717 lines
25 KiB
Python
717 lines
25 KiB
Python
"""
|
||
============================================================================
|
||
Example 6: PCA Denoising — Can Fewer Dimensions Improve Similarity?
|
||
============================================================================
|
||
AISE501 – AI in Software Engineering I
|
||
Fachhochschule Graubünden
|
||
|
||
HYPOTHESIS:
|
||
Embedding vectors live in a 768-dimensional space, but most of the
|
||
semantic signal may be concentrated in a small number of principal
|
||
components. The remaining dimensions could add "noise" that dilutes
|
||
cosine similarity. If true, projecting embeddings onto a small PCA
|
||
subspace should INCREASE similarity within semantic groups and
|
||
DECREASE similarity across groups — making code search sharper.
|
||
|
||
WHAT YOU WILL LEARN:
|
||
- How PCA decomposes the embedding space into ranked components
|
||
- How to measure retrieval quality (intra- vs inter-group similarity)
|
||
- Whether dimensionality reduction helps or hurts in practice
|
||
- The concept of an "optimal" embedding dimension for a given task
|
||
|
||
OUTPUT:
|
||
Saves pca_denoising_analysis.png with three sub-plots.
|
||
|
||
HARDWARE:
|
||
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
|
||
============================================================================
|
||
"""
|
||
|
||
import torch
|
||
import numpy as np
|
||
from transformers import AutoTokenizer, AutoModel
|
||
import torch.nn.functional as F
|
||
from sklearn.decomposition import PCA
|
||
import matplotlib.pyplot as plt
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
|
||
# ── Device selection ──────────────────────────────────────────────────────
|
||
def get_device():
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
elif torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
DEVICE = get_device()
|
||
print(f"Using device: {DEVICE}\n")
|
||
|
||
# ── Load model ────────────────────────────────────────────────────────────
|
||
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
|
||
print(f"Loading model: {MODEL_NAME} ...")
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
||
model.eval()
|
||
print("Model loaded.\n")
|
||
|
||
# ── Code snippets organized into semantic GROUPS ──────────────────────────
|
||
# We need clear groups so we can measure intra-group vs inter-group similarity.
|
||
groups = {
|
||
"Sorting": {
|
||
"bubble_sort": """
|
||
def bubble_sort(arr):
|
||
n = len(arr)
|
||
for i in range(n):
|
||
for j in range(0, n - i - 1):
|
||
if arr[j] > arr[j + 1]:
|
||
arr[j], arr[j + 1] = arr[j + 1], arr[j]
|
||
return arr""",
|
||
"quick_sort": """
|
||
def quick_sort(arr):
|
||
if len(arr) <= 1:
|
||
return arr
|
||
pivot = arr[len(arr) // 2]
|
||
left = [x for x in arr if x < pivot]
|
||
middle = [x for x in arr if x == pivot]
|
||
right = [x for x in arr if x > pivot]
|
||
return quick_sort(left) + middle + quick_sort(right)""",
|
||
"merge_sort": """
|
||
def merge_sort(arr):
|
||
if len(arr) <= 1:
|
||
return arr
|
||
mid = len(arr) // 2
|
||
left = merge_sort(arr[:mid])
|
||
right = merge_sort(arr[mid:])
|
||
merged = []
|
||
i = j = 0
|
||
while i < len(left) and j < len(right):
|
||
if left[i] <= right[j]:
|
||
merged.append(left[i]); i += 1
|
||
else:
|
||
merged.append(right[j]); j += 1
|
||
return merged + left[i:] + right[j:]""",
|
||
"insertion_sort": """
|
||
def insertion_sort(arr):
|
||
for i in range(1, len(arr)):
|
||
key = arr[i]
|
||
j = i - 1
|
||
while j >= 0 and arr[j] > key:
|
||
arr[j + 1] = arr[j]
|
||
j -= 1
|
||
arr[j + 1] = key
|
||
return arr""",
|
||
"selection_sort": """
|
||
def selection_sort(arr):
|
||
for i in range(len(arr)):
|
||
min_idx = i
|
||
for j in range(i + 1, len(arr)):
|
||
if arr[j] < arr[min_idx]:
|
||
min_idx = j
|
||
arr[i], arr[min_idx] = arr[min_idx], arr[i]
|
||
return arr""",
|
||
"heap_sort": """
|
||
def heap_sort(arr):
|
||
import heapq
|
||
heapq.heapify(arr)
|
||
return [heapq.heappop(arr) for _ in range(len(arr))]""",
|
||
},
|
||
"File I/O": {
|
||
"read_json": """
|
||
import json
|
||
def read_json(path):
|
||
with open(path, 'r') as f:
|
||
return json.load(f)""",
|
||
"write_file": """
|
||
def write_file(path, content):
|
||
with open(path, 'w') as f:
|
||
f.write(content)""",
|
||
"read_csv": """
|
||
import csv
|
||
def read_csv(path):
|
||
with open(path, 'r') as f:
|
||
reader = csv.reader(f)
|
||
return list(reader)""",
|
||
"read_yaml": """
|
||
import yaml
|
||
def load_yaml(path):
|
||
with open(path, 'r') as f:
|
||
return yaml.safe_load(f)""",
|
||
"write_json": """
|
||
import json
|
||
def write_json(path, data):
|
||
with open(path, 'w') as f:
|
||
json.dump(data, f, indent=2)""",
|
||
"read_lines": """
|
||
def read_lines(path):
|
||
with open(path, 'r') as f:
|
||
return f.readlines()""",
|
||
},
|
||
"Math": {
|
||
"factorial": """
|
||
def factorial(n):
|
||
if n <= 1:
|
||
return 1
|
||
return n * factorial(n - 1)""",
|
||
"fibonacci": """
|
||
def fibonacci(n):
|
||
a, b = 0, 1
|
||
for _ in range(n):
|
||
a, b = b, a + b
|
||
return a""",
|
||
"gcd": """
|
||
def gcd(a, b):
|
||
while b:
|
||
a, b = b, a % b
|
||
return a""",
|
||
"is_prime": """
|
||
def is_prime(n):
|
||
if n < 2:
|
||
return False
|
||
for i in range(2, int(n**0.5) + 1):
|
||
if n % i == 0:
|
||
return False
|
||
return True""",
|
||
"power": """
|
||
def power(base, exp):
|
||
if exp == 0:
|
||
return 1
|
||
if exp % 2 == 0:
|
||
half = power(base, exp // 2)
|
||
return half * half
|
||
return base * power(base, exp - 1)""",
|
||
"sum_digits": """
|
||
def sum_digits(n):
|
||
total = 0
|
||
while n > 0:
|
||
total += n % 10
|
||
n //= 10
|
||
return total""",
|
||
},
|
||
"Networking": {
|
||
"http_get": """
|
||
import requests
|
||
def http_get(url):
|
||
response = requests.get(url)
|
||
return response.json()""",
|
||
"post_data": """
|
||
import requests
|
||
def post_data(url, payload):
|
||
response = requests.post(url, json=payload)
|
||
return response.status_code, response.json()""",
|
||
"fetch_url": """
|
||
import urllib.request
|
||
def fetch_url(url):
|
||
with urllib.request.urlopen(url) as resp:
|
||
return resp.read().decode('utf-8')""",
|
||
"download_file": """
|
||
import urllib.request
|
||
def download_file(url, dest):
|
||
urllib.request.urlretrieve(url, dest)
|
||
return dest""",
|
||
"http_put": """
|
||
import requests
|
||
def http_put(url, data):
|
||
response = requests.put(url, json=data)
|
||
return response.status_code""",
|
||
"http_delete": """
|
||
import requests
|
||
def http_delete(url):
|
||
response = requests.delete(url)
|
||
return response.status_code""",
|
||
},
|
||
"String ops": {
|
||
"reverse_str": """
|
||
def reverse_string(s):
|
||
return s[::-1]""",
|
||
"is_palindrome": """
|
||
def is_palindrome(s):
|
||
s = s.lower().replace(' ', '')
|
||
return s == s[::-1]""",
|
||
"count_vowels": """
|
||
def count_vowels(s):
|
||
return sum(1 for c in s.lower() if c in 'aeiou')""",
|
||
"capitalize_words": """
|
||
def capitalize_words(s):
|
||
return ' '.join(w.capitalize() for w in s.split())""",
|
||
"remove_duplicates": """
|
||
def remove_duplicate_chars(s):
|
||
seen = set()
|
||
result = []
|
||
for c in s:
|
||
if c not in seen:
|
||
seen.add(c)
|
||
result.append(c)
|
||
return ''.join(result)""",
|
||
"count_words": """
|
||
def count_words(text):
|
||
words = text.lower().split()
|
||
freq = {}
|
||
for w in words:
|
||
freq[w] = freq.get(w, 0) + 1
|
||
return freq""",
|
||
},
|
||
"Data structures": {
|
||
"stack_push_pop": """
|
||
class Stack:
|
||
def __init__(self):
|
||
self.items = []
|
||
def push(self, item):
|
||
self.items.append(item)
|
||
def pop(self):
|
||
return self.items.pop()""",
|
||
"queue_impl": """
|
||
from collections import deque
|
||
class Queue:
|
||
def __init__(self):
|
||
self.items = deque()
|
||
def enqueue(self, item):
|
||
self.items.append(item)
|
||
def dequeue(self):
|
||
return self.items.popleft()""",
|
||
"linked_list": """
|
||
class Node:
|
||
def __init__(self, val):
|
||
self.val = val
|
||
self.next = None
|
||
class LinkedList:
|
||
def __init__(self):
|
||
self.head = None
|
||
def append(self, val):
|
||
node = Node(val)
|
||
if not self.head:
|
||
self.head = node
|
||
return
|
||
curr = self.head
|
||
while curr.next:
|
||
curr = curr.next
|
||
curr.next = node""",
|
||
"binary_tree": """
|
||
class TreeNode:
|
||
def __init__(self, val):
|
||
self.val = val
|
||
self.left = None
|
||
self.right = None
|
||
def inorder(root):
|
||
if root:
|
||
yield from inorder(root.left)
|
||
yield root.val
|
||
yield from inorder(root.right)""",
|
||
"hash_map": """
|
||
class HashMap:
|
||
def __init__(self, size=256):
|
||
self.buckets = [[] for _ in range(size)]
|
||
def put(self, key, value):
|
||
idx = hash(key) % len(self.buckets)
|
||
for i, (k, v) in enumerate(self.buckets[idx]):
|
||
if k == key:
|
||
self.buckets[idx][i] = (key, value)
|
||
return
|
||
self.buckets[idx].append((key, value))""",
|
||
"priority_queue": """
|
||
import heapq
|
||
class PriorityQueue:
|
||
def __init__(self):
|
||
self.heap = []
|
||
def push(self, priority, item):
|
||
heapq.heappush(self.heap, (priority, item))
|
||
def pop(self):
|
||
return heapq.heappop(self.heap)[1]""",
|
||
},
|
||
}
|
||
|
||
|
||
def embed_code(code: str) -> torch.Tensor:
|
||
"""Embed code into a normalized vector."""
|
||
inputs = tokenizer(
|
||
code, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||
).to(DEVICE)
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
mask = inputs["attention_mask"].unsqueeze(-1)
|
||
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
|
||
return F.normalize(embedding, p=2, dim=1).squeeze(0)
|
||
|
||
|
||
# ── Step 1: Compute all embeddings ────────────────────────────────────────
|
||
print("Computing embeddings...")
|
||
all_names = []
|
||
all_labels = []
|
||
all_vectors = []
|
||
|
||
for group_name, snippets in groups.items():
|
||
for snippet_name, code in snippets.items():
|
||
vec = embed_code(code).cpu().numpy()
|
||
all_names.append(snippet_name)
|
||
all_labels.append(group_name)
|
||
all_vectors.append(vec)
|
||
print(f" [{group_name:12s}] {snippet_name}")
|
||
|
||
X = np.stack(all_vectors) # shape: [N, 768]
|
||
N, D = X.shape
|
||
print(f"\nEmbedding matrix: {N} snippets × {D} dimensions\n")
|
||
|
||
# ── Step 2: Define similarity metrics ─────────────────────────────────────
|
||
def cosine_matrix(vectors):
|
||
"""Compute pairwise cosine similarity for L2-normalized vectors."""
|
||
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
|
||
norms = np.maximum(norms, 1e-10)
|
||
normed = vectors / norms
|
||
return normed @ normed.T
|
||
|
||
def compute_metrics(sim_matrix, labels):
|
||
"""
|
||
Compute intra-group (same category) and inter-group (different category)
|
||
average similarities. The GAP between them measures discriminability.
|
||
"""
|
||
intra_sims = []
|
||
inter_sims = []
|
||
n = len(labels)
|
||
for i in range(n):
|
||
for j in range(i + 1, n):
|
||
if labels[i] == labels[j]:
|
||
intra_sims.append(sim_matrix[i, j])
|
||
else:
|
||
inter_sims.append(sim_matrix[i, j])
|
||
intra_mean = np.mean(intra_sims)
|
||
inter_mean = np.mean(inter_sims)
|
||
gap = intra_mean - inter_mean
|
||
return intra_mean, inter_mean, gap
|
||
|
||
|
||
# ── Step 3: Sweep across PCA dimensions ──────────────────────────────────
|
||
# PCA can have at most min(N, D) components; cap accordingly
|
||
max_components = min(N, D)
|
||
dims_to_test = sorted(set(
|
||
k for k in [2, 3, 5, 8, 10, 15, 20, 30, 50, 75, 100, 150, 200,
|
||
300, 400, 500, 600, D]
|
||
if k <= max_components
|
||
))
|
||
dims_to_test.append(D) # always include full dimensionality as baseline
|
||
|
||
print("=" * 70)
|
||
print("PCA DENOISING EXPERIMENT")
|
||
print("=" * 70)
|
||
print(f"\n{'Components':>12s} {'Intra-Group':>12s} {'Inter-Group':>12s} "
|
||
f"{'Gap':>8s} {'vs Full':>8s}")
|
||
print("-" * 62)
|
||
|
||
results = []
|
||
for k in dims_to_test:
|
||
if k >= D:
|
||
# Full dimensionality — no PCA, just use original vectors
|
||
X_reduced = X.copy()
|
||
actual_k = D
|
||
else:
|
||
pca = PCA(n_components=k, random_state=42)
|
||
X_reduced = pca.fit_transform(X)
|
||
actual_k = k
|
||
|
||
sim = cosine_matrix(X_reduced)
|
||
intra, inter, gap = compute_metrics(sim, all_labels)
|
||
results.append((actual_k, intra, inter, gap))
|
||
|
||
# Compute full-dim gap for comparison
|
||
full_intra, full_inter, full_gap = results[-1][1], results[-1][2], results[-1][3]
|
||
|
||
for k, intra, inter, gap in results:
|
||
delta = gap - full_gap
|
||
delta_str = f"{delta:+.4f}" if k < D else " (base)"
|
||
print(f"{k:>12d} {intra:>12.4f} {inter:>12.4f} {gap:>8.4f} {delta_str:>8s}")
|
||
|
||
# ── Step 4: Find the optimal dimensionality ──────────────────────────────
|
||
dims_arr = np.array([r[0] for r in results])
|
||
gaps_arr = np.array([r[3] for r in results])
|
||
best_idx = np.argmax(gaps_arr)
|
||
best_k, best_gap = int(dims_arr[best_idx]), gaps_arr[best_idx]
|
||
|
||
print(f"\n{'=' * 70}")
|
||
print(f"BEST DIMENSIONALITY: {best_k} components")
|
||
print(f" Gap (intra - inter): {best_gap:.4f} vs {full_gap:.4f} at full 768-d")
|
||
print(f" Improvement: {best_gap - full_gap:+.4f}")
|
||
print(f"{'=' * 70}")
|
||
|
||
# ── Step 5: Show detailed comparison at optimal k vs full ────────────────
|
||
print(f"\n── Detailed Similarity Matrix at k={best_k} vs k={D} ──\n")
|
||
|
||
if best_k < D:
|
||
pca_best = PCA(n_components=best_k, random_state=42)
|
||
X_best = pca_best.fit_transform(X)
|
||
else:
|
||
X_best = X.copy()
|
||
|
||
sim_full = cosine_matrix(X)
|
||
sim_best = cosine_matrix(X_best)
|
||
|
||
# Show a selection of interesting pairs
|
||
print(f"{'Snippet A':>20s} {'Snippet B':>20s} {'Full 768d':>10s} "
|
||
f"{'PCA {0}d'.format(best_k):>10s} {'Change':>8s}")
|
||
print("-" * 78)
|
||
|
||
interesting_pairs = [
|
||
# Intra-group: should be high
|
||
("bubble_sort", "quick_sort"),
|
||
("bubble_sort", "merge_sort"),
|
||
("read_json", "read_csv"),
|
||
("http_get", "fetch_url"),
|
||
("factorial", "fibonacci"),
|
||
("reverse_str", "is_palindrome"),
|
||
("stack_push_pop", "queue_impl"),
|
||
# Inter-group: should be low
|
||
("bubble_sort", "read_json"),
|
||
("factorial", "http_get"),
|
||
("reverse_str", "download_file"),
|
||
("is_prime", "write_file"),
|
||
("stack_push_pop", "count_vowels"),
|
||
]
|
||
|
||
for n1, n2 in interesting_pairs:
|
||
i = all_names.index(n1)
|
||
j = all_names.index(n2)
|
||
s_full = sim_full[i, j]
|
||
s_best = sim_best[i, j]
|
||
same = all_labels[i] == all_labels[j]
|
||
marker = "SAME" if same else "DIFF"
|
||
change = s_best - s_full
|
||
print(f"{n1:>20s} {n2:>20s} {s_full:>10.4f} {s_best:>10.4f} "
|
||
f"{change:>+8.4f} [{marker}]")
|
||
|
||
|
||
# ── Step 6: Text-to-code search comparison ────────────────────────────────
|
||
print(f"\n── Text-to-Code Search: Full 768d vs PCA {best_k}d ──\n")
|
||
|
||
search_queries = [
|
||
("sort a list of numbers", "Sorting"),
|
||
("read a JSON config file", "File I/O"),
|
||
("compute factorial recursively", "Math"),
|
||
("make an HTTP GET request", "Networking"),
|
||
("check if a number is prime", "Math"),
|
||
]
|
||
|
||
if best_k < D:
|
||
pca_search = PCA(n_components=best_k, random_state=42)
|
||
X_search = pca_search.fit_transform(X)
|
||
else:
|
||
X_search = X.copy()
|
||
pca_search = None
|
||
|
||
for query, expected_group in search_queries:
|
||
q_vec = embed_code(query).cpu().numpy().reshape(1, -1)
|
||
|
||
# Full dimension search
|
||
q_norm = q_vec / np.linalg.norm(q_vec)
|
||
X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)
|
||
scores_full = (X_norm @ q_norm.T).flatten()
|
||
|
||
# PCA-reduced search
|
||
if pca_search is not None:
|
||
q_reduced = pca_search.transform(q_vec)
|
||
else:
|
||
q_reduced = q_vec.copy()
|
||
q_r_norm = q_reduced / np.linalg.norm(q_reduced)
|
||
X_s_norm = X_search / np.linalg.norm(X_search, axis=1, keepdims=True)
|
||
scores_pca = (X_s_norm @ q_r_norm.T).flatten()
|
||
|
||
top_full = np.argsort(-scores_full)[:3]
|
||
top_pca = np.argsort(-scores_pca)[:3]
|
||
|
||
print(f' Query: "{query}"')
|
||
print(f' Full 768d: {all_names[top_full[0]]:>16s} ({scores_full[top_full[0]]:.3f})'
|
||
f' {all_names[top_full[1]]:>16s} ({scores_full[top_full[1]]:.3f})'
|
||
f' {all_names[top_full[2]]:>16s} ({scores_full[top_full[2]]:.3f})')
|
||
print(f' PCA {best_k:>3d}d: {all_names[top_pca[0]]:>16s} ({scores_pca[top_pca[0]]:.3f})'
|
||
f' {all_names[top_pca[1]]:>16s} ({scores_pca[top_pca[1]]:.3f})'
|
||
f' {all_names[top_pca[2]]:>16s} ({scores_pca[top_pca[2]]:.3f})')
|
||
|
||
full_correct = all_labels[top_full[0]] == expected_group
|
||
pca_correct = all_labels[top_pca[0]] == expected_group
|
||
print(f' Full correct: {full_correct} | PCA correct: {pca_correct}')
|
||
print()
|
||
|
||
|
||
# ── Step 7: Visualization ─────────────────────────────────────────────────
|
||
# Six-panel figure for a comprehensive visual analysis.
|
||
|
||
group_colors = {
|
||
"Sorting": "#1f77b4", "File I/O": "#ff7f0e", "Math": "#2ca02c",
|
||
"Networking": "#d62728", "String ops": "#9467bd", "Data structures": "#8c564b",
|
||
}
|
||
label_colors = [group_colors[g] for g in all_labels]
|
||
unique_groups = list(dict.fromkeys(all_labels))
|
||
|
||
fig = plt.figure(figsize=(20, 13))
|
||
fig.suptitle("PCA Denoising Analysis — Can Fewer Dimensions Improve Code Similarity?",
|
||
fontsize=15, fontweight="bold", y=0.98)
|
||
|
||
# ── Row 1 ──
|
||
|
||
# Plot 1: Intra/inter similarity vs number of PCA components
|
||
ax1 = fig.add_subplot(2, 3, 1)
|
||
dims_plot = [r[0] for r in results]
|
||
intra_plot = [r[1] for r in results]
|
||
inter_plot = [r[2] for r in results]
|
||
ax1.fill_between(dims_plot, inter_plot, intra_plot, alpha=0.15, color="tab:green")
|
||
ax1.plot(dims_plot, intra_plot, "o-", color="tab:blue", linewidth=2,
|
||
label="Intra-group (same category)", markersize=6)
|
||
ax1.plot(dims_plot, inter_plot, "s-", color="tab:red", linewidth=2,
|
||
label="Inter-group (different category)", markersize=6)
|
||
ax1.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
|
||
label=f"Best gap at k={best_k}")
|
||
ax1.set_xlabel("Number of PCA Components", fontsize=10)
|
||
ax1.set_ylabel("Average Cosine Similarity", fontsize=10)
|
||
ax1.set_title("(a) Intra- vs Inter-Group Similarity", fontsize=11, fontweight="bold")
|
||
ax1.legend(fontsize=7, loc="center right")
|
||
ax1.set_xscale("log")
|
||
ax1.grid(True, alpha=0.3)
|
||
|
||
# Plot 2: Gap (discriminability) vs number of PCA components
|
||
ax2 = fig.add_subplot(2, 3, 2)
|
||
gaps_plot = [r[3] for r in results]
|
||
ax2.plot(dims_plot, gaps_plot, "D-", color="tab:green", linewidth=2, markersize=7)
|
||
ax2.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
|
||
label=f"Best k={best_k} (gap={best_gap:.3f})")
|
||
ax2.axhline(y=full_gap, color="gray", linestyle=":", alpha=0.7,
|
||
label=f"Full 768d (gap={full_gap:.3f})")
|
||
ax2.fill_between(dims_plot, full_gap, gaps_plot, alpha=0.12, color="tab:green",
|
||
where=[g > full_gap for g in gaps_plot])
|
||
ax2.set_xlabel("Number of PCA Components", fontsize=10)
|
||
ax2.set_ylabel("Gap (Intra − Inter)", fontsize=10)
|
||
ax2.set_title("(b) Discriminability vs Dimensionality", fontsize=11, fontweight="bold")
|
||
ax2.legend(fontsize=8)
|
||
ax2.set_xscale("log")
|
||
ax2.grid(True, alpha=0.3)
|
||
|
||
# Plot 3: Cumulative variance explained
|
||
pca_full = PCA(n_components=min(N, D), random_state=42)
|
||
pca_full.fit(X)
|
||
cumvar = np.cumsum(pca_full.explained_variance_ratio_) * 100
|
||
ax3 = fig.add_subplot(2, 3, 3)
|
||
ax3.plot(range(1, len(cumvar) + 1), cumvar, "-", color="tab:purple", linewidth=2)
|
||
ax3.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
|
||
label=f"Best k={best_k}")
|
||
for threshold in [90, 95, 99]:
|
||
k_thresh = np.searchsorted(cumvar, threshold) + 1
|
||
if k_thresh <= len(cumvar):
|
||
ax3.axhline(y=threshold, color="gray", linestyle=":", alpha=0.4)
|
||
ax3.annotate(f"{threshold}% → k={k_thresh}", xy=(k_thresh, threshold),
|
||
fontsize=8, color="gray", ha="left",
|
||
xytext=(k_thresh + 1, threshold - 2))
|
||
ax3.set_xlabel("Number of PCA Components", fontsize=10)
|
||
ax3.set_ylabel("Cumulative Variance Explained (%)", fontsize=10)
|
||
ax3.set_title("(c) Variance Concentration", fontsize=11, fontweight="bold")
|
||
ax3.legend(fontsize=8)
|
||
ax3.set_xscale("log")
|
||
ax3.grid(True, alpha=0.3)
|
||
|
||
# ── Row 2 ──
|
||
|
||
# Plot 4 & 5: Side-by-side heatmaps (full vs PCA-denoised)
|
||
# Sort indices by group for a block-diagonal structure
|
||
sorted_idx = sorted(range(N), key=lambda i: all_labels[i])
|
||
sorted_names = [all_names[i] for i in sorted_idx]
|
||
sorted_labels = [all_labels[i] for i in sorted_idx]
|
||
|
||
sim_full_sorted = sim_full[np.ix_(sorted_idx, sorted_idx)]
|
||
sim_best_sorted = sim_best[np.ix_(sorted_idx, sorted_idx)]
|
||
|
||
for panel_idx, (mat, title_str) in enumerate([
|
||
(sim_full_sorted, f"(d) Similarity Heatmap — Full 768d"),
|
||
(sim_best_sorted, f"(e) Similarity Heatmap — PCA {best_k}d (Denoised)"),
|
||
]):
|
||
ax = fig.add_subplot(2, 3, 4 + panel_idx)
|
||
im = ax.imshow(mat, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")
|
||
ax.set_xticks(range(N))
|
||
ax.set_yticks(range(N))
|
||
ax.set_xticklabels(sorted_names, rotation=90, fontsize=5)
|
||
ax.set_yticklabels(sorted_names, fontsize=5)
|
||
|
||
# Draw group boundary lines
|
||
prev_label = sorted_labels[0]
|
||
for i, lab in enumerate(sorted_labels):
|
||
if lab != prev_label:
|
||
ax.axhline(y=i - 0.5, color="black", linewidth=1)
|
||
ax.axvline(x=i - 0.5, color="black", linewidth=1)
|
||
prev_label = lab
|
||
|
||
ax.set_title(title_str, fontsize=11, fontweight="bold")
|
||
plt.colorbar(im, ax=ax, shrink=0.8, label="Cosine Similarity")
|
||
|
||
# Plot 6: Bar chart comparing specific pairs at full vs PCA
|
||
ax6 = fig.add_subplot(2, 3, 6)
|
||
pair_labels = []
|
||
full_scores = []
|
||
pca_scores = []
|
||
pair_colors = []
|
||
|
||
for n1, n2 in interesting_pairs:
|
||
i = all_names.index(n1)
|
||
j = all_names.index(n2)
|
||
pair_labels.append(f"{n1}\nvs {n2}")
|
||
full_scores.append(sim_full[i, j])
|
||
pca_scores.append(sim_best[i, j])
|
||
pair_colors.append("#2ca02c" if all_labels[i] == all_labels[j] else "#d62728")
|
||
|
||
y_pos = np.arange(len(pair_labels))
|
||
bar_h = 0.35
|
||
bars_full = ax6.barh(y_pos + bar_h / 2, full_scores, bar_h, label="Full 768d",
|
||
color="tab:blue", alpha=0.7)
|
||
bars_pca = ax6.barh(y_pos - bar_h / 2, pca_scores, bar_h, label=f"PCA {best_k}d",
|
||
color="tab:orange", alpha=0.7)
|
||
|
||
# Color labels by same/different group
|
||
for i, (yl, col) in enumerate(zip(pair_labels, pair_colors)):
|
||
ax6.annotate("●", xy=(-0.05, y_pos[i]), fontsize=10, color=col,
|
||
ha="right", va="center", fontweight="bold",
|
||
annotation_clip=False)
|
||
|
||
ax6.set_yticks(y_pos)
|
||
ax6.set_yticklabels(pair_labels, fontsize=6)
|
||
ax6.set_xlabel("Cosine Similarity", fontsize=10)
|
||
ax6.set_title("(f) Pair Comparison: Full vs PCA Denoised", fontsize=11, fontweight="bold")
|
||
ax6.legend(fontsize=8)
|
||
ax6.axvline(x=0, color="black", linewidth=0.5)
|
||
ax6.set_xlim(-1.1, 1.1)
|
||
ax6.grid(True, axis="x", alpha=0.3)
|
||
ax6.invert_yaxis()
|
||
|
||
# Custom legend for the dots
|
||
from matplotlib.lines import Line2D
|
||
dot_legend = [Line2D([0], [0], marker="o", color="w", markerfacecolor="#2ca02c",
|
||
markersize=8, label="Same group"),
|
||
Line2D([0], [0], marker="o", color="w", markerfacecolor="#d62728",
|
||
markersize=8, label="Different group")]
|
||
ax6.legend(handles=[bars_full, bars_pca] + dot_legend, fontsize=7, loc="lower right")
|
||
|
||
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
||
plt.savefig("pca_denoising_analysis.png", dpi=150, bbox_inches="tight")
|
||
print(f"\nSaved: pca_denoising_analysis.png")
|
||
|
||
# ── Summary ───────────────────────────────────────────────────────────────
|
||
print(f"""
|
||
{'=' * 70}
|
||
CONCLUSIONS
|
||
{'=' * 70}
|
||
|
||
1. VARIANCE CONCENTRATION:
|
||
The first few PCA components capture a disproportionate amount of
|
||
variance. This means the embedding space has low effective
|
||
dimensionality — most of the 768 dimensions are semi-redundant.
|
||
|
||
2. DENOISING EFFECT:
|
||
At k={best_k}, the gap between intra-group and inter-group similarity
|
||
is {best_gap:.4f} (vs {full_gap:.4f} at full 768d).
|
||
{'PCA denoising IMPROVED discriminability by removing noisy dimensions.' if best_gap > full_gap else 'Full dimensionality was already optimal for this dataset.'}
|
||
|
||
3. PRACTICAL IMPLICATIONS:
|
||
- For retrieval (code search), moderate PCA reduction can sharpen
|
||
results while also reducing storage and computation.
|
||
- Too few dimensions (k=2,3) lose important signal.
|
||
- Too many dimensions may retain noise that dilutes similarity.
|
||
- The "sweet spot" depends on the dataset and task.
|
||
|
||
4. TRADE-OFF:
|
||
PCA denoising is a post-hoc technique. Newer embedding models are
|
||
trained with Matryoshka Representation Learning (MRL) that makes
|
||
the FIRST k dimensions maximally informative by design.
|
||
""")
|