235 lines
9.0 KiB
Python
235 lines
9.0 KiB
Python
"""
|
||
Exercise 3 – Build a Method Call Graph Between Classes
|
||
======================================================
|
||
AISE501 · AST Exercises · Spring Semester 2026
|
||
|
||
Learning goals
|
||
--------------
|
||
* Resolve method calls (``self.method()``, ``obj.method()``) to their
|
||
owning class using AST analysis.
|
||
* Build a call graph that records which methods call which other methods.
|
||
* Detect cross-class calls (e.g. ``ReportGenerator.add_descriptive``
|
||
receives a ``DescriptiveStats`` object and calls ``desc.full_report()``).
|
||
* Output the call graph as an adjacency list.
|
||
|
||
Tasks
|
||
-----
|
||
Part A Find all method calls within each method (TODOs 1-3).
|
||
Part B Resolve self.method() calls within the same class (TODOs 4-5).
|
||
Part C Detect cross-class calls using constructor arguments (TODOs 6-7).
|
||
Part D Print the full call graph as an adjacency list (TODO 8).
|
||
"""
|
||
|
||
import ast
|
||
from pathlib import Path
|
||
from collections import defaultdict
|
||
|
||
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
|
||
source_code = SOURCE_FILE.read_text()
|
||
|
||
tree = ast.parse(source_code)
|
||
|
||
|
||
# ── Helper: collect class information ───────────────────────────────────────
|
||
|
||
class_info: dict[str, dict] = {}
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.ClassDef):
|
||
methods = {}
|
||
for item in node.body:
|
||
if isinstance(item, ast.FunctionDef):
|
||
methods[item.name] = item
|
||
class_info[node.name] = {"node": node, "methods": methods}
|
||
|
||
|
||
# ── Part A: Find All Calls Within Each Method ──────────────────────────────
|
||
|
||
print("=" * 60)
|
||
print("Part A – Raw calls inside each method")
|
||
print("=" * 60)
|
||
|
||
# TODO 1: Write a function `extract_calls(func_node)` that returns a list
|
||
# of call descriptions found inside *func_node*.
|
||
#
|
||
# For each ast.Call node found (use ast.walk):
|
||
# - If func is ast.Attribute (e.g. self.remove_nans(), np.mean()):
|
||
# record {"type": "attribute", "object": <name>, "method": <attr>}
|
||
# - If func is ast.Name (e.g. print(), len()):
|
||
# record {"type": "name", "name": <id>}
|
||
#
|
||
# To get the object name for Attribute calls:
|
||
# - If node.func.value is ast.Name -> use .id (e.g. "self", "np")
|
||
# - Otherwise use ast.unparse(node.func.value) as a fallback
|
||
|
||
def extract_calls(func_node: ast.FunctionDef) -> list[dict]:
|
||
"""Return a list of call descriptions inside *func_node*."""
|
||
calls = []
|
||
# TODO: implement using ast.walk(func_node) and isinstance checks
|
||
for call in [n for n in ast.walk(func_node) if isinstance(n, ast.Call)]:
|
||
if isinstance(call.func, ast.Attribute):
|
||
if isinstance(call.func.value, ast.Name):
|
||
obj_name = call.func.value.id
|
||
else:
|
||
obj_name = ast.unparse(call.func.value)
|
||
|
||
calls.append({
|
||
"type": "attribute",
|
||
"object": obj_name,
|
||
"method": call.func.attr
|
||
})
|
||
|
||
elif isinstance(call.func, ast.Name):
|
||
calls.append({
|
||
"type": "name",
|
||
"name": call.func.id
|
||
})
|
||
|
||
return calls
|
||
|
||
|
||
# TODO 2: For each class and method, print the extracted calls.
|
||
|
||
for cls_name, info in class_info.items():
|
||
print(f"\n class {cls_name}:")
|
||
for method_name, method_node in info["methods"].items():
|
||
calls = extract_calls(method_node)
|
||
print(f" {method_name}():")
|
||
for c in calls:
|
||
if c["type"] == "attribute":
|
||
print(f" -> {c['object']}.{c['method']}()")
|
||
else:
|
||
print(f" -> {c['name']}()")
|
||
|
||
|
||
# ── Part B: Resolve self.method() Calls ────────────────────────────────────
|
||
# When a method calls self.xxx(), that is an internal call to another
|
||
# method of the same class.
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Part B – Internal calls (self.method())")
|
||
print("=" * 60)
|
||
|
||
# TODO 3: Write a function `find_internal_calls(cls_name, method_name)`
|
||
# that returns a list of method names called via self.xxx()
|
||
# where xxx is also a method of the same class.
|
||
#
|
||
# Approach:
|
||
# 1. Get the method node from class_info[cls_name]["methods"][method_name]
|
||
# 2. Call extract_calls() on it
|
||
# 3. Filter for calls where type=="attribute" and object=="self"
|
||
# 4. Further filter: check that the called method name exists in the
|
||
# same class's method list.
|
||
|
||
def find_internal_calls(cls_name: str, method_name: str) -> list[str]:
|
||
"""Return method names called as self.xxx() within the same class."""
|
||
# TODO: implement
|
||
class_methods = class_info[cls_name]["methods"]
|
||
method = class_methods[method_name]
|
||
|
||
self_calls = [call["method"] for call in extract_calls(method)
|
||
if call["type"] == "attribute" and call["object"] == "self"
|
||
and call["method"] in class_methods.keys()]
|
||
|
||
return self_calls
|
||
|
||
|
||
# TODO 4: Print internal call edges for each class.
|
||
|
||
for cls_name, info in class_info.items():
|
||
print(f"\n class {cls_name}:")
|
||
for method_name in info["methods"]:
|
||
internal = find_internal_calls(cls_name, method_name)
|
||
if internal:
|
||
for target in internal:
|
||
print(f" {method_name}() -> self.{target}()")
|
||
|
||
|
||
# ── Part C: Detect Cross-Class Calls ──────────────────────────────────────
|
||
# Cross-class calls occur when a method receives an object of another class
|
||
# and calls a method on it. For example:
|
||
# def add_descriptive(self, desc: DescriptiveStats) -> None:
|
||
# self.sections.append({"content": desc.full_report()})
|
||
# Here, desc.full_report() is a cross-class call.
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Part C – Cross-class calls")
|
||
print("=" * 60)
|
||
|
||
# TODO 5: Write a function `get_param_types(func_node)` that returns a dict
|
||
# mapping parameter names to their annotation type names (as strings).
|
||
#
|
||
# For each arg in func_node.args.args (skip 'self'):
|
||
# - If arg.annotation is an ast.Name, use .id
|
||
# - If arg.annotation is an ast.Attribute, use ast.unparse()
|
||
# - Otherwise skip it
|
||
#
|
||
# Example: for add_descriptive(self, desc: DescriptiveStats)
|
||
# return {"desc": "DescriptiveStats"}
|
||
|
||
def get_param_types(func_node: ast.FunctionDef) -> dict[str, str]:
|
||
"""Map parameter name -> annotation type name."""
|
||
# TODO: implement
|
||
return {}
|
||
|
||
|
||
# TODO 6: Write a function `find_cross_class_calls(cls_name, method_name)`
|
||
# that returns a list of tuples: (target_class, target_method).
|
||
#
|
||
# Approach:
|
||
# 1. Get parameter types via get_param_types().
|
||
# 2. Get all calls via extract_calls().
|
||
# 3. For each attribute call where the object name matches a parameter
|
||
# whose type is a known class (in class_info), record the edge.
|
||
|
||
def find_cross_class_calls(cls_name: str, method_name: str) -> list[tuple[str, str]]:
|
||
"""Return [(target_class, target_method), ...] for cross-class calls."""
|
||
# TODO: implement
|
||
return []
|
||
|
||
|
||
# TODO 7: Print cross-class call edges.
|
||
|
||
# for cls_name, info in class_info.items():
|
||
# for method_name in info["methods"]:
|
||
# cross = find_cross_class_calls(cls_name, method_name)
|
||
# for target_cls, target_method in cross:
|
||
# print(f" {cls_name}.{method_name}() -> {target_cls}.{target_method}()")
|
||
|
||
|
||
# ── Part D: Full Call Graph as Adjacency List ──────────────────────────────
|
||
|
||
print("\n" + "=" * 60)
|
||
print("Part D – Full call graph (adjacency list)")
|
||
print("=" * 60)
|
||
|
||
# TODO 8: Combine internal and cross-class calls into a single adjacency
|
||
# list (dict mapping "Class.method" -> list of "Class.method").
|
||
# Include module-level functions calling class methods too.
|
||
#
|
||
# Also analyse run_analysis_pipeline() which instantiates classes
|
||
# and calls their methods.
|
||
|
||
# call_graph = defaultdict(list)
|
||
#
|
||
# # Add internal calls
|
||
# for cls_name, info in class_info.items():
|
||
# for method_name in info["methods"]:
|
||
# source = f"{cls_name}.{method_name}"
|
||
# for target in find_internal_calls(cls_name, method_name):
|
||
# call_graph[source].append(f"{cls_name}.{target}")
|
||
# for target_cls, target_method in find_cross_class_calls(cls_name, method_name):
|
||
# call_graph[source].append(f"{target_cls}.{target_method}")
|
||
#
|
||
# # Print the graph
|
||
# for source, targets in sorted(call_graph.items()):
|
||
# print(f" {source}")
|
||
# for t in targets:
|
||
# print(f" -> {t}")
|
||
|
||
|
||
# ── Expected Output (key edges) ────────────────────────────────────────────
|
||
# DataCleaner.remove_outliers -> DataCleaner.remove_nans (internal)
|
||
# ReportGenerator.add_descriptive -> DescriptiveStats.full_report (cross-class)
|
||
# ReportGenerator.add_hypothesis -> HypothesisTester.interpret_result (cross-class)
|
||
# ReportGenerator.add_curve_fit -> CurveFitter.r_squared (cross-class)
|