""" ============================================================================ Example 6: PCA Denoising — Can Fewer Dimensions Improve Similarity? ============================================================================ AISE501 – AI in Software Engineering I Fachhochschule Graubünden HYPOTHESIS: Embedding vectors live in a 768-dimensional space, but most of the semantic signal may be concentrated in a small number of principal components. The remaining dimensions could add "noise" that dilutes cosine similarity. If true, projecting embeddings onto a small PCA subspace should INCREASE similarity within semantic groups and DECREASE similarity across groups — making code search sharper. WHAT YOU WILL LEARN: - How PCA decomposes the embedding space into ranked components - How to measure retrieval quality (intra- vs inter-group similarity) - Whether dimensionality reduction helps or hurts in practice - The concept of an "optimal" embedding dimension for a given task OUTPUT: Saves pca_denoising_analysis.png with three sub-plots. HARDWARE: Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac). ============================================================================ """ import torch import numpy as np from transformers import AutoTokenizer, AutoModel import torch.nn.functional as F from sklearn.decomposition import PCA import matplotlib.pyplot as plt import matplotlib matplotlib.use("Agg") # ── Device selection ────────────────────────────────────────────────────── def get_device(): if torch.cuda.is_available(): return torch.device("cuda") elif torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") DEVICE = get_device() print(f"Using device: {DEVICE}\n") # ── Load model ──────────────────────────────────────────────────────────── MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base" print(f"Loading model: {MODEL_NAME} ...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) model.eval() print("Model loaded.\n") # ── Code snippets organized into semantic GROUPS ────────────────────────── # We need clear groups so we can measure intra-group vs inter-group similarity. groups = { "Sorting": { "bubble_sort": """ def bubble_sort(arr): n = len(arr) for i in range(n): for j in range(0, n - i - 1): if arr[j] > arr[j + 1]: arr[j], arr[j + 1] = arr[j + 1], arr[j] return arr""", "quick_sort": """ def quick_sort(arr): if len(arr) <= 1: return arr pivot = arr[len(arr) // 2] left = [x for x in arr if x < pivot] middle = [x for x in arr if x == pivot] right = [x for x in arr if x > pivot] return quick_sort(left) + middle + quick_sort(right)""", "merge_sort": """ def merge_sort(arr): if len(arr) <= 1: return arr mid = len(arr) // 2 left = merge_sort(arr[:mid]) right = merge_sort(arr[mid:]) merged = [] i = j = 0 while i < len(left) and j < len(right): if left[i] <= right[j]: merged.append(left[i]); i += 1 else: merged.append(right[j]); j += 1 return merged + left[i:] + right[j:]""", "insertion_sort": """ def insertion_sort(arr): for i in range(1, len(arr)): key = arr[i] j = i - 1 while j >= 0 and arr[j] > key: arr[j + 1] = arr[j] j -= 1 arr[j + 1] = key return arr""", "selection_sort": """ def selection_sort(arr): for i in range(len(arr)): min_idx = i for j in range(i + 1, len(arr)): if arr[j] < arr[min_idx]: min_idx = j arr[i], arr[min_idx] = arr[min_idx], arr[i] return arr""", "heap_sort": """ def heap_sort(arr): import heapq heapq.heapify(arr) return [heapq.heappop(arr) for _ in range(len(arr))]""", }, "File I/O": { "read_json": """ import json def read_json(path): with open(path, 'r') as f: return json.load(f)""", "write_file": """ def write_file(path, content): with open(path, 'w') as f: f.write(content)""", "read_csv": """ import csv def read_csv(path): with open(path, 'r') as f: reader = csv.reader(f) return list(reader)""", "read_yaml": """ import yaml def load_yaml(path): with open(path, 'r') as f: return yaml.safe_load(f)""", "write_json": """ import json def write_json(path, data): with open(path, 'w') as f: json.dump(data, f, indent=2)""", "read_lines": """ def read_lines(path): with open(path, 'r') as f: return f.readlines()""", }, "Math": { "factorial": """ def factorial(n): if n <= 1: return 1 return n * factorial(n - 1)""", "fibonacci": """ def fibonacci(n): a, b = 0, 1 for _ in range(n): a, b = b, a + b return a""", "gcd": """ def gcd(a, b): while b: a, b = b, a % b return a""", "is_prime": """ def is_prime(n): if n < 2: return False for i in range(2, int(n**0.5) + 1): if n % i == 0: return False return True""", "power": """ def power(base, exp): if exp == 0: return 1 if exp % 2 == 0: half = power(base, exp // 2) return half * half return base * power(base, exp - 1)""", "sum_digits": """ def sum_digits(n): total = 0 while n > 0: total += n % 10 n //= 10 return total""", }, "Networking": { "http_get": """ import requests def http_get(url): response = requests.get(url) return response.json()""", "post_data": """ import requests def post_data(url, payload): response = requests.post(url, json=payload) return response.status_code, response.json()""", "fetch_url": """ import urllib.request def fetch_url(url): with urllib.request.urlopen(url) as resp: return resp.read().decode('utf-8')""", "download_file": """ import urllib.request def download_file(url, dest): urllib.request.urlretrieve(url, dest) return dest""", "http_put": """ import requests def http_put(url, data): response = requests.put(url, json=data) return response.status_code""", "http_delete": """ import requests def http_delete(url): response = requests.delete(url) return response.status_code""", }, "String ops": { "reverse_str": """ def reverse_string(s): return s[::-1]""", "is_palindrome": """ def is_palindrome(s): s = s.lower().replace(' ', '') return s == s[::-1]""", "count_vowels": """ def count_vowels(s): return sum(1 for c in s.lower() if c in 'aeiou')""", "capitalize_words": """ def capitalize_words(s): return ' '.join(w.capitalize() for w in s.split())""", "remove_duplicates": """ def remove_duplicate_chars(s): seen = set() result = [] for c in s: if c not in seen: seen.add(c) result.append(c) return ''.join(result)""", "count_words": """ def count_words(text): words = text.lower().split() freq = {} for w in words: freq[w] = freq.get(w, 0) + 1 return freq""", }, "Data structures": { "stack_push_pop": """ class Stack: def __init__(self): self.items = [] def push(self, item): self.items.append(item) def pop(self): return self.items.pop()""", "queue_impl": """ from collections import deque class Queue: def __init__(self): self.items = deque() def enqueue(self, item): self.items.append(item) def dequeue(self): return self.items.popleft()""", "linked_list": """ class Node: def __init__(self, val): self.val = val self.next = None class LinkedList: def __init__(self): self.head = None def append(self, val): node = Node(val) if not self.head: self.head = node return curr = self.head while curr.next: curr = curr.next curr.next = node""", "binary_tree": """ class TreeNode: def __init__(self, val): self.val = val self.left = None self.right = None def inorder(root): if root: yield from inorder(root.left) yield root.val yield from inorder(root.right)""", "hash_map": """ class HashMap: def __init__(self, size=256): self.buckets = [[] for _ in range(size)] def put(self, key, value): idx = hash(key) % len(self.buckets) for i, (k, v) in enumerate(self.buckets[idx]): if k == key: self.buckets[idx][i] = (key, value) return self.buckets[idx].append((key, value))""", "priority_queue": """ import heapq class PriorityQueue: def __init__(self): self.heap = [] def push(self, priority, item): heapq.heappush(self.heap, (priority, item)) def pop(self): return heapq.heappop(self.heap)[1]""", }, } def embed_code(code: str) -> torch.Tensor: """Embed code into a normalized vector.""" inputs = tokenizer( code, return_tensors="pt", truncation=True, max_length=512, padding=True ).to(DEVICE) with torch.no_grad(): outputs = model(**inputs) mask = inputs["attention_mask"].unsqueeze(-1) embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1) return F.normalize(embedding, p=2, dim=1).squeeze(0) # ── Step 1: Compute all embeddings ──────────────────────────────────────── print("Computing embeddings...") all_names = [] all_labels = [] all_vectors = [] for group_name, snippets in groups.items(): for snippet_name, code in snippets.items(): vec = embed_code(code).cpu().numpy() all_names.append(snippet_name) all_labels.append(group_name) all_vectors.append(vec) print(f" [{group_name:12s}] {snippet_name}") X = np.stack(all_vectors) # shape: [N, 768] N, D = X.shape print(f"\nEmbedding matrix: {N} snippets × {D} dimensions\n") # ── Step 2: Define similarity metrics ───────────────────────────────────── def cosine_matrix(vectors): """Compute pairwise cosine similarity for L2-normalized vectors.""" norms = np.linalg.norm(vectors, axis=1, keepdims=True) norms = np.maximum(norms, 1e-10) normed = vectors / norms return normed @ normed.T def compute_metrics(sim_matrix, labels): """ Compute intra-group (same category) and inter-group (different category) average similarities. The GAP between them measures discriminability. """ intra_sims = [] inter_sims = [] n = len(labels) for i in range(n): for j in range(i + 1, n): if labels[i] == labels[j]: intra_sims.append(sim_matrix[i, j]) else: inter_sims.append(sim_matrix[i, j]) intra_mean = np.mean(intra_sims) inter_mean = np.mean(inter_sims) gap = intra_mean - inter_mean return intra_mean, inter_mean, gap # ── Step 3: Sweep across PCA dimensions ────────────────────────────────── # PCA can have at most min(N, D) components; cap accordingly max_components = min(N, D) dims_to_test = sorted(set( k for k in [2, 3, 5, 8, 10, 15, 20, 30, 50, 75, 100, 150, 200, 300, 400, 500, 600, D] if k <= max_components )) dims_to_test.append(D) # always include full dimensionality as baseline print("=" * 70) print("PCA DENOISING EXPERIMENT") print("=" * 70) print(f"\n{'Components':>12s} {'Intra-Group':>12s} {'Inter-Group':>12s} " f"{'Gap':>8s} {'vs Full':>8s}") print("-" * 62) results = [] for k in dims_to_test: if k >= D: # Full dimensionality — no PCA, just use original vectors X_reduced = X.copy() actual_k = D else: pca = PCA(n_components=k, random_state=42) X_reduced = pca.fit_transform(X) actual_k = k sim = cosine_matrix(X_reduced) intra, inter, gap = compute_metrics(sim, all_labels) results.append((actual_k, intra, inter, gap)) # Compute full-dim gap for comparison full_intra, full_inter, full_gap = results[-1][1], results[-1][2], results[-1][3] for k, intra, inter, gap in results: delta = gap - full_gap delta_str = f"{delta:+.4f}" if k < D else " (base)" print(f"{k:>12d} {intra:>12.4f} {inter:>12.4f} {gap:>8.4f} {delta_str:>8s}") # ── Step 4: Find the optimal dimensionality ────────────────────────────── dims_arr = np.array([r[0] for r in results]) gaps_arr = np.array([r[3] for r in results]) best_idx = np.argmax(gaps_arr) best_k, best_gap = int(dims_arr[best_idx]), gaps_arr[best_idx] print(f"\n{'=' * 70}") print(f"BEST DIMENSIONALITY: {best_k} components") print(f" Gap (intra - inter): {best_gap:.4f} vs {full_gap:.4f} at full 768-d") print(f" Improvement: {best_gap - full_gap:+.4f}") print(f"{'=' * 70}") # ── Step 5: Show detailed comparison at optimal k vs full ──────────────── print(f"\n── Detailed Similarity Matrix at k={best_k} vs k={D} ──\n") if best_k < D: pca_best = PCA(n_components=best_k, random_state=42) X_best = pca_best.fit_transform(X) else: X_best = X.copy() sim_full = cosine_matrix(X) sim_best = cosine_matrix(X_best) # Show a selection of interesting pairs print(f"{'Snippet A':>20s} {'Snippet B':>20s} {'Full 768d':>10s} " f"{'PCA {0}d'.format(best_k):>10s} {'Change':>8s}") print("-" * 78) interesting_pairs = [ # Intra-group: should be high ("bubble_sort", "quick_sort"), ("bubble_sort", "merge_sort"), ("read_json", "read_csv"), ("http_get", "fetch_url"), ("factorial", "fibonacci"), ("reverse_str", "is_palindrome"), ("stack_push_pop", "queue_impl"), # Inter-group: should be low ("bubble_sort", "read_json"), ("factorial", "http_get"), ("reverse_str", "download_file"), ("is_prime", "write_file"), ("stack_push_pop", "count_vowels"), ] for n1, n2 in interesting_pairs: i = all_names.index(n1) j = all_names.index(n2) s_full = sim_full[i, j] s_best = sim_best[i, j] same = all_labels[i] == all_labels[j] marker = "SAME" if same else "DIFF" change = s_best - s_full print(f"{n1:>20s} {n2:>20s} {s_full:>10.4f} {s_best:>10.4f} " f"{change:>+8.4f} [{marker}]") # ── Step 6: Text-to-code search comparison ──────────────────────────────── print(f"\n── Text-to-Code Search: Full 768d vs PCA {best_k}d ──\n") search_queries = [ ("sort a list of numbers", "Sorting"), ("read a JSON config file", "File I/O"), ("compute factorial recursively", "Math"), ("make an HTTP GET request", "Networking"), ("check if a number is prime", "Math"), ] if best_k < D: pca_search = PCA(n_components=best_k, random_state=42) X_search = pca_search.fit_transform(X) else: X_search = X.copy() pca_search = None for query, expected_group in search_queries: q_vec = embed_code(query).cpu().numpy().reshape(1, -1) # Full dimension search q_norm = q_vec / np.linalg.norm(q_vec) X_norm = X / np.linalg.norm(X, axis=1, keepdims=True) scores_full = (X_norm @ q_norm.T).flatten() # PCA-reduced search if pca_search is not None: q_reduced = pca_search.transform(q_vec) else: q_reduced = q_vec.copy() q_r_norm = q_reduced / np.linalg.norm(q_reduced) X_s_norm = X_search / np.linalg.norm(X_search, axis=1, keepdims=True) scores_pca = (X_s_norm @ q_r_norm.T).flatten() top_full = np.argsort(-scores_full)[:3] top_pca = np.argsort(-scores_pca)[:3] print(f' Query: "{query}"') print(f' Full 768d: {all_names[top_full[0]]:>16s} ({scores_full[top_full[0]]:.3f})' f' {all_names[top_full[1]]:>16s} ({scores_full[top_full[1]]:.3f})' f' {all_names[top_full[2]]:>16s} ({scores_full[top_full[2]]:.3f})') print(f' PCA {best_k:>3d}d: {all_names[top_pca[0]]:>16s} ({scores_pca[top_pca[0]]:.3f})' f' {all_names[top_pca[1]]:>16s} ({scores_pca[top_pca[1]]:.3f})' f' {all_names[top_pca[2]]:>16s} ({scores_pca[top_pca[2]]:.3f})') full_correct = all_labels[top_full[0]] == expected_group pca_correct = all_labels[top_pca[0]] == expected_group print(f' Full correct: {full_correct} | PCA correct: {pca_correct}') print() # ── Step 7: Visualization ───────────────────────────────────────────────── # Six-panel figure for a comprehensive visual analysis. group_colors = { "Sorting": "#1f77b4", "File I/O": "#ff7f0e", "Math": "#2ca02c", "Networking": "#d62728", "String ops": "#9467bd", "Data structures": "#8c564b", } label_colors = [group_colors[g] for g in all_labels] unique_groups = list(dict.fromkeys(all_labels)) fig = plt.figure(figsize=(20, 13)) fig.suptitle("PCA Denoising Analysis — Can Fewer Dimensions Improve Code Similarity?", fontsize=15, fontweight="bold", y=0.98) # ── Row 1 ── # Plot 1: Intra/inter similarity vs number of PCA components ax1 = fig.add_subplot(2, 3, 1) dims_plot = [r[0] for r in results] intra_plot = [r[1] for r in results] inter_plot = [r[2] for r in results] ax1.fill_between(dims_plot, inter_plot, intra_plot, alpha=0.15, color="tab:green") ax1.plot(dims_plot, intra_plot, "o-", color="tab:blue", linewidth=2, label="Intra-group (same category)", markersize=6) ax1.plot(dims_plot, inter_plot, "s-", color="tab:red", linewidth=2, label="Inter-group (different category)", markersize=6) ax1.axvline(x=best_k, color="green", linestyle="--", alpha=0.7, label=f"Best gap at k={best_k}") ax1.set_xlabel("Number of PCA Components", fontsize=10) ax1.set_ylabel("Average Cosine Similarity", fontsize=10) ax1.set_title("(a) Intra- vs Inter-Group Similarity", fontsize=11, fontweight="bold") ax1.legend(fontsize=7, loc="center right") ax1.set_xscale("log") ax1.grid(True, alpha=0.3) # Plot 2: Gap (discriminability) vs number of PCA components ax2 = fig.add_subplot(2, 3, 2) gaps_plot = [r[3] for r in results] ax2.plot(dims_plot, gaps_plot, "D-", color="tab:green", linewidth=2, markersize=7) ax2.axvline(x=best_k, color="green", linestyle="--", alpha=0.7, label=f"Best k={best_k} (gap={best_gap:.3f})") ax2.axhline(y=full_gap, color="gray", linestyle=":", alpha=0.7, label=f"Full 768d (gap={full_gap:.3f})") ax2.fill_between(dims_plot, full_gap, gaps_plot, alpha=0.12, color="tab:green", where=[g > full_gap for g in gaps_plot]) ax2.set_xlabel("Number of PCA Components", fontsize=10) ax2.set_ylabel("Gap (Intra − Inter)", fontsize=10) ax2.set_title("(b) Discriminability vs Dimensionality", fontsize=11, fontweight="bold") ax2.legend(fontsize=8) ax2.set_xscale("log") ax2.grid(True, alpha=0.3) # Plot 3: Cumulative variance explained pca_full = PCA(n_components=min(N, D), random_state=42) pca_full.fit(X) cumvar = np.cumsum(pca_full.explained_variance_ratio_) * 100 ax3 = fig.add_subplot(2, 3, 3) ax3.plot(range(1, len(cumvar) + 1), cumvar, "-", color="tab:purple", linewidth=2) ax3.axvline(x=best_k, color="green", linestyle="--", alpha=0.7, label=f"Best k={best_k}") for threshold in [90, 95, 99]: k_thresh = np.searchsorted(cumvar, threshold) + 1 if k_thresh <= len(cumvar): ax3.axhline(y=threshold, color="gray", linestyle=":", alpha=0.4) ax3.annotate(f"{threshold}% → k={k_thresh}", xy=(k_thresh, threshold), fontsize=8, color="gray", ha="left", xytext=(k_thresh + 1, threshold - 2)) ax3.set_xlabel("Number of PCA Components", fontsize=10) ax3.set_ylabel("Cumulative Variance Explained (%)", fontsize=10) ax3.set_title("(c) Variance Concentration", fontsize=11, fontweight="bold") ax3.legend(fontsize=8) ax3.set_xscale("log") ax3.grid(True, alpha=0.3) # ── Row 2 ── # Plot 4 & 5: Side-by-side heatmaps (full vs PCA-denoised) # Sort indices by group for a block-diagonal structure sorted_idx = sorted(range(N), key=lambda i: all_labels[i]) sorted_names = [all_names[i] for i in sorted_idx] sorted_labels = [all_labels[i] for i in sorted_idx] sim_full_sorted = sim_full[np.ix_(sorted_idx, sorted_idx)] sim_best_sorted = sim_best[np.ix_(sorted_idx, sorted_idx)] for panel_idx, (mat, title_str) in enumerate([ (sim_full_sorted, f"(d) Similarity Heatmap — Full 768d"), (sim_best_sorted, f"(e) Similarity Heatmap — PCA {best_k}d (Denoised)"), ]): ax = fig.add_subplot(2, 3, 4 + panel_idx) im = ax.imshow(mat, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto") ax.set_xticks(range(N)) ax.set_yticks(range(N)) ax.set_xticklabels(sorted_names, rotation=90, fontsize=5) ax.set_yticklabels(sorted_names, fontsize=5) # Draw group boundary lines prev_label = sorted_labels[0] for i, lab in enumerate(sorted_labels): if lab != prev_label: ax.axhline(y=i - 0.5, color="black", linewidth=1) ax.axvline(x=i - 0.5, color="black", linewidth=1) prev_label = lab ax.set_title(title_str, fontsize=11, fontweight="bold") plt.colorbar(im, ax=ax, shrink=0.8, label="Cosine Similarity") # Plot 6: Bar chart comparing specific pairs at full vs PCA ax6 = fig.add_subplot(2, 3, 6) pair_labels = [] full_scores = [] pca_scores = [] pair_colors = [] for n1, n2 in interesting_pairs: i = all_names.index(n1) j = all_names.index(n2) pair_labels.append(f"{n1}\nvs {n2}") full_scores.append(sim_full[i, j]) pca_scores.append(sim_best[i, j]) pair_colors.append("#2ca02c" if all_labels[i] == all_labels[j] else "#d62728") y_pos = np.arange(len(pair_labels)) bar_h = 0.35 bars_full = ax6.barh(y_pos + bar_h / 2, full_scores, bar_h, label="Full 768d", color="tab:blue", alpha=0.7) bars_pca = ax6.barh(y_pos - bar_h / 2, pca_scores, bar_h, label=f"PCA {best_k}d", color="tab:orange", alpha=0.7) # Color labels by same/different group for i, (yl, col) in enumerate(zip(pair_labels, pair_colors)): ax6.annotate("●", xy=(-0.05, y_pos[i]), fontsize=10, color=col, ha="right", va="center", fontweight="bold", annotation_clip=False) ax6.set_yticks(y_pos) ax6.set_yticklabels(pair_labels, fontsize=6) ax6.set_xlabel("Cosine Similarity", fontsize=10) ax6.set_title("(f) Pair Comparison: Full vs PCA Denoised", fontsize=11, fontweight="bold") ax6.legend(fontsize=8) ax6.axvline(x=0, color="black", linewidth=0.5) ax6.set_xlim(-1.1, 1.1) ax6.grid(True, axis="x", alpha=0.3) ax6.invert_yaxis() # Custom legend for the dots from matplotlib.lines import Line2D dot_legend = [Line2D([0], [0], marker="o", color="w", markerfacecolor="#2ca02c", markersize=8, label="Same group"), Line2D([0], [0], marker="o", color="w", markerfacecolor="#d62728", markersize=8, label="Different group")] ax6.legend(handles=[bars_full, bars_pca] + dot_legend, fontsize=7, loc="lower right") plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.savefig("pca_denoising_analysis.png", dpi=150, bbox_inches="tight") print(f"\nSaved: pca_denoising_analysis.png") # ── Summary ─────────────────────────────────────────────────────────────── print(f""" {'=' * 70} CONCLUSIONS {'=' * 70} 1. VARIANCE CONCENTRATION: The first few PCA components capture a disproportionate amount of variance. This means the embedding space has low effective dimensionality — most of the 768 dimensions are semi-redundant. 2. DENOISING EFFECT: At k={best_k}, the gap between intra-group and inter-group similarity is {best_gap:.4f} (vs {full_gap:.4f} at full 768d). {'PCA denoising IMPROVED discriminability by removing noisy dimensions.' if best_gap > full_gap else 'Full dimensionality was already optimal for this dataset.'} 3. PRACTICAL IMPLICATIONS: - For retrieval (code search), moderate PCA reduction can sharpen results while also reducing storage and computation. - Too few dimensions (k=2,3) lose important signal. - Too many dimensions may retain noise that dilutes similarity. - The "sweet spot" depends on the dataset and task. 4. TRADE-OFF: PCA denoising is a post-hoc technique. Newer embedding models are trained with Matryoshka Representation Learning (MRL) that makes the FIRST k dimensions maximally informative by design. """)