""" ast_demo.py – Walkthrough: How Python's ast Module Works ========================================================= AISE501 · AST Exercises · Spring Semester 2026 This file demonstrates how to parse Python source code into an Abstract Syntax Tree (AST) and walk through it to find nodes and relationships. WHAT IS AN AST? An AST is a tree representation of source code. Every construct in Python (import, function definition, class, assignment, call, ...) becomes a node in the tree. Nodes have children: for example, a ClassDef node contains FunctionDef children (its methods). WHY USE THE ast MODULE? Python ships with a built-in ``ast`` module that can parse any valid Python source string into an AST. This lets you analyse code *without* running it (static analysis). Use cases include: - code linters and formatters - documentation generators - dependency analysis - AI-powered code search (RAG pipelines) HOW TO RUN THIS FILE: cd ast_exercises python ast_demo.py Read the printed output step by step. Each step introduces one core technique that you will need in the exercises. """ import ast from pathlib import Path # --------------------------------------------------------------------------- # Load the target source file. # We read it as a plain string -- ast.parse() works on strings, not files. # --------------------------------------------------------------------------- SOURCE_FILE = Path(__file__).parent / "sample_stats.py" source_code = SOURCE_FILE.read_text() # =========================================================================== # STEP 1: Parse source code into an AST # =========================================================================== # # ast.parse(source_string) takes a string of Python source code and returns # the root node of the AST. The root is always an ast.Module node. # # The Module node has a .body attribute -- a Python list containing one AST # node per top-level statement in the file. "Top-level" means statements # that are NOT inside a class or function. Examples: # - import statements -> ast.Import or ast.ImportFrom # - function definitions -> ast.FunctionDef # - class definitions -> ast.ClassDef # - if __name__ == "__main__" -> ast.If # - bare expressions (docstrings) -> ast.Expr # # Each node carries metadata: # - node.lineno : the line number in the original source (1-based) # - node.col_offset : the column offset (0-based) # - node-specific attributes (e.g. ClassDef has .name, .body, .bases, ...) print("=" * 70) print("STEP 1 – Parse source code into an AST") print("=" * 70) # Parse the source string into an AST. This does NOT execute the code. tree = ast.parse(source_code) print(f"\nParsed '{SOURCE_FILE.name}' successfully.") print(f"Root node type : {type(tree).__name__}") # -> "Module" print(f"Number of top-level statements: {len(tree.body)}\n") # Iterate over top-level statements and print their type and metadata. # We use isinstance() to check what kind of node each statement is, # then access node-specific attributes (.name for classes/functions, # .names for imports, etc.). for i, node in enumerate(tree.body): # getattr with a default handles nodes that lack .lineno (rare) line = getattr(node, "lineno", "?") # type(node).__name__ gives us the AST class name, e.g. "ClassDef" print(f" [{i:2d}] Line {line:>3}: {type(node).__name__}", end="") # Print extra info depending on the node type if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): # FunctionDef.name is the function name as a string print(f" -> function '{node.name}'", end="") elif isinstance(node, ast.ClassDef): # ClassDef.name is the class name as a string print(f" -> class '{node.name}'", end="") elif isinstance(node, (ast.Import, ast.ImportFrom)): # Import/ImportFrom have a .names list of ast.alias objects. # Each alias has .name (the real name) and .asname (the alias or None). # ImportFrom also has .module (e.g. "scipy.stats" in "from scipy.stats import ..."). names = [alias.name for alias in node.names] module = getattr(node, "module", None) if module: print(f" -> from {module} import {', '.join(names)}", end="") else: print(f" -> import {', '.join(names)}", end="") print() # newline # =========================================================================== # STEP 2: Walk the tree with ast.walk() # =========================================================================== # # ast.walk(node) is a generator that yields EVERY node in the subtree # rooted at `node`, in breadth-first order. It visits all children, # grandchildren, etc. -- the entire tree. # # This is the simplest traversal method. It is ideal when you want to # count or collect nodes of a certain type regardless of where they appear # in the tree (e.g. "how many function calls are there in total?"). # # Limitation: ast.walk() does NOT tell you the parent of each node. # If you need parent information, see Step 6. print("\n" + "=" * 70) print("STEP 2 – Walk the tree: count node types") print("=" * 70) # Count how often each AST node type appears in the entire tree. # type(node).__name__ gives us the class name (e.g. "Call", "Name", "Assign"). node_counts: dict[str, int] = {} for node in ast.walk(tree): name = type(node).__name__ node_counts[name] = node_counts.get(name, 0) + 1 # Print the 15 most frequent node types. # Common types you will see: # Name - every variable reference (e.g. `x`, `self`, `np`) # Load - context marker: the name is being read (not assigned) # Store - context marker: the name is being assigned to # Call - a function/method call (e.g. np.mean(...)) # Attribute - a dotted access (e.g. self.data, np.array) # Constant - a literal value (number, string, None, True, False) # Assign - an assignment statement (x = ...) # Return - a return statement print("\nTop 15 most frequent node types:") for name, count in sorted(node_counts.items(), key=lambda x: -x[1])[:15]: print(f" {name:25s} {count:4d}") # =========================================================================== # STEP 3: Use ast.NodeVisitor for targeted traversal # =========================================================================== # # While ast.walk() visits everything, ast.NodeVisitor lets you define # handler methods for specific node types. The naming convention is: # # def visit_(self, node): # ... # your logic # self.generic_visit(node) # continue to children # # When you call visitor.visit(tree), the visitor walks the tree and # dispatches to the matching visit_* method for each node. # # IMPORTANT: If you forget to call self.generic_visit(node) at the end # of your visit_* method, the visitor will NOT recurse into that node's # children. This is a common mistake! # # Use case here: find every class and list its methods. print("\n" + "=" * 70) print("STEP 3 – NodeVisitor: find all classes and their methods") print("=" * 70) class ClassMethodVisitor(ast.NodeVisitor): """Visit every ClassDef and collect its method names. How it works: 1. The visitor walks the tree starting from the Module root. 2. When it encounters a ClassDef node, visit_ClassDef is called. 3. We iterate over the class body (node.body) and pick out FunctionDef nodes -- these are the class's methods. 4. We store them in a dict: {class_name: [method_name, ...]}. 5. We call generic_visit to allow recursion into nested classes. """ def __init__(self): # Maps class name -> list of method names self.classes: dict[str, list[str]] = {} def visit_ClassDef(self, node: ast.ClassDef): # node.body is a list of AST nodes inside the class. # Methods are FunctionDef nodes; there can also be Assign nodes # (class-level attributes), Expr nodes (docstrings), etc. methods = [] for item in node.body: if isinstance(item, ast.FunctionDef): # item.name is the method name as a string methods.append(item.name) self.classes[node.name] = methods # IMPORTANT: call generic_visit to recurse into nested classes # (classes defined inside this class). Without this call, nested # classes would be silently skipped. self.generic_visit(node) # Create an instance of the visitor and run it on the tree. visitor = ClassMethodVisitor() visitor.visit(tree) # Print the results: each class and its methods. print() for cls_name, methods in visitor.classes.items(): print(f" class {cls_name}:") for m in methods: print(f" - {m}()") print() # =========================================================================== # STEP 4: Find function calls inside a method # =========================================================================== # # Function calls in the AST are represented by ast.Call nodes. # A Call node has a .func attribute that tells us WHAT is being called: # # Simple call: len(data) # -> Call(func=Name(id='len'), args=[Name(id='data')]) # # Attribute call: np.mean(data) # -> Call(func=Attribute(value=Name(id='np'), attr='mean'), ...) # # Chained call: self.sections.append(...) # -> Call(func=Attribute( # value=Attribute(value=Name(id='self'), attr='sections'), # attr='append'), ...) # # To reconstruct the full dotted name (e.g. "self.sections.append"), we # walk up the chain of Attribute nodes until we reach a Name node. print("=" * 70) print("STEP 4 – Find function calls inside a specific method") print("=" * 70) class CallFinder(ast.NodeVisitor): """Collect all function/method calls inside a node. For each ast.Call found, reconstruct the callable's name: - Simple call (Name): e.g. "len", "print", "DataCleaner" - Attribute call (obj.attr): e.g. "self.remove_nans", "np.mean" - Chained attribute call: e.g. "self.sections.append" """ def __init__(self): self.calls: list[str] = [] def visit_Call(self, node: ast.Call): # Case 1: Simple name call, e.g. len(), print(), DataCleaner() # node.func is an ast.Name with .id = the function name if isinstance(node.func, ast.Name): self.calls.append(node.func.id) # Case 2: Attribute call, e.g. self.remove_nans(), np.mean() # node.func is an ast.Attribute with .attr = the method name # and .value = the object (which can itself be an Attribute for chains) elif isinstance(node.func, ast.Attribute): # Walk the chain of Attribute nodes to reconstruct the full # dotted name. For example, for self.sections.append: # Attribute(attr='append', # value=Attribute(attr='sections', # value=Name(id='self'))) # We collect: ['append', 'sections', 'self'] then reverse. parts = [] current = node.func while isinstance(current, ast.Attribute): parts.append(current.attr) # collect each .attr current = current.value # move to the parent object if isinstance(current, ast.Name): parts.append(current.id) # the root object name # Reverse to get the natural order: "self.sections.append" self.calls.append(".".join(reversed(parts))) # Continue visiting children -- a Call can contain other Calls # in its arguments. For example: print(len(data)) has two Calls. self.generic_visit(node) # Find the run_analysis_pipeline function in the tree and extract its calls. # We use ast.walk() to search for the specific FunctionDef by name. for node in ast.walk(tree): if isinstance(node, ast.FunctionDef) and node.name == "run_analysis_pipeline": finder = CallFinder() finder.visit(node) # visit only THIS function's subtree print(f"\nCalls inside 'run_analysis_pipeline()':") for call in finder.calls: print(f" -> {call}()") break # stop after finding the first match # =========================================================================== # STEP 5: Inspect function signatures (arguments) # =========================================================================== # # Every FunctionDef node has an .args attribute of type ast.arguments. # This contains: # .args : list of ast.arg -- the positional parameters # .vararg : ast.arg or None -- the *args parameter # .kwonlyargs : list of ast.arg -- keyword-only parameters # .kwarg : ast.arg or None -- the **kwargs parameter # .defaults : list of default values for positional args # .kw_defaults: list of default values for keyword-only args # # Each ast.arg has: # .arg : str -- the parameter name (e.g. "self", "data", "x") # .annotation : AST node or None -- the type annotation # # FunctionDef also has: # .returns : AST node or None -- the return type annotation # # To convert an annotation node back to readable Python, use: # ast.unparse(node) -- returns a string like "np.ndarray" # ast.dump(node) -- returns the raw AST representation (for debugging) print("\n" + "=" * 70) print("STEP 5 – Extract function signatures") print("=" * 70) print() for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): # node.args is the ast.arguments object args = node.args # Extract parameter names from the list of ast.arg objects param_names = [arg.arg for arg in args.args] # Check for type annotations on each parameter. # arg.annotation is None if no annotation, otherwise an AST node. # We use ast.dump() here for demonstration; in the exercises you # will use ast.unparse() for cleaner output. annotations = {} for arg in args.args: if arg.annotation: annotations[arg.arg] = ast.dump(arg.annotation) # Check for return type annotation. # node.returns is None if there is no "-> ..." annotation. ret = ast.dump(node.returns) if node.returns else "None" print(f" def {node.name}({', '.join(param_names)}) -> {ret}") print() # =========================================================================== # STEP 6: Linking nodes – parent-child relationships # =========================================================================== # # A common question: "Which class does this method belong to?" # The AST does NOT store a .parent reference on each node by default. # But we can add one ourselves using a simple pre-processing step: # # for node in ast.walk(tree): # for child in ast.iter_child_nodes(node): # child._parent = node # # ast.iter_child_nodes(node) yields the direct children of a node. # By setting child._parent = node, we create a back-link from each # child to its parent. # # After this, for any FunctionDef node we can check: # - Is its parent a ClassDef? -> it's a method # - Is its parent a Module? -> it's a module-level function # - Is its parent a FunctionDef? -> it's a nested function print("=" * 70) print("STEP 6 – Parent-child relationships") print("=" * 70) # Pre-processing: annotate every node with a _parent reference. # We use a private attribute name (_parent) by convention because # this is not part of the official AST API -- we are adding it ourselves. for node in ast.walk(tree): for child in ast.iter_child_nodes(node): child._parent = node # type: ignore[attr-defined] # Now find every FunctionDef and report whether it is a method (parent is # ClassDef) or a module-level function (parent is Module). print("\nMethod -> Parent mapping:") for node in ast.walk(tree): if isinstance(node, ast.FunctionDef): # Retrieve the parent we stored above parent = getattr(node, "_parent", None) parent_type = type(parent).__name__ if parent else "?" parent_name = getattr(parent, "name", "") if isinstance(parent, ast.ClassDef): # This FunctionDef is a method of a class print(f" {parent_name}.{node.name}() [method of class]") elif isinstance(parent, ast.Module): # This FunctionDef is a top-level function print(f" {node.name}() [module-level function]") # (We skip nested functions for brevity -- they would have a # FunctionDef as their parent.) print("\n" + "=" * 70) print("DEMO COMPLETE – You are now ready to start the exercises!") print("=" * 70)