AISE1_CLASS/Code embeddings/06_pca_denoising.py

717 lines
25 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
============================================================================
Example 6: PCA Denoising — Can Fewer Dimensions Improve Similarity?
============================================================================
AISE501 AI in Software Engineering I
Fachhochschule Graubünden
HYPOTHESIS:
Embedding vectors live in a 768-dimensional space, but most of the
semantic signal may be concentrated in a small number of principal
components. The remaining dimensions could add "noise" that dilutes
cosine similarity. If true, projecting embeddings onto a small PCA
subspace should INCREASE similarity within semantic groups and
DECREASE similarity across groups — making code search sharper.
WHAT YOU WILL LEARN:
- How PCA decomposes the embedding space into ranked components
- How to measure retrieval quality (intra- vs inter-group similarity)
- Whether dimensionality reduction helps or hurts in practice
- The concept of an "optimal" embedding dimension for a given task
OUTPUT:
Saves pca_denoising_analysis.png with three sub-plots.
HARDWARE:
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
============================================================================
"""
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use("Agg")
# ── Device selection ──────────────────────────────────────────────────────
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
DEVICE = get_device()
print(f"Using device: {DEVICE}\n")
# ── Load model ────────────────────────────────────────────────────────────
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
print(f"Loading model: {MODEL_NAME} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
print("Model loaded.\n")
# ── Code snippets organized into semantic GROUPS ──────────────────────────
# We need clear groups so we can measure intra-group vs inter-group similarity.
groups = {
"Sorting": {
"bubble_sort": """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n - i - 1):
if arr[j] > arr[j + 1]:
arr[j], arr[j + 1] = arr[j + 1], arr[j]
return arr""",
"quick_sort": """
def quick_sort(arr):
if len(arr) <= 1:
return arr
pivot = arr[len(arr) // 2]
left = [x for x in arr if x < pivot]
middle = [x for x in arr if x == pivot]
right = [x for x in arr if x > pivot]
return quick_sort(left) + middle + quick_sort(right)""",
"merge_sort": """
def merge_sort(arr):
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
merged = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
merged.append(left[i]); i += 1
else:
merged.append(right[j]); j += 1
return merged + left[i:] + right[j:]""",
"insertion_sort": """
def insertion_sort(arr):
for i in range(1, len(arr)):
key = arr[i]
j = i - 1
while j >= 0 and arr[j] > key:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
return arr""",
"selection_sort": """
def selection_sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i + 1, len(arr)):
if arr[j] < arr[min_idx]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr""",
"heap_sort": """
def heap_sort(arr):
import heapq
heapq.heapify(arr)
return [heapq.heappop(arr) for _ in range(len(arr))]""",
},
"File I/O": {
"read_json": """
import json
def read_json(path):
with open(path, 'r') as f:
return json.load(f)""",
"write_file": """
def write_file(path, content):
with open(path, 'w') as f:
f.write(content)""",
"read_csv": """
import csv
def read_csv(path):
with open(path, 'r') as f:
reader = csv.reader(f)
return list(reader)""",
"read_yaml": """
import yaml
def load_yaml(path):
with open(path, 'r') as f:
return yaml.safe_load(f)""",
"write_json": """
import json
def write_json(path, data):
with open(path, 'w') as f:
json.dump(data, f, indent=2)""",
"read_lines": """
def read_lines(path):
with open(path, 'r') as f:
return f.readlines()""",
},
"Math": {
"factorial": """
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)""",
"fibonacci": """
def fibonacci(n):
a, b = 0, 1
for _ in range(n):
a, b = b, a + b
return a""",
"gcd": """
def gcd(a, b):
while b:
a, b = b, a % b
return a""",
"is_prime": """
def is_prime(n):
if n < 2:
return False
for i in range(2, int(n**0.5) + 1):
if n % i == 0:
return False
return True""",
"power": """
def power(base, exp):
if exp == 0:
return 1
if exp % 2 == 0:
half = power(base, exp // 2)
return half * half
return base * power(base, exp - 1)""",
"sum_digits": """
def sum_digits(n):
total = 0
while n > 0:
total += n % 10
n //= 10
return total""",
},
"Networking": {
"http_get": """
import requests
def http_get(url):
response = requests.get(url)
return response.json()""",
"post_data": """
import requests
def post_data(url, payload):
response = requests.post(url, json=payload)
return response.status_code, response.json()""",
"fetch_url": """
import urllib.request
def fetch_url(url):
with urllib.request.urlopen(url) as resp:
return resp.read().decode('utf-8')""",
"download_file": """
import urllib.request
def download_file(url, dest):
urllib.request.urlretrieve(url, dest)
return dest""",
"http_put": """
import requests
def http_put(url, data):
response = requests.put(url, json=data)
return response.status_code""",
"http_delete": """
import requests
def http_delete(url):
response = requests.delete(url)
return response.status_code""",
},
"String ops": {
"reverse_str": """
def reverse_string(s):
return s[::-1]""",
"is_palindrome": """
def is_palindrome(s):
s = s.lower().replace(' ', '')
return s == s[::-1]""",
"count_vowels": """
def count_vowels(s):
return sum(1 for c in s.lower() if c in 'aeiou')""",
"capitalize_words": """
def capitalize_words(s):
return ' '.join(w.capitalize() for w in s.split())""",
"remove_duplicates": """
def remove_duplicate_chars(s):
seen = set()
result = []
for c in s:
if c not in seen:
seen.add(c)
result.append(c)
return ''.join(result)""",
"count_words": """
def count_words(text):
words = text.lower().split()
freq = {}
for w in words:
freq[w] = freq.get(w, 0) + 1
return freq""",
},
"Data structures": {
"stack_push_pop": """
class Stack:
def __init__(self):
self.items = []
def push(self, item):
self.items.append(item)
def pop(self):
return self.items.pop()""",
"queue_impl": """
from collections import deque
class Queue:
def __init__(self):
self.items = deque()
def enqueue(self, item):
self.items.append(item)
def dequeue(self):
return self.items.popleft()""",
"linked_list": """
class Node:
def __init__(self, val):
self.val = val
self.next = None
class LinkedList:
def __init__(self):
self.head = None
def append(self, val):
node = Node(val)
if not self.head:
self.head = node
return
curr = self.head
while curr.next:
curr = curr.next
curr.next = node""",
"binary_tree": """
class TreeNode:
def __init__(self, val):
self.val = val
self.left = None
self.right = None
def inorder(root):
if root:
yield from inorder(root.left)
yield root.val
yield from inorder(root.right)""",
"hash_map": """
class HashMap:
def __init__(self, size=256):
self.buckets = [[] for _ in range(size)]
def put(self, key, value):
idx = hash(key) % len(self.buckets)
for i, (k, v) in enumerate(self.buckets[idx]):
if k == key:
self.buckets[idx][i] = (key, value)
return
self.buckets[idx].append((key, value))""",
"priority_queue": """
import heapq
class PriorityQueue:
def __init__(self):
self.heap = []
def push(self, priority, item):
heapq.heappush(self.heap, (priority, item))
def pop(self):
return heapq.heappop(self.heap)[1]""",
},
}
def embed_code(code: str) -> torch.Tensor:
"""Embed code into a normalized vector."""
inputs = tokenizer(
code, return_tensors="pt", truncation=True, max_length=512, padding=True
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1)
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
return F.normalize(embedding, p=2, dim=1).squeeze(0)
# ── Step 1: Compute all embeddings ────────────────────────────────────────
print("Computing embeddings...")
all_names = []
all_labels = []
all_vectors = []
for group_name, snippets in groups.items():
for snippet_name, code in snippets.items():
vec = embed_code(code).cpu().numpy()
all_names.append(snippet_name)
all_labels.append(group_name)
all_vectors.append(vec)
print(f" [{group_name:12s}] {snippet_name}")
X = np.stack(all_vectors) # shape: [N, 768]
N, D = X.shape
print(f"\nEmbedding matrix: {N} snippets × {D} dimensions\n")
# ── Step 2: Define similarity metrics ─────────────────────────────────────
def cosine_matrix(vectors):
"""Compute pairwise cosine similarity for L2-normalized vectors."""
norms = np.linalg.norm(vectors, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-10)
normed = vectors / norms
return normed @ normed.T
def compute_metrics(sim_matrix, labels):
"""
Compute intra-group (same category) and inter-group (different category)
average similarities. The GAP between them measures discriminability.
"""
intra_sims = []
inter_sims = []
n = len(labels)
for i in range(n):
for j in range(i + 1, n):
if labels[i] == labels[j]:
intra_sims.append(sim_matrix[i, j])
else:
inter_sims.append(sim_matrix[i, j])
intra_mean = np.mean(intra_sims)
inter_mean = np.mean(inter_sims)
gap = intra_mean - inter_mean
return intra_mean, inter_mean, gap
# ── Step 3: Sweep across PCA dimensions ──────────────────────────────────
# PCA can have at most min(N, D) components; cap accordingly
max_components = min(N, D)
dims_to_test = sorted(set(
k for k in [2, 3, 5, 8, 10, 15, 20, 30, 50, 75, 100, 150, 200,
300, 400, 500, 600, D]
if k <= max_components
))
dims_to_test.append(D) # always include full dimensionality as baseline
print("=" * 70)
print("PCA DENOISING EXPERIMENT")
print("=" * 70)
print(f"\n{'Components':>12s} {'Intra-Group':>12s} {'Inter-Group':>12s} "
f"{'Gap':>8s} {'vs Full':>8s}")
print("-" * 62)
results = []
for k in dims_to_test:
if k >= D:
# Full dimensionality — no PCA, just use original vectors
X_reduced = X.copy()
actual_k = D
else:
pca = PCA(n_components=k, random_state=42)
X_reduced = pca.fit_transform(X)
actual_k = k
sim = cosine_matrix(X_reduced)
intra, inter, gap = compute_metrics(sim, all_labels)
results.append((actual_k, intra, inter, gap))
# Compute full-dim gap for comparison
full_intra, full_inter, full_gap = results[-1][1], results[-1][2], results[-1][3]
for k, intra, inter, gap in results:
delta = gap - full_gap
delta_str = f"{delta:+.4f}" if k < D else " (base)"
print(f"{k:>12d} {intra:>12.4f} {inter:>12.4f} {gap:>8.4f} {delta_str:>8s}")
# ── Step 4: Find the optimal dimensionality ──────────────────────────────
dims_arr = np.array([r[0] for r in results])
gaps_arr = np.array([r[3] for r in results])
best_idx = np.argmax(gaps_arr)
best_k, best_gap = int(dims_arr[best_idx]), gaps_arr[best_idx]
print(f"\n{'=' * 70}")
print(f"BEST DIMENSIONALITY: {best_k} components")
print(f" Gap (intra - inter): {best_gap:.4f} vs {full_gap:.4f} at full 768-d")
print(f" Improvement: {best_gap - full_gap:+.4f}")
print(f"{'=' * 70}")
# ── Step 5: Show detailed comparison at optimal k vs full ────────────────
print(f"\n── Detailed Similarity Matrix at k={best_k} vs k={D} ──\n")
if best_k < D:
pca_best = PCA(n_components=best_k, random_state=42)
X_best = pca_best.fit_transform(X)
else:
X_best = X.copy()
sim_full = cosine_matrix(X)
sim_best = cosine_matrix(X_best)
# Show a selection of interesting pairs
print(f"{'Snippet A':>20s} {'Snippet B':>20s} {'Full 768d':>10s} "
f"{'PCA {0}d'.format(best_k):>10s} {'Change':>8s}")
print("-" * 78)
interesting_pairs = [
# Intra-group: should be high
("bubble_sort", "quick_sort"),
("bubble_sort", "merge_sort"),
("read_json", "read_csv"),
("http_get", "fetch_url"),
("factorial", "fibonacci"),
("reverse_str", "is_palindrome"),
("stack_push_pop", "queue_impl"),
# Inter-group: should be low
("bubble_sort", "read_json"),
("factorial", "http_get"),
("reverse_str", "download_file"),
("is_prime", "write_file"),
("stack_push_pop", "count_vowels"),
]
for n1, n2 in interesting_pairs:
i = all_names.index(n1)
j = all_names.index(n2)
s_full = sim_full[i, j]
s_best = sim_best[i, j]
same = all_labels[i] == all_labels[j]
marker = "SAME" if same else "DIFF"
change = s_best - s_full
print(f"{n1:>20s} {n2:>20s} {s_full:>10.4f} {s_best:>10.4f} "
f"{change:>+8.4f} [{marker}]")
# ── Step 6: Text-to-code search comparison ────────────────────────────────
print(f"\n── Text-to-Code Search: Full 768d vs PCA {best_k}d ──\n")
search_queries = [
("sort a list of numbers", "Sorting"),
("read a JSON config file", "File I/O"),
("compute factorial recursively", "Math"),
("make an HTTP GET request", "Networking"),
("check if a number is prime", "Math"),
]
if best_k < D:
pca_search = PCA(n_components=best_k, random_state=42)
X_search = pca_search.fit_transform(X)
else:
X_search = X.copy()
pca_search = None
for query, expected_group in search_queries:
q_vec = embed_code(query).cpu().numpy().reshape(1, -1)
# Full dimension search
q_norm = q_vec / np.linalg.norm(q_vec)
X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)
scores_full = (X_norm @ q_norm.T).flatten()
# PCA-reduced search
if pca_search is not None:
q_reduced = pca_search.transform(q_vec)
else:
q_reduced = q_vec.copy()
q_r_norm = q_reduced / np.linalg.norm(q_reduced)
X_s_norm = X_search / np.linalg.norm(X_search, axis=1, keepdims=True)
scores_pca = (X_s_norm @ q_r_norm.T).flatten()
top_full = np.argsort(-scores_full)[:3]
top_pca = np.argsort(-scores_pca)[:3]
print(f' Query: "{query}"')
print(f' Full 768d: {all_names[top_full[0]]:>16s} ({scores_full[top_full[0]]:.3f})'
f' {all_names[top_full[1]]:>16s} ({scores_full[top_full[1]]:.3f})'
f' {all_names[top_full[2]]:>16s} ({scores_full[top_full[2]]:.3f})')
print(f' PCA {best_k:>3d}d: {all_names[top_pca[0]]:>16s} ({scores_pca[top_pca[0]]:.3f})'
f' {all_names[top_pca[1]]:>16s} ({scores_pca[top_pca[1]]:.3f})'
f' {all_names[top_pca[2]]:>16s} ({scores_pca[top_pca[2]]:.3f})')
full_correct = all_labels[top_full[0]] == expected_group
pca_correct = all_labels[top_pca[0]] == expected_group
print(f' Full correct: {full_correct} | PCA correct: {pca_correct}')
print()
# ── Step 7: Visualization ─────────────────────────────────────────────────
# Six-panel figure for a comprehensive visual analysis.
group_colors = {
"Sorting": "#1f77b4", "File I/O": "#ff7f0e", "Math": "#2ca02c",
"Networking": "#d62728", "String ops": "#9467bd", "Data structures": "#8c564b",
}
label_colors = [group_colors[g] for g in all_labels]
unique_groups = list(dict.fromkeys(all_labels))
fig = plt.figure(figsize=(20, 13))
fig.suptitle("PCA Denoising Analysis — Can Fewer Dimensions Improve Code Similarity?",
fontsize=15, fontweight="bold", y=0.98)
# ── Row 1 ──
# Plot 1: Intra/inter similarity vs number of PCA components
ax1 = fig.add_subplot(2, 3, 1)
dims_plot = [r[0] for r in results]
intra_plot = [r[1] for r in results]
inter_plot = [r[2] for r in results]
ax1.fill_between(dims_plot, inter_plot, intra_plot, alpha=0.15, color="tab:green")
ax1.plot(dims_plot, intra_plot, "o-", color="tab:blue", linewidth=2,
label="Intra-group (same category)", markersize=6)
ax1.plot(dims_plot, inter_plot, "s-", color="tab:red", linewidth=2,
label="Inter-group (different category)", markersize=6)
ax1.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
label=f"Best gap at k={best_k}")
ax1.set_xlabel("Number of PCA Components", fontsize=10)
ax1.set_ylabel("Average Cosine Similarity", fontsize=10)
ax1.set_title("(a) Intra- vs Inter-Group Similarity", fontsize=11, fontweight="bold")
ax1.legend(fontsize=7, loc="center right")
ax1.set_xscale("log")
ax1.grid(True, alpha=0.3)
# Plot 2: Gap (discriminability) vs number of PCA components
ax2 = fig.add_subplot(2, 3, 2)
gaps_plot = [r[3] for r in results]
ax2.plot(dims_plot, gaps_plot, "D-", color="tab:green", linewidth=2, markersize=7)
ax2.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
label=f"Best k={best_k} (gap={best_gap:.3f})")
ax2.axhline(y=full_gap, color="gray", linestyle=":", alpha=0.7,
label=f"Full 768d (gap={full_gap:.3f})")
ax2.fill_between(dims_plot, full_gap, gaps_plot, alpha=0.12, color="tab:green",
where=[g > full_gap for g in gaps_plot])
ax2.set_xlabel("Number of PCA Components", fontsize=10)
ax2.set_ylabel("Gap (Intra Inter)", fontsize=10)
ax2.set_title("(b) Discriminability vs Dimensionality", fontsize=11, fontweight="bold")
ax2.legend(fontsize=8)
ax2.set_xscale("log")
ax2.grid(True, alpha=0.3)
# Plot 3: Cumulative variance explained
pca_full = PCA(n_components=min(N, D), random_state=42)
pca_full.fit(X)
cumvar = np.cumsum(pca_full.explained_variance_ratio_) * 100
ax3 = fig.add_subplot(2, 3, 3)
ax3.plot(range(1, len(cumvar) + 1), cumvar, "-", color="tab:purple", linewidth=2)
ax3.axvline(x=best_k, color="green", linestyle="--", alpha=0.7,
label=f"Best k={best_k}")
for threshold in [90, 95, 99]:
k_thresh = np.searchsorted(cumvar, threshold) + 1
if k_thresh <= len(cumvar):
ax3.axhline(y=threshold, color="gray", linestyle=":", alpha=0.4)
ax3.annotate(f"{threshold}% → k={k_thresh}", xy=(k_thresh, threshold),
fontsize=8, color="gray", ha="left",
xytext=(k_thresh + 1, threshold - 2))
ax3.set_xlabel("Number of PCA Components", fontsize=10)
ax3.set_ylabel("Cumulative Variance Explained (%)", fontsize=10)
ax3.set_title("(c) Variance Concentration", fontsize=11, fontweight="bold")
ax3.legend(fontsize=8)
ax3.set_xscale("log")
ax3.grid(True, alpha=0.3)
# ── Row 2 ──
# Plot 4 & 5: Side-by-side heatmaps (full vs PCA-denoised)
# Sort indices by group for a block-diagonal structure
sorted_idx = sorted(range(N), key=lambda i: all_labels[i])
sorted_names = [all_names[i] for i in sorted_idx]
sorted_labels = [all_labels[i] for i in sorted_idx]
sim_full_sorted = sim_full[np.ix_(sorted_idx, sorted_idx)]
sim_best_sorted = sim_best[np.ix_(sorted_idx, sorted_idx)]
for panel_idx, (mat, title_str) in enumerate([
(sim_full_sorted, f"(d) Similarity Heatmap — Full 768d"),
(sim_best_sorted, f"(e) Similarity Heatmap — PCA {best_k}d (Denoised)"),
]):
ax = fig.add_subplot(2, 3, 4 + panel_idx)
im = ax.imshow(mat, cmap="RdBu_r", vmin=-1, vmax=1, aspect="auto")
ax.set_xticks(range(N))
ax.set_yticks(range(N))
ax.set_xticklabels(sorted_names, rotation=90, fontsize=5)
ax.set_yticklabels(sorted_names, fontsize=5)
# Draw group boundary lines
prev_label = sorted_labels[0]
for i, lab in enumerate(sorted_labels):
if lab != prev_label:
ax.axhline(y=i - 0.5, color="black", linewidth=1)
ax.axvline(x=i - 0.5, color="black", linewidth=1)
prev_label = lab
ax.set_title(title_str, fontsize=11, fontweight="bold")
plt.colorbar(im, ax=ax, shrink=0.8, label="Cosine Similarity")
# Plot 6: Bar chart comparing specific pairs at full vs PCA
ax6 = fig.add_subplot(2, 3, 6)
pair_labels = []
full_scores = []
pca_scores = []
pair_colors = []
for n1, n2 in interesting_pairs:
i = all_names.index(n1)
j = all_names.index(n2)
pair_labels.append(f"{n1}\nvs {n2}")
full_scores.append(sim_full[i, j])
pca_scores.append(sim_best[i, j])
pair_colors.append("#2ca02c" if all_labels[i] == all_labels[j] else "#d62728")
y_pos = np.arange(len(pair_labels))
bar_h = 0.35
bars_full = ax6.barh(y_pos + bar_h / 2, full_scores, bar_h, label="Full 768d",
color="tab:blue", alpha=0.7)
bars_pca = ax6.barh(y_pos - bar_h / 2, pca_scores, bar_h, label=f"PCA {best_k}d",
color="tab:orange", alpha=0.7)
# Color labels by same/different group
for i, (yl, col) in enumerate(zip(pair_labels, pair_colors)):
ax6.annotate("", xy=(-0.05, y_pos[i]), fontsize=10, color=col,
ha="right", va="center", fontweight="bold",
annotation_clip=False)
ax6.set_yticks(y_pos)
ax6.set_yticklabels(pair_labels, fontsize=6)
ax6.set_xlabel("Cosine Similarity", fontsize=10)
ax6.set_title("(f) Pair Comparison: Full vs PCA Denoised", fontsize=11, fontweight="bold")
ax6.legend(fontsize=8)
ax6.axvline(x=0, color="black", linewidth=0.5)
ax6.set_xlim(-1.1, 1.1)
ax6.grid(True, axis="x", alpha=0.3)
ax6.invert_yaxis()
# Custom legend for the dots
from matplotlib.lines import Line2D
dot_legend = [Line2D([0], [0], marker="o", color="w", markerfacecolor="#2ca02c",
markersize=8, label="Same group"),
Line2D([0], [0], marker="o", color="w", markerfacecolor="#d62728",
markersize=8, label="Different group")]
ax6.legend(handles=[bars_full, bars_pca] + dot_legend, fontsize=7, loc="lower right")
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig("pca_denoising_analysis.png", dpi=150, bbox_inches="tight")
print(f"\nSaved: pca_denoising_analysis.png")
# ── Summary ───────────────────────────────────────────────────────────────
print(f"""
{'=' * 70}
CONCLUSIONS
{'=' * 70}
1. VARIANCE CONCENTRATION:
The first few PCA components capture a disproportionate amount of
variance. This means the embedding space has low effective
dimensionality — most of the 768 dimensions are semi-redundant.
2. DENOISING EFFECT:
At k={best_k}, the gap between intra-group and inter-group similarity
is {best_gap:.4f} (vs {full_gap:.4f} at full 768d).
{'PCA denoising IMPROVED discriminability by removing noisy dimensions.' if best_gap > full_gap else 'Full dimensionality was already optimal for this dataset.'}
3. PRACTICAL IMPLICATIONS:
- For retrieval (code search), moderate PCA reduction can sharpen
results while also reducing storage and computation.
- Too few dimensions (k=2,3) lose important signal.
- Too many dimensions may retain noise that dilutes similarity.
- The "sweet spot" depends on the dataset and task.
4. TRADE-OFF:
PCA denoising is a post-hoc technique. Newer embedding models are
trained with Matryoshka Representation Learning (MRL) that makes
the FIRST k dimensions maximally informative by design.
""")