252 lines
7.8 KiB
Python
252 lines
7.8 KiB
Python
"""
|
||
============================================================================
|
||
Example 2: Text-to-Code Semantic Search
|
||
============================================================================
|
||
AISE501 – AI in Software Engineering I
|
||
Fachhochschule Graubünden
|
||
|
||
GOAL:
|
||
Build a mini code search engine: given a natural language query like
|
||
"sort a list", find the most relevant code snippet from a collection.
|
||
This is the core mechanism behind semantic code search in tools like
|
||
Cursor, GitHub Copilot, and code search engines.
|
||
|
||
WHAT YOU WILL LEARN:
|
||
- How the SAME embedding model maps both text and code into a shared
|
||
vector space — this is what makes text-to-code search possible.
|
||
- How to build a simple search index and query it.
|
||
- Why embedding-based search beats keyword search for code.
|
||
|
||
HARDWARE:
|
||
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
|
||
============================================================================
|
||
"""
|
||
|
||
import torch
|
||
from transformers import AutoTokenizer, AutoModel
|
||
import torch.nn.functional as F
|
||
|
||
# ── Device selection ──────────────────────────────────────────────────────
|
||
def get_device():
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
elif torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
DEVICE = get_device()
|
||
print(f"Using device: {DEVICE}\n")
|
||
|
||
# ── Load model ────────────────────────────────────────────────────────────
|
||
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
|
||
print(f"Loading model: {MODEL_NAME} ...")
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
||
model.eval()
|
||
print("Model loaded.\n")
|
||
|
||
# ── Code "database" ──────────────────────────────────────────────────────
|
||
# Imagine these are functions in a large codebase that we want to search.
|
||
code_database = [
|
||
{
|
||
"name": "binary_search",
|
||
"code": """
|
||
def binary_search(arr, target):
|
||
low, high = 0, len(arr) - 1
|
||
while low <= high:
|
||
mid = (low + high) // 2
|
||
if arr[mid] == target:
|
||
return mid
|
||
elif arr[mid] < target:
|
||
low = mid + 1
|
||
else:
|
||
high = mid - 1
|
||
return -1
|
||
"""
|
||
},
|
||
{
|
||
"name": "merge_sort",
|
||
"code": """
|
||
def merge_sort(arr):
|
||
if len(arr) <= 1:
|
||
return arr
|
||
mid = len(arr) // 2
|
||
left = merge_sort(arr[:mid])
|
||
right = merge_sort(arr[mid:])
|
||
return merge(left, right)
|
||
"""
|
||
},
|
||
{
|
||
"name": "read_json_file",
|
||
"code": """
|
||
import json
|
||
def read_json_file(path):
|
||
with open(path, 'r') as f:
|
||
return json.load(f)
|
||
"""
|
||
},
|
||
{
|
||
"name": "calculate_average",
|
||
"code": """
|
||
def calculate_average(numbers):
|
||
if not numbers:
|
||
return 0.0
|
||
return sum(numbers) / len(numbers)
|
||
"""
|
||
},
|
||
{
|
||
"name": "connect_database",
|
||
"code": """
|
||
import sqlite3
|
||
def connect_database(db_path):
|
||
conn = sqlite3.connect(db_path)
|
||
cursor = conn.cursor()
|
||
return conn, cursor
|
||
"""
|
||
},
|
||
{
|
||
"name": "send_http_request",
|
||
"code": """
|
||
import requests
|
||
def send_http_request(url, method='GET', data=None):
|
||
if method == 'GET':
|
||
response = requests.get(url)
|
||
else:
|
||
response = requests.post(url, json=data)
|
||
return response.json()
|
||
"""
|
||
},
|
||
{
|
||
"name": "flatten_nested_list",
|
||
"code": """
|
||
def flatten(nested_list):
|
||
result = []
|
||
for item in nested_list:
|
||
if isinstance(item, list):
|
||
result.extend(flatten(item))
|
||
else:
|
||
result.append(item)
|
||
return result
|
||
"""
|
||
},
|
||
{
|
||
"name": "count_words",
|
||
"code": """
|
||
def count_words(text):
|
||
words = text.lower().split()
|
||
word_count = {}
|
||
for word in words:
|
||
word_count[word] = word_count.get(word, 0) + 1
|
||
return word_count
|
||
"""
|
||
},
|
||
{
|
||
"name": "validate_email",
|
||
"code": """
|
||
import re
|
||
def validate_email(email):
|
||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$'
|
||
return bool(re.match(pattern, email))
|
||
"""
|
||
},
|
||
{
|
||
"name": "fibonacci",
|
||
"code": """
|
||
def fibonacci(n):
|
||
if n <= 1:
|
||
return n
|
||
a, b = 0, 1
|
||
for _ in range(2, n + 1):
|
||
a, b = b, a + b
|
||
return b
|
||
"""
|
||
},
|
||
]
|
||
|
||
|
||
def embed_text(text: str) -> torch.Tensor:
|
||
"""Embed a piece of text or code into a normalized vector."""
|
||
inputs = tokenizer(
|
||
text, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||
).to(DEVICE)
|
||
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
|
||
# Mean pooling over non-padding tokens
|
||
mask = inputs["attention_mask"].unsqueeze(-1)
|
||
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
|
||
return F.normalize(embedding, p=2, dim=1).squeeze(0)
|
||
|
||
|
||
# ── Step 1: Index the code database ───────────────────────────────────────
|
||
# In a real system this would be stored in a vector database (ChromaDB,
|
||
# Pinecone, pgvector). Here we keep it simple with a list of tensors.
|
||
print("Indexing code database...")
|
||
code_vectors = []
|
||
for entry in code_database:
|
||
vec = embed_text(entry["code"])
|
||
code_vectors.append(vec)
|
||
print(f" Indexed: {entry['name']}")
|
||
|
||
# Stack into a matrix: shape [num_snippets, embedding_dim]
|
||
code_matrix = torch.stack(code_vectors)
|
||
print(f"\nIndex built: {code_matrix.shape[0]} snippets, {code_matrix.shape[1]} dimensions\n")
|
||
|
||
|
||
# ── Step 2: Search with natural language queries ──────────────────────────
|
||
queries = [
|
||
"sort a list of numbers",
|
||
"find an element in a sorted array",
|
||
"compute the mean of a list",
|
||
"make an HTTP API call",
|
||
"open and read a JSON file",
|
||
"check if an email address is valid",
|
||
"count word frequencies in a string",
|
||
"generate fibonacci numbers",
|
||
"connect to a SQL database",
|
||
"flatten a nested list into a single list",
|
||
]
|
||
|
||
print("=" * 70)
|
||
print("SEMANTIC CODE SEARCH RESULTS")
|
||
print("=" * 70)
|
||
|
||
for query in queries:
|
||
# Embed the natural language query with the SAME model
|
||
query_vec = embed_text(query)
|
||
|
||
# Compute cosine similarity against all code embeddings
|
||
# Because vectors are normalized, dot product = cosine similarity
|
||
similarities = torch.mv(code_matrix.cpu(), query_vec.cpu())
|
||
|
||
# Rank results by similarity (highest first)
|
||
ranked_indices = torch.argsort(similarities, descending=True)
|
||
|
||
print(f'\nQuery: "{query}"')
|
||
print(f" Rank Score Function")
|
||
print(f" ---- ----- --------")
|
||
for rank, idx in enumerate(ranked_indices[:3]): # show top 3
|
||
score = similarities[idx].item()
|
||
name = code_database[idx]["name"]
|
||
marker = " ← best match" if rank == 0 else ""
|
||
print(f" {rank+1:4d} {score:.3f} {name}{marker}")
|
||
|
||
print("\n" + "=" * 70)
|
||
print("KEY OBSERVATIONS:")
|
||
print("=" * 70)
|
||
print("""
|
||
1. The model maps NATURAL LANGUAGE queries and CODE into the same vector
|
||
space. This is why "sort a list" finds merge_sort and "find an element
|
||
in a sorted array" finds binary_search — even though the queries
|
||
contain none of the function identifiers.
|
||
|
||
2. This is fundamentally different from grep/keyword search:
|
||
- grep "sort" would miss functions named "order" or "arrange"
|
||
- grep "find element" would miss "binary_search"
|
||
Embeddings understand MEANING, not just string matching.
|
||
|
||
3. This is exactly how Cursor, Copilot, and other AI coding tools
|
||
retrieve relevant code from your project to feed into the LLM.
|
||
""")
|