AISE1_CLASS/Code embeddings/01_basic_embeddings.py

232 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
============================================================================
Example 1: Computing Code Embeddings and Measuring Similarity
============================================================================
AISE501 AI in Software Engineering I
Fachhochschule Graubünden
GOAL:
Load a pre-trained code embedding model, embed several code snippets,
and compute pairwise cosine similarities to see which snippets the
model considers semantically similar.
WHAT YOU WILL LEARN:
- How to load a code embedding model with PyTorch
- How code is tokenized and converted to vectors
- How cosine similarity reveals semantic relationships
- That similar functionality → high similarity, different purpose → low
HARDWARE:
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
============================================================================
"""
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
# ── Device selection ──────────────────────────────────────────────────────
# PyTorch supports three backends:
# - "cuda" → NVIDIA GPUs (Linux/Windows)
# - "mps" → Apple Silicon GPUs (macOS M1/M2/M3/M4)
# - "cpu" → always available, slower
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
DEVICE = get_device()
print(f"Using device: {DEVICE}\n")
# ── Load model and tokenizer ─────────────────────────────────────────────
# We use st-codesearch-distilroberta-base — a DistilRoBERTa model (82M params)
# specifically fine-tuned on 1.38M code-comment pairs from CodeSearchNet using
# contrastive learning. It produces 768-dim embeddings optimized for matching
# natural language descriptions to code, making it ideal for code search and
# similarity tasks.
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
print(f"Loading model: {MODEL_NAME} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval() # disable dropout — we want deterministic embeddings
print("Model loaded.\n")
# ── Define code snippets to compare ──────────────────────────────────────
# We intentionally include:
# - Two sorting functions (similar purpose, different implementation)
# - A function that does something completely different (JSON parsing)
# - A sorting function in a different style (list comprehension)
snippets = {
"bubble_sort": """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr
""",
"quick_sort": """
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)
""",
"sorted_builtin": """
def sort_list(data):
return sorted(data)
""",
"parse_json": """
import json
def parse_config(filepath):
with open(filepath, 'r') as f:
config = json.load(f)
return config
""",
"read_csv": """
import csv
def read_csv_file(filepath):
rows = []
with open(filepath, 'r') as f:
reader = csv.reader(f)
for row in reader:
rows.append(row)
return rows
""",
}
def embed_code(code_text: str) -> torch.Tensor:
"""
Convert a code snippet into a single embedding vector.
This function implements the full pipeline from the lecture:
raw code → tokens → token embeddings → single vector → unit vector
Why a function like this is needed:
A transformer model outputs one vector *per token*, but we need a single
vector that represents the entire snippet so we can compare snippets using
cosine similarity. This function handles tokenization, the forward pass,
pooling (many vectors → one), and normalization (arbitrary length → unit).
Returns:
A 768-dimensional unit vector (torch.Tensor) representing the code.
"""
# ── Step 1: Tokenization ──────────────────────────────────────────────
# The model cannot read raw text. We must split the code into sub-word
# tokens and convert each token to its integer ID from the vocabulary.
#
# The tokenizer also produces an "attention mask": a tensor of 1s and 0s
# indicating which positions are real tokens (1) vs. padding (0).
# Padding is needed because tensors must have uniform length.
#
# truncation=True: if the code exceeds 512 tokens, cut it off.
# Why 512? This model was trained with a max context of 512 tokens.
# Anything beyond that would be out-of-distribution.
inputs = tokenizer(
code_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
).to(DEVICE)
# ── Step 2: Forward pass through the transformer ──────────────────────
# The model processes all tokens through multiple layers of self-attention
# (as covered in the lecture). Each layer refines the representation.
#
# torch.no_grad() disables gradient tracking because we are only doing
# inference, not training. This saves memory and speeds things up.
#
# The output contains a CONTEXTUAL embedding for EACH token:
# outputs.last_hidden_state has shape [1, seq_len, 768]
# → 1 batch, seq_len tokens, each represented as a 768-dim vector.
#
# These are NOT the static input embeddings — they have been transformed
# by the attention mechanism, so each token's vector now encodes context
# from ALL other tokens in the sequence.
with torch.no_grad():
outputs = model(**inputs)
# ── Step 3: Mean pooling — many token vectors → one snippet vector ────
# Problem: we have one 768-dim vector per token, but we need ONE vector
# for the entire code snippet (so we can compare it to other snippets).
#
# Solution: average all token vectors. This is called "mean pooling."
#
# Subtlety: we must ignore padding tokens. If the code has 30 real tokens
# but the tensor was padded to 40, we don't want the 10 zero-vectors from
# padding to dilute the average. The attention mask lets us do this:
# 1. Multiply each token vector by its mask (1 for real, 0 for padding)
# 2. Sum the masked vectors
# 3. Divide by the number of real tokens (not the padded length)
attention_mask = inputs["attention_mask"].unsqueeze(-1) # [1, seq_len, 1]
masked_output = outputs.last_hidden_state * attention_mask
embedding = masked_output.sum(dim=1) / attention_mask.sum(dim=1)
# ── Step 4: L2 normalization — project onto the unit hypersphere ──────
# From the lecture: when vectors are normalized to length 1, cosine
# similarity simplifies to the dot product:
#
# cos(θ) = (a · b) / (‖a‖ · ‖b‖) → if ‖a‖=‖b‖=1 → cos(θ) = a · b
#
# This is not just a convenience — it is standard practice in production
# embedding systems (OpenAI, Cohere, etc.) because:
# - Dot products are faster to compute than full cosine similarity
# - Vector databases are optimized for dot-product search
# - It removes magnitude differences so we compare direction only
embedding = F.normalize(embedding, p=2, dim=1)
return embedding.squeeze(0) # remove batch dim → shape: [768]
# ── Compute embeddings for all snippets ───────────────────────────────────
print("Computing embeddings...")
embeddings = {}
for name, code in snippets.items():
embeddings[name] = embed_code(code)
num_tokens = len(tokenizer.encode(code))
print(f" {name:20s}{num_tokens:3d} tokens → vector of dim {embeddings[name].shape[0]}")
print()
# ── Compute pairwise cosine similarities ──────────────────────────────────
# cosine_similarity = dot product of unit vectors (we already normalized above)
names = list(embeddings.keys())
print("Pairwise Cosine Similarities:")
print(f"{'':22s}", end="")
for n in names:
print(f"{n:>16s}", end="")
print()
for i, n1 in enumerate(names):
print(f"{n1:22s}", end="")
for j, n2 in enumerate(names):
sim = torch.dot(embeddings[n1].cpu(), embeddings[n2].cpu()).item()
print(f"{sim:16.3f}", end="")
print()
# ── Interpretation ────────────────────────────────────────────────────────
print("\n" + "=" * 70)
print("INTERPRETATION:")
print("=" * 70)
print("""
- bubble_sort, quick_sort, and sorted_builtin should have HIGH similarity
(all perform sorting, despite very different implementations).
- parse_json and read_csv should be similar to each other (both read files)
but DISSIMILAR to the sorting functions (different purpose).
- This demonstrates that code embeddings capture WHAT code does,
not just HOW it looks syntactically.
""")