AISE501_CLASS/AST Files/ex04_dependency_graph.py
2026-05-07 17:26:41 +02:00

349 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Exercise 4 Full Dependency Graph with Visualisation
=====================================================
AISE501 · AST Exercises · Spring Semester 2026
Learning goals
--------------
* Combine all previous analyses into a comprehensive dependency graph.
* Track import dependencies (which modules does each class use?).
* Analyse the ``run_analysis_pipeline()`` function to discover how
classes are instantiated and wired together.
* Export the graph in DOT format and render it with Graphviz.
* (Optional) Use ``networkx`` and ``matplotlib`` for interactive display.
Tasks
-----
Part A Map each class to its external library calls (TODOs 1-2).
Part B Analyse run_analysis_pipeline for data flow (TODOs 3-5).
Part C Export to DOT format (TODOs 6-7).
Part D (Optional) Render with networkx + matplotlib (TODOs 8-9).
"""
import ast
from pathlib import Path
from collections import defaultdict
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
tree = ast.parse(source_code)
# ── Reusable helpers from previous exercises ───────────────────────────────
def extract_calls(func_node: ast.FunctionDef) -> list[dict]:
"""Return call descriptions inside *func_node*."""
calls = []
for node in ast.walk(func_node):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
obj_name = node.func.value.id
else:
obj_name = ast.unparse(node.func.value)
calls.append({
"type": "attribute",
"object": obj_name,
"method": node.func.attr,
})
elif isinstance(node.func, ast.Name):
calls.append({"type": "name", "name": node.func.id})
return calls
# Collect class info
class_info: dict[str, dict] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
methods = {}
for item in node.body:
if isinstance(item, ast.FunctionDef):
methods[item.name] = item
class_info[node.name] = {"node": node, "methods": methods}
# Collect module-level functions
module_functions: dict[str, ast.FunctionDef] = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef):
module_functions[node.name] = node
# ── Part A: External Library Dependencies Per Class ─────────────────────────
print("=" * 60)
print("Part A External library calls per class")
print("=" * 60)
# First, collect all imports to know which names are external modules.
# TODO 1: Walk tree.body and collect all imported module names.
#
# For ast.Import: names like "os", "csv", "json"
# For ast.ImportFrom: the module, e.g. "numpy", "scipy.stats"
#
# Also collect aliases: "import numpy as np" means "np" -> "numpy"
#
# Store in a dict: alias_to_module = {"np": "numpy", "stats": "scipy.stats", ...}
alias_to_module: dict[str, str] = {}
# TODO: iterate over tree.body and fill alias_to_module
for node in tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
key = alias.asname if alias.asname else alias.name
alias_to_module[key] = alias.name
elif isinstance(node, ast.ImportFrom):
for alias in node.names:
key = alias.asname if alias.asname else alias.name
alias_to_module[key] = node.module or ""
# TODO 2: For each class, find which external modules its methods call.
#
# For each method, look at extract_calls() results:
# - For attribute calls: check if the object name is in alias_to_module
# - Record the mapping: class_name -> set of module names
#
# Example: DescriptiveStats calls np.mean, np.var, stats.skew
# -> DescriptiveStats uses {"numpy", "scipy.stats"}
def get_external_deps(cls_name: str) -> set[str]:
"""Return the set of external module names used by *cls_name*."""
deps = set()
# TODO: implement
methods = class_info[cls_name]["methods"]
for method in methods.values():
func_calls = extract_calls(method)
for call in func_calls:
obj = call.get("object")
if not obj:
continue
dep = alias_to_module.get(str(obj))
if dep:
deps.add(dep)
return deps
for cls_name in class_info:
deps = get_external_deps(cls_name)
print(f"\n {cls_name}: {sorted(deps) if deps else '(none)'}")
# ── Part B: Analyse run_analysis_pipeline for Data Flow ────────────────────
# This function is the "glue" that creates objects and passes them between
# classes. We want to discover:
# - Which classes are instantiated
# - Which methods are called on those instances
# - How data flows: output of one call becomes input of another
print("\n" + "=" * 60)
print("Part B Data flow in run_analysis_pipeline")
print("=" * 60)
# TODO 3: Find the run_analysis_pipeline function node.
pipeline_func = None
# TODO: find it in tree.body
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "run_analysis_pipeline":
pipeline_func = node
# TODO 4: Walk the pipeline function and find all variable assignments.
# For each ast.Assign where the right-hand side is a Call:
# - Record what variable name receives the result
# - Record what class/function was called
#
# For example: cleaner = DataCleaner(raw_data)
# -> variable "cleaner" is assigned an instance of "DataCleaner"
#
# Store as: var_types = {"cleaner": "DataCleaner", "desc": "DescriptiveStats", ...}
var_types: dict[str, str] = {}
# TODO: implement by walking pipeline_func
if pipeline_func:
for node in pipeline_func.body:
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call):
for target in node.targets:
var_types[ast.unparse(target)] = ast.unparse(node.value.func)
for target, value in var_types.items():
print(f"\n {target}: {value}")
# TODO 5: Now trace method calls on those variables.
# For each attribute call (e.g. cleaner.remove_nans()):
# - Look up the variable in var_types to find its class
# - Record the edge: "run_analysis_pipeline" -> "DataCleaner.remove_nans"
#
# Build a list of edges: [(source, target), ...]
pipeline_edges: list[tuple[str, str]] = []
# TODO: implement
if pipeline_func:
calls = extract_calls(pipeline_func)
for call in calls:
if call["type"] == "attribute" and call["object"] in var_types.keys():
type_name = var_types[call["object"]]
# pipeline_edges.append(("run_analysis_pipeline", f"{type_name}.{call["method"]}"))
pipeline_edges.append(("run_analysis_pipeline", f"{call["method"]}"))
print("\n Data flow edges:")
for source, target in pipeline_edges:
print(f" {source} -> {target}")
# ── Part C: Export to DOT Format ────────────────────────────────────────────
# DOT is the graph description language used by Graphviz.
print("\n" + "=" * 60)
print("Part C Export to DOT format")
print("=" * 60)
# TODO 6: Collect ALL edges into a single list:
# - Internal calls (self.method -> self.other_method within a class)
# - Cross-class calls (from Exercise 3 logic)
# - Pipeline edges (from Part B)
# - External dependency edges (Class -> module)
#
# Use the format: (source_label, target_label, edge_type)
# where edge_type is one of: "internal", "cross_class", "pipeline", "external"
from ex03_method_call_graph import find_internal_calls, find_cross_class_calls
all_edges: list[tuple[str, str, str]] = []
# TODO: collect all edges
for cls_name, cls_info in class_info.items():
for method_name in cls_info["methods"].keys():
for internal in find_internal_calls(cls_name, method_name):
all_edges.append((cls_name, method_name, "internal"))
for cross_class in find_cross_class_calls(cls_name, method_name):
all_edges.append((cls_name, method_name, "cross_class"))
for pipeline_edge in pipeline_edges:
all_edges.append((pipeline_edge[0], pipeline_edge[1], "pipeline"))
for external_dep in get_external_deps(cls_name):
all_edges.append((cls_name, external_dep[1], "external"))
# TODO 7: Generate a DOT string and write it to "dependency_graph.dot".
#
# DOT format example:
# digraph G {
# rankdir=LR;
# node [shape=box, style=filled, fillcolor=lightblue];
# "DataCleaner.remove_nans" -> "DataCleaner.remove_outliers";
# "ReportGenerator.add_descriptive" -> "DescriptiveStats.full_report";
# }
#
# Use different colors for different edge types:
# internal -> black
# cross_class -> blue
# pipeline -> red
# external -> gray
def generate_dot(edges: list[tuple[str, str, str]]) -> str:
"""Return a DOT-format string for the dependency graph."""
# TODO: implement
EDGE_COLORS = {
"internal": "black",
"cross_class": "blue",
"pipeline": "red",
"external": "gray",
}
lines = [
"digraph G {",
" rankdir=LR;",
' node [shape=box, style=filled, fillcolor=lightblue];',
]
for source, target, edge_type in edges:
color = EDGE_COLORS.get(edge_type, "black")
lines.append(f' "{source}" -> "{target}" [color={color}];')
lines.append("}")
return "\n".join(lines) + "\n"
dot_string = generate_dot(all_edges)
dot_file = Path(__file__).parent / "dependency_graph.dot"
dot_file.write_text(dot_string)
print(f"\n Written to {dot_file}")
print(f" Render with: dot -Tpng dependency_graph.dot -o dependency_graph.png")
# ── Part D: (Optional) Render with networkx + matplotlib ───────────────────
print("\n" + "=" * 60)
print("Part D (Optional) Visualise with networkx")
print("=" * 60)
# TODO 8: Install networkx and matplotlib if not already available:
# pip install networkx matplotlib
#
# TODO 9: Build a networkx DiGraph from all_edges and render it.
#
import networkx as nx
import matplotlib.pyplot as plt
G = nx.DiGraph()
# Color map for edge types
edge_colors = {
"internal": "black",
"cross_class": "blue",
"pipeline": "red",
"external": "gray",
}
# Add edges with colors
for source, target, etype in all_edges:
G.add_edge(source, target, color=edge_colors.get(etype, "black"))
# Node colors: classes in light blue, functions in light green, modules in light gray
node_colors = []
for n in G.nodes():
if "." in n and n.split(".")[0] in class_info:
node_colors.append("lightblue")
elif n in module_functions:
node_colors.append("lightgreen")
else:
node_colors.append("lightgray")
# Draw
pos = nx.spring_layout(G, k=2, seed=42)
colors = [G[u][v]["color"] for u, v in G.edges()]
plt.figure(figsize=(16, 10))
nx.draw(G, pos,
with_labels=True,
node_color=node_colors,
edge_color=colors,
node_size=2000,
font_size=7,
arrowsize=15)
plt.title("Dependency Graph sample_stats.py")
plt.tight_layout()
plt.savefig(Path(__file__).parent / "dependency_graph.png", dpi=150)
plt.show()
print(" Saved dependency_graph.png")
print("\n (Uncomment the code above after installing networkx and matplotlib)")
# ── Expected Output ────────────────────────────────────────────────────────
# Part A: Each class maps to numpy, scipy.stats, scipy.optimize, etc.
# Part B: Pipeline shows DataCleaner -> DescriptiveStats -> HypothesisTester
# -> ReportGenerator chain.
# Part C: A .dot file with ~20-30 edges in four colours.
# Part D: A visual graph showing the full architecture of sample_stats.py.