AISE1_CLASS/Code embeddings/04_clone_detection.py

238 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
============================================================================
Example 4: Code Clone Detection
============================================================================
AISE501 AI in Software Engineering I
Fachhochschule Graubünden
GOAL:
Detect code clones (duplicate/similar code) in a collection of
functions using embeddings. We simulate a real-world scenario
where a codebase contains multiple near-duplicate implementations
that should be refactored into a single function.
WHAT YOU WILL LEARN:
- The four types of code clones (Type 14)
- How embeddings detect clones that text-based tools miss
- Ranking-based clone detection via cosine similarity
- Practical application: finding refactoring opportunities
CLONE TYPES:
Type 1: Exact copy (trivial — grep can find these)
Type 2: Renamed variables (grep misses these)
Type 3: Modified structure (added/removed lines)
Type 4: Same functionality, completely different implementation
HARDWARE:
Works on CPU, CUDA (NVIDIA), and MPS (Apple Silicon Mac).
============================================================================
"""
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from itertools import combinations
# ── Device selection ──────────────────────────────────────────────────────
def get_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
DEVICE = get_device()
print(f"Using device: {DEVICE}\n")
# ── Load model ────────────────────────────────────────────────────────────
MODEL_NAME = "flax-sentence-embeddings/st-codesearch-distilroberta-base"
print(f"Loading model: {MODEL_NAME} ...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
print("Model loaded.\n")
# ── Simulated codebase ────────────────────────────────────────────────────
# These functions simulate what you'd find in a messy, real-world codebase
# where different developers wrote similar functionality independently.
#
# IMPORTANT: The clone groups share ZERO common words (besides Python
# keywords). This demonstrates that embeddings capture semantics, not
# surface-level text overlap. grep would never find these.
codebase = {
# ── Clone group 1: Computing the maximum of a list ──
# Three completely different implementations — no shared identifiers,
# no shared structure, but identical purpose.
"utils/find_max.py": """
def find_max(numbers):
result = numbers[0]
for candidate in numbers[1:]:
if candidate > result:
result = candidate
return result
""",
"legacy/find_max_old.py": """
def find_max(numbers):
result = numbers[0]
for candidate in numbers[1:]:
if candidate > result:
result = candidate
return result
""",
"analytics/top_scorer.py": """
import heapq
def fetch_top_element(collection):
return heapq.nlargest(1, collection)[0]
""",
"stats/dominant_value.py": """
def extract_peak(dataset):
dataset = sorted(dataset, reverse=True)
return dataset[0]
""",
# ── Clone group 2: String reversal ──
# Two implementations with zero lexical overlap — slicing vs index-based.
"text/flip_text.py": """
def flip_text(content):
return content[::-1]
""",
"helpers/mirror.py": """
def mirror_characters(phrase):
output = []
idx = len(phrase) - 1
while idx >= 0:
output.append(phrase[idx])
idx -= 1
return ''.join(output)
""",
# ── Not a clone: completely different functionality ──
# Each uses a different Python construct and domain to ensure
# they don't cluster with each other or with the clone groups.
"math/square_root.py": """
def square_root(x):
return x ** 0.5
""",
"calendar/leap_year.py": """
def is_leap_year(year):
return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)
""",
"formatting/currency.py": """
def format_currency(amount, symbol="$"):
return f"{symbol}{amount:,.2f}"
""",
}
def embed_code(code: str) -> torch.Tensor:
"""Embed code into a normalized vector."""
inputs = tokenizer(
code, return_tensors="pt", truncation=True, max_length=512, padding=True
).to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
mask = inputs["attention_mask"].unsqueeze(-1)
embedding = (outputs.last_hidden_state * mask).sum(dim=1) / mask.sum(dim=1)
return F.normalize(embedding, p=2, dim=1).squeeze(0)
# ── Embed all functions ───────────────────────────────────────────────────
print("Embedding all functions in the codebase...")
embeddings = {}
for path, code in codebase.items():
embeddings[path] = embed_code(code)
print(f" {path}")
print()
# ── Compute pairwise similarity matrix ────────────────────────────────────
paths = list(embeddings.keys())
n = len(paths)
def short_name(path):
"""Extract a readable label from the file path."""
return path.split("/")[-1].replace(".py", "")
labels = [short_name(p) for p in paths]
sim_matrix = {}
for i in range(n):
for j in range(n):
sim = torch.dot(embeddings[paths[i]].cpu(), embeddings[paths[j]].cpu()).item()
sim_matrix[(i, j)] = sim
# ── Print similarity matrix ───────────────────────────────────────────────
col_w = max(len(l) for l in labels) + 2
header_w = col_w
print("=" * 70)
print("SIMILARITY MATRIX")
print("=" * 70)
print(f"\n{'':>{header_w}}", end="")
for label in labels:
print(f"{label:>{col_w}}", end="")
print()
for i in range(n):
print(f"{labels[i]:>{header_w}}", end="")
for j in range(n):
print(f"{sim_matrix[(i, j)]:>{col_w}.3f}", end="")
print()
# ── Most similar match per function ───────────────────────────────────────
print()
print(f"{'BEST MATCH':>{header_w}}", end="")
for i in range(n):
best_j, best_sim = -1, -1.0
for j in range(n):
if i != j and sim_matrix[(i, j)] > best_sim:
best_sim = sim_matrix[(i, j)]
best_j = j
print(f"{labels[best_j]:>{col_w}}", end="")
print()
print(f"{'(similarity)':>{header_w}}", end="")
for i in range(n):
best_sim = max(sim_matrix[(i, j)] for j in range(n) if i != j)
print(f"{best_sim:>{col_w}.3f}", end="")
print()
print(f"""
{'=' * 70}
INTERPRETATION:
{'=' * 70}
HOW TO READ THE TABLE:
Each cell shows the cosine similarity between two functions.
The BEST MATCH row shows which other function is most similar
to each column — these are the clone candidates a developer
would investigate.
EXPECTED CLONE GROUPS:
1. find_max ↔ find_max_old (Type 1: exact copy)
→ Similarity ≈ 1.000
2. find_max / fetch_top_element / extract_peak (Type 4 clones)
→ Same purpose (find the largest value), completely different
code: for-loop vs heapq.nlargest() vs sorted(reverse=True)
→ Zero shared identifiers between implementations
3. flip_text ↔ mirror_characters (Type 4 clone)
→ Same purpose (reverse a string), completely different code:
slicing ([::-1]) vs while-loop with index
→ Zero shared identifiers
NON-CLONES:
square_root, is_leap_year, format_currency each use a different
domain and code structure. Their best matches should have low
similarity compared to the clone groups.
KEY INSIGHT:
The clone groups share NO common words (besides Python keywords
like def/return/if). grep or any text-matching tool would never
find these clones. Only semantic understanding — which is what
embeddings provide — can detect that these functions do the same
thing despite having completely different code.
""")