200 lines
6.6 KiB
Python
200 lines
6.6 KiB
Python
"""
|
||
============================================================================
|
||
Example 3: Cross-Language Code Similarity
|
||
============================================================================
|
||
AISE501 – AI in Software Engineering I
|
||
Fachhochschule Graubünden
|
||
|
||
GOAL:
|
||
Demonstrate that code embeddings capture FUNCTIONALITY, not syntax.
|
||
The same algorithm written in Python, JavaScript, Java, and C++
|
||
should produce similar embedding vectors — even though the surface
|
||
syntax is completely different.
|
||
|
||
WHAT YOU WILL LEARN:
|
||
- Code embedding models create a language-agnostic semantic space.
|
||
- Functionally equivalent code clusters together regardless of language.
|
||
- This enables cross-language code search (e.g., find the Java
|
||
equivalent of a Python function).
|
||
|
||
HARDWARE:
|
||
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
|
||
============================================================================
|
||
"""
|
||
|
||
import torch
|
||
from transformers import AutoTokenizer, AutoModel
|
||
import torch.nn.functional as F
|
||
|
||
# ── Device selection ──────────────────────────────────────────────────────
|
||
def get_device():
|
||
if torch.cuda.is_available():
|
||
return torch.device("cuda")
|
||
elif torch.backends.mps.is_available():
|
||
return torch.device("mps")
|
||
return torch.device("cpu")
|
||
|
||
DEVICE = get_device()
|
||
print(f"Using device: {DEVICE}\n")
|
||
|
||
# ── Load model ────────────────────────────────────────────────────────────
|
||
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
|
||
print(f"Loading model: {MODEL_NAME} ...")
|
||
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
||
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
|
||
model.eval()
|
||
print("Model loaded.\n")
|
||
|
||
# ── Same algorithm in four languages ──────────────────────────────────────
|
||
# Task A: Factorial — a simple recursive/iterative computation
|
||
# Task B: Reverse a string
|
||
# If embeddings are truly semantic, Task A functions should cluster together
|
||
# and Task B functions should cluster together, regardless of language.
|
||
|
||
code_snippets = {
|
||
# ── Task A: Factorial ──
|
||
"factorial_python": """
|
||
def factorial(n):
|
||
result = 1
|
||
for i in range(2, n + 1):
|
||
result *= i
|
||
return result
|
||
""",
|
||
"factorial_javascript": """
|
||
function factorial(n) {
|
||
let result = 1;
|
||
for (let i = 2; i <= n; i++) {
|
||
result *= i;
|
||
}
|
||
return result;
|
||
}
|
||
""",
|
||
"factorial_java": """
|
||
public static int factorial(int n) {
|
||
int result = 1;
|
||
for (int i = 2; i <= n; i++) {
|
||
result *= i;
|
||
}
|
||
return result;
|
||
}
|
||
""",
|
||
"factorial_cpp": """
|
||
int factorial(int n) {
|
||
int result = 1;
|
||
for (int i = 2; i <= n; i++) {
|
||
result *= i;
|
||
}
|
||
return result;
|
||
}
|
||
""",
|
||
|
||
# ── Task B: Reverse a string ──
|
||
"reverse_python": """
|
||
def reverse_string(s):
|
||
return s[::-1]
|
||
""",
|
||
"reverse_javascript": """
|
||
function reverseString(s) {
|
||
return s.split('').reverse().join('');
|
||
}
|
||
""",
|
||
"reverse_java": """
|
||
public static String reverseString(String s) {
|
||
return new StringBuilder(s).reverse().toString();
|
||
}
|
||
""",
|
||
"reverse_cpp": """
|
||
std::string reverseString(std::string s) {
|
||
std::reverse(s.begin(), s.end());
|
||
return s;
|
||
}
|
||
""",
|
||
}
|
||
|
||
|
||
def embed_code(code: str) -> torch.Tensor:
|
||
"""Embed code into a normalized vector."""
|
||
inputs = tokenizer(
|
||
code, return_tensors="pt", truncation=True, max_length=512, padding=True
|
||
).to(DEVICE)
|
||
with torch.no_grad():
|
||
outputs = model(**inputs)
|
||
mask = inputs["attention_mask"].unsqueeze(-1)
|
||
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
|
||
return F.normalize(embedding, p=2, dim=1).squeeze(0)
|
||
|
||
|
||
# ── Compute all embeddings ────────────────────────────────────────────────
|
||
print("Computing embeddings for all snippets...")
|
||
embeddings = {}
|
||
for name, code in code_snippets.items():
|
||
embeddings[name] = embed_code(code)
|
||
print(f"Done. {len(embeddings)} embeddings computed.\n")
|
||
|
||
# ── Compute similarity matrix ─────────────────────────────────────────────
|
||
names = list(embeddings.keys())
|
||
n = len(names)
|
||
|
||
print("=" * 70)
|
||
print("CROSS-LANGUAGE SIMILARITY MATRIX")
|
||
print("=" * 70)
|
||
|
||
# Print header (abbreviated names for readability)
|
||
short_names = [n.replace("factorial_", "F:").replace("reverse_", "R:") for n in names]
|
||
|
||
print(f"\n{'':14s}", end="")
|
||
for sn in short_names:
|
||
print(f"{sn:>10s}", end="")
|
||
print()
|
||
|
||
for i in range(n):
|
||
print(f"{short_names[i]:14s}", end="")
|
||
for j in range(n):
|
||
sim = torch.dot(embeddings[names[i]].cpu(), embeddings[names[j]].cpu()).item()
|
||
print(f"{sim:10.3f}", end="")
|
||
print()
|
||
|
||
# ── Compute average within-task and across-task similarities ──────────────
|
||
factorial_names = [n for n in names if "factorial" in n]
|
||
reverse_names = [n for n in names if "reverse" in n]
|
||
|
||
within_factorial = []
|
||
within_reverse = []
|
||
across_tasks = []
|
||
|
||
for i, n1 in enumerate(names):
|
||
for j, n2 in enumerate(names):
|
||
if i >= j:
|
||
continue
|
||
sim = torch.dot(embeddings[n1].cpu(), embeddings[n2].cpu()).item()
|
||
if n1 in factorial_names and n2 in factorial_names:
|
||
within_factorial.append(sim)
|
||
elif n1 in reverse_names and n2 in reverse_names:
|
||
within_reverse.append(sim)
|
||
else:
|
||
across_tasks.append(sim)
|
||
|
||
print("\n" + "=" * 70)
|
||
print("ANALYSIS")
|
||
print("=" * 70)
|
||
print(f"\nAvg similarity WITHIN factorial (across languages): "
|
||
f"{sum(within_factorial)/len(within_factorial):.3f}")
|
||
print(f"Avg similarity WITHIN reverse (across languages): "
|
||
f"{sum(within_reverse)/len(within_reverse):.3f}")
|
||
print(f"Avg similarity ACROSS tasks (factorial vs reverse): "
|
||
f"{sum(across_tasks)/len(across_tasks):.3f}")
|
||
|
||
print("""
|
||
EXPECTED RESULT:
|
||
Within-task similarity should be MUCH HIGHER than across-task similarity.
|
||
This proves that the embedding model groups code by WHAT IT DOES,
|
||
not by WHAT LANGUAGE it is written in.
|
||
|
||
factorial_python ≈ factorial_java ≈ factorial_cpp ≈ factorial_javascript
|
||
reverse_python ≈ reverse_java ≈ reverse_cpp ≈ reverse_javascript
|
||
factorial_* ≠ reverse_*
|
||
|
||
This is what enables cross-language code search: you can find a Java
|
||
implementation by providing a Python query, or vice versa.
|
||
""")
|