""" ============================================================================ Example 2: Text-to-Code Semantic Search ============================================================================ AISE501 – AI in Software Engineering I Fachhochschule Graubünden GOAL: Build a mini code search engine: given a natural language query like "sort a list", find the most relevant code snippet from a collection. This is the core mechanism behind semantic code search in tools like Cursor, GitHub Copilot, and code search engines. WHAT YOU WILL LEARN: - How the SAME embedding model maps both text and code into a shared vector space — this is what makes text-to-code search possible. - How to build a simple search index and query it. - Why embedding-based search beats keyword search for code. HARDWARE: Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac). ============================================================================ """ import torch from transformers import AutoTokenizer, AutoModel import torch.nn.functional as F # ── Device selection ────────────────────────────────────────────────────── def get_device(): if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") DEVICE = get_device() print(f"Using device: {DEVICE}\n") # ── Load model ──────────────────────────────────────────────────────────── MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base" print(f"Loading model: {MODEL_NAME} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() print("Model loaded.\n") # ── Code "database" ────────────────────────────────────────────────────── # Imagine these are functions in a large codebase that we want to search. code_database = [ { "name": "binary_search", "code": """ def binary_search(arr, target): low, high = 0, len(arr) - 1 while low <= high: mid = (low + high) // 2 if arr[mid] == target: return mid elif arr[mid] < target: low = mid + 1 else: high = mid - 1 return -1 """ }, { "name": "merge_sort", "code": """ def merge_sort(arr): if len(arr) <= 1: return arr mid = len(arr) // 2 left = merge_sort(arr[:mid]) right = merge_sort(arr[mid:]) return merge(left, right) """ }, { "name": "read_json_file", "code": """ import json def read_json_file(path): with open(path, 'r') as f: return json.load(f) """ }, { "name": "calculate_average", "code": """ def calculate_average(numbers): if not numbers: return 0.0 return sum(numbers) / len(numbers) """ }, { "name": "connect_database", "code": """ import sqlite3 def connect_database(db_path): conn = sqlite3.connect(db_path) cursor = conn.cursor() return conn, cursor """ }, { "name": "send_http_request", "code": """ import requests def send_http_request(url, method='GET', data=None): if method == 'GET': response = requests.get(url) else: response = requests.post(url, json=data) return response.json() """ }, { "name": "flatten_nested_list", "code": """ def flatten(nested_list): result = [] for item in nested_list: if isinstance(item, list): result.extend(flatten(item)) else: result.append(item) return result """ }, { "name": "count_words", "code": """ def count_words(text): words = text.lower().split() word_count = {} for word in words: word_count[word] = word_count.get(word, 0) + 1 return word_count """ }, { "name": "validate_email", "code": """ import re def validate_email(email): pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\.[a-zA-Z]{2,}$' return bool(re.match(pattern, email)) """ }, { "name": "fibonacci", "code": """ def fibonacci(n): if n <= 1: return n a, b = 0, 1 for _ in range(2, n + 1): a, b = b, a + b return b """ }, ] def embed_text(text: str) -> torch.Tensor: """Embed a piece of text or code into a normalized vector.""" inputs = tokenizer( text, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) # Mean pooling over non-padding tokens mask = inputs["attention_mask"].unsqueeze(-1) embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) return F.normalize(embedding, p=2, dim=1).squeeze(0) # ── Step 1: Index the code database ─────────────────────────────────────── # In a real system this would be stored in a vector database (ChromaDB, # Pinecone, pgvector). Here we keep it simple with a list of tensors. print("Indexing code database...") code_vectors = [] for entry in code_database: vec = embed_text(entry["code"]) code_vectors.append(vec) print(f" Indexed: {entry['name']}") # Stack into a matrix: shape [num_snippets, embedding_dim] code_matrix = torch.stack(code_vectors) print(f"\nIndex built: {code_matrix.shape[0]} snippets, {code_matrix.shape[1]} dimensions\n") # ── Step 2: Search with natural language queries ────────────────────────── queries = [ "sort a list of numbers", "find an element in a sorted array", "compute the mean of a list", "make an HTTP API call", "open and read a JSON file", "check if an email address is valid", "count word frequencies in a string", "generate fibonacci numbers", "connect to a SQL database", "flatten a nested list into a single list", ] print("=" * 70) print("SEMANTIC CODE SEARCH RESULTS") print("=" * 70) for query in queries: # Embed the natural language query with the SAME model query_vec = embed_text(query) # Compute cosine similarity against all code embeddings # Because vectors are normalized, dot product = cosine similarity similarities = torch.mv(code_matrix.cpu(), query_vec.cpu()) # Rank results by similarity (highest first) ranked_indices = torch.argsort(similarities, descending=True) print(f'\nQuery: "{query}"') print(f" Rank Score Function") print(f" ---- ----- --------") for rank, idx in enumerate(ranked_indices[:3]): # show top 3 score = similarities[idx].item() name = code_database[idx]["name"] marker = " ← best match" if rank == 0 else "" print(f" {rank+1:4d} {score:.3f} {name}{marker}") print("\n" + "=" * 70) print("KEY OBSERVATIONS:") print("=" * 70) print(""" 1. The model maps NATURAL LANGUAGE queries and CODE into the same vector space. This is why "sort a list" finds merge_sort and "find an element in a sorted array" finds binary_search — even though the queries contain none of the function identifiers. 2. This is fundamentally different from grep/keyword search: - grep "sort" would miss functions named "order" or "arrange" - grep "find element" would miss "binary_search" Embeddings understand MEANING, not just string matching. 3. This is exactly how Cursor, Copilot, and other AI coding tools retrieve relevant code from your project to feed into the LLM. """)