217 lines
11 KiB
Python
217 lines
11 KiB
Python
"""
|
||
============================================================================
|
||
Example 5: Visualizing Code Embeddings with PCA and t-SNE
|
||
============================================================================
|
||
AISE501 – AI in Software Engineering I
|
||
Fachhochschule Graubünden
|
||
|
||
GOAL:
|
||
Reduce 768-dimensional code embeddings to 2D and plot them.
|
||
This makes the embedding space visible: you can SEE that similar
|
||
code clusters together and different code is far apart.
|
||
|
||
WHAT YOU WILL LEARN:
|
||
- How PCA projects high-dimensional vectors to 2D (linear reduction)
|
||
- How t-SNE creates a non-linear 2D map that preserves neighborhoods
|
||
- How to interpret embedding space visualizations
|
||
- That code functionality determines position, not syntax or language
|
||
|
||
OUTPUT:
|
||
Saves two PNG plots: code_embeddings_pca.png and code_embeddings_tsne.png
|
||
|
||
HARDWARE:
|
||
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
|
||
============================================================================
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from transformers import AutoTokenizer, AutoModel
|
||
import torch.nn.functional as F
|
||
from sklearn.decomposition import PCA
|
||
from sklearn.manifold import TSNE
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
|
||
# Use a non-interactive backend so the script works in headless environments
|
||
matplotlib.use("Agg")
|
||
|
||
# ── Device selection ──────────────────────────────────────────────────────
|
||
def get_device():
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
elif torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
DEVICE = get_device()
|
||
print(f"Using device: {DEVICE}\n")
|
||
|
||
# ── Load model ────────────────────────────────────────────────────────────
|
||
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
|
||
print(f"Loading model: {MODEL_NAME} ...")
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
||
model.eval()
|
||
print("Model loaded.\n")
|
||
|
||
# ── Code snippets organized by CATEGORY ───────────────────────────────────
|
||
# Each category represents a type of task. We expect snippets within the
|
||
# same category to cluster together in the embedding space.
|
||
categories = {
|
||
"Sorting": {
|
||
"bubble_sort_py": "def bubble_sort(arr):\n n = len(arr)\n for i in range(n):\n for j in range(n-i-1):\n if arr[j] > arr[j+1]:\n arr[j], arr[j+1] = arr[j+1], arr[j]\n return arr",
|
||
"quick_sort_py": "def quick_sort(a):\n if len(a) <= 1: return a\n p = a[0]\n return quick_sort([x for x in a[1:] if x < p]) + [p] + quick_sort([x for x in a[1:] if x >= p])",
|
||
"sort_js": "function sortArray(arr) { return arr.sort((a, b) => a - b); }",
|
||
"insertion_sort": "def insertion_sort(arr):\n for i in range(1, len(arr)):\n key = arr[i]\n j = i - 1\n while j >= 0 and arr[j] > key:\n arr[j+1] = arr[j]\n j -= 1\n arr[j+1] = key\n return arr",
|
||
},
|
||
"File I/O": {
|
||
"read_json": "import json\ndef read_json(path):\n with open(path) as f:\n return json.load(f)",
|
||
"write_file": "def write_file(path, content):\n with open(path, 'w') as f:\n f.write(content)",
|
||
"read_csv": "import csv\ndef read_csv(path):\n with open(path) as f:\n return list(csv.reader(f))",
|
||
"read_yaml": "import yaml\ndef read_yaml(path):\n with open(path) as f:\n return yaml.safe_load(f)",
|
||
},
|
||
"String ops": {
|
||
"reverse_str": "def reverse(s): return s[::-1]",
|
||
"capitalize": "def capitalize_words(s): return ' '.join(w.capitalize() for w in s.split())",
|
||
"count_chars": "def count_chars(s):\n return {c: s.count(c) for c in set(s)}",
|
||
"is_palindrome": "def is_palindrome(s): return s == s[::-1]",
|
||
},
|
||
"Math": {
|
||
"factorial": "def factorial(n):\n r = 1\n for i in range(2, n+1): r *= i\n return r",
|
||
"fibonacci": "def fib(n):\n a, b = 0, 1\n for _ in range(n): a, b = b, a+b\n return a",
|
||
"gcd": "def gcd(a, b):\n while b: a, b = b, a % b\n return a",
|
||
"is_prime": "def is_prime(n):\n if n < 2: return False\n for i in range(2, int(n**0.5)+1):\n if n % i == 0: return False\n return True",
|
||
},
|
||
"Networking": {
|
||
"http_get": "import requests\ndef http_get(url): return requests.get(url).json()",
|
||
"fetch_url": "import urllib.request\ndef fetch(url):\n with urllib.request.urlopen(url) as r:\n return r.read().decode()",
|
||
"post_data": "import requests\ndef post_json(url, data): return requests.post(url, json=data).status_code",
|
||
"download_file": "import urllib.request\ndef download(url, path): urllib.request.urlretrieve(url, path)",
|
||
},
|
||
}
|
||
|
||
|
||
def embed_code(code: str) -> torch.Tensor:
|
||
"""Embed code into a normalized vector."""
|
||
inputs = tokenizer(
|
||
code, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||
).to(DEVICE)
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
mask = inputs["attention_mask"].unsqueeze(-1)
|
||
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
|
||
return F.normalize(embedding, p=2, dim=1).squeeze(0).cpu().numpy()
|
||
|
||
|
||
# ── Compute embeddings ────────────────────────────────────────────────────
|
||
print("Computing embeddings...")
|
||
all_embeddings = []
|
||
all_labels = []
|
||
all_categories = []
|
||
|
||
for category, snippets in categories.items():
|
||
for label, code in snippets.items():
|
||
vec = embed_code(code)
|
||
all_embeddings.append(vec)
|
||
all_labels.append(label)
|
||
all_categories.append(category)
|
||
print(f" [{category:12s}] {label}")
|
||
|
||
# Convert to numpy matrix: shape [num_snippets, 768]
|
||
X = np.stack(all_embeddings)
|
||
print(f"\nEmbedding matrix: {X.shape[0]} snippets × {X.shape[1]} dimensions\n")
|
||
|
||
# ── Color map for categories ──────────────────────────────────────────────
|
||
category_names = list(categories.keys())
|
||
colors = plt.cm.Set1(np.linspace(0, 1, len(category_names)))
|
||
color_map = {cat: colors[i] for i, cat in enumerate(category_names)}
|
||
point_colors = [color_map[cat] for cat in all_categories]
|
||
|
||
# ── Plot 1: PCA ──────────────────────────────────────────────────────────
|
||
# PCA finds the two directions of maximum variance in the 1024-dim space
|
||
# and projects all points onto those two directions.
|
||
print("Computing PCA (2 components)...")
|
||
pca = PCA(n_components=2)
|
||
X_pca = pca.fit_transform(X)
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
for i, (x, y) in enumerate(X_pca):
|
||
ax.scatter(x, y, c=[point_colors[i]], s=100, edgecolors="black", linewidth=0.5, zorder=3)
|
||
ax.annotate(all_labels[i], (x, y), fontsize=7, ha="center", va="bottom",
|
||
xytext=(0, 6), textcoords="offset points")
|
||
|
||
# Legend
|
||
for cat in category_names:
|
||
ax.scatter([], [], c=[color_map[cat]], s=80, label=cat, edgecolors="black", linewidth=0.5)
|
||
ax.legend(loc="best", fontsize=9, title="Category", title_fontsize=10)
|
||
|
||
variance_explained = pca.explained_variance_ratio_
|
||
ax.set_title(f"Code Embeddings — PCA Projection\n"
|
||
f"(PC1: {variance_explained[0]:.1%} variance, PC2: {variance_explained[1]:.1%} variance)",
|
||
fontsize=13)
|
||
ax.set_xlabel("Principal Component 1", fontsize=11)
|
||
ax.set_ylabel("Principal Component 2", fontsize=11)
|
||
ax.grid(True, alpha=0.3)
|
||
fig.tight_layout()
|
||
fig.savefig("code_embeddings_pca.png", dpi=150)
|
||
print(f" Saved: code_embeddings_pca.png")
|
||
print(f" Variance explained: PC1={variance_explained[0]:.1%}, PC2={variance_explained[1]:.1%}\n")
|
||
|
||
# ── Plot 2: t-SNE ────────────────────────────────────────────────────────
|
||
# t-SNE is a non-linear method that preserves LOCAL neighborhood structure.
|
||
# Points that are close in 1024-dim space stay close in 2D.
|
||
# Perplexity controls the balance between local and global structure.
|
||
print("Computing t-SNE (this may take a few seconds)...")
|
||
tsne = TSNE(n_components=2, perplexity=5, random_state=42, max_iter=1000)
|
||
X_tsne = tsne.fit_transform(X)
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
for i, (x, y) in enumerate(X_tsne):
|
||
ax.scatter(x, y, c=[point_colors[i]], s=100, edgecolors="black", linewidth=0.5, zorder=3)
|
||
ax.annotate(all_labels[i], (x, y), fontsize=7, ha="center", va="bottom",
|
||
xytext=(0, 6), textcoords="offset points")
|
||
|
||
for cat in category_names:
|
||
ax.scatter([], [], c=[color_map[cat]], s=80, label=cat, edgecolors="black", linewidth=0.5)
|
||
ax.legend(loc="best", fontsize=9, title="Category", title_fontsize=10)
|
||
|
||
ax.set_title("Code Embeddings — t-SNE Projection\n"
|
||
"(non-linear dimensionality reduction)", fontsize=13)
|
||
ax.set_xlabel("t-SNE Dimension 1", fontsize=11)
|
||
ax.set_ylabel("t-SNE Dimension 2", fontsize=11)
|
||
ax.grid(True, alpha=0.3)
|
||
fig.tight_layout()
|
||
fig.savefig("code_embeddings_tsne.png", dpi=150)
|
||
print(f" Saved: code_embeddings_tsne.png\n")
|
||
|
||
print("=" * 70)
|
||
print("INTERPRETATION")
|
||
print("=" * 70)
|
||
print(f"""
|
||
Both plots project {X.shape[1]}-dimensional embedding vectors to 2D:
|
||
|
||
PCA (Principal Component Analysis):
|
||
- Linear projection onto the two axes of maximum variance.
|
||
- Preserves global structure: large distances are meaningful.
|
||
- Good for seeing overall separation between categories.
|
||
- The % variance tells you how much information is retained.
|
||
|
||
t-SNE (t-distributed Stochastic Neighbor Embedding):
|
||
- Non-linear: distorts distances but preserves neighborhoods.
|
||
- Points that are close in the original space stay close in 2D.
|
||
- Better at revealing tight clusters within categories.
|
||
- Distances BETWEEN clusters are not meaningful.
|
||
|
||
EXPECTED RESULT:
|
||
You should see 5 distinct clusters, one per category:
|
||
- Sorting functions (bubble, quick, insertion, JS sort) cluster together
|
||
- File I/O functions cluster together
|
||
- String operations cluster together
|
||
- Math functions cluster together
|
||
- Networking functions cluster together
|
||
|
||
This visually confirms that code embeddings organize code by
|
||
PURPOSE, not by surface syntax or programming language.
|
||
""")
|