411 lines
16 KiB
Python
411 lines
16 KiB
Python
"""
|
||
ast_demo.py – Walkthrough: How Python's ast Module Works
|
||
=========================================================
|
||
AISE501 · AST Exercises · Spring Semester 2026
|
||
|
||
This file demonstrates how to parse Python source code into an Abstract
|
||
Syntax Tree (AST) and walk through it to find nodes and relationships.
|
||
|
||
WHAT IS AN AST?
|
||
An AST is a tree representation of source code. Every construct in
|
||
Python (import, function definition, class, assignment, call, ...)
|
||
becomes a node in the tree. Nodes have children: for example, a
|
||
ClassDef node contains FunctionDef children (its methods).
|
||
|
||
WHY USE THE ast MODULE?
|
||
Python ships with a built-in ``ast`` module that can parse any valid
|
||
Python source string into an AST. This lets you analyse code
|
||
*without* running it (static analysis). Use cases include:
|
||
- code linters and formatters
|
||
- documentation generators
|
||
- dependency analysis
|
||
- AI-powered code search (RAG pipelines)
|
||
|
||
HOW TO RUN THIS FILE:
|
||
cd ast_exercises
|
||
python ast_demo.py
|
||
|
||
Read the printed output step by step. Each step introduces one core
|
||
technique that you will need in the exercises.
|
||
"""
|
||
|
||
import ast
|
||
from pathlib import Path
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Load the target source file.
|
||
# We read it as a plain string -- ast.parse() works on strings, not files.
|
||
# ---------------------------------------------------------------------------
|
||
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
|
||
source_code = SOURCE_FILE.read_text()
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 1: Parse source code into an AST
|
||
# ===========================================================================
|
||
#
|
||
# ast.parse(source_string) takes a string of Python source code and returns
|
||
# the root node of the AST. The root is always an ast.Module node.
|
||
#
|
||
# The Module node has a .body attribute -- a Python list containing one AST
|
||
# node per top-level statement in the file. "Top-level" means statements
|
||
# that are NOT inside a class or function. Examples:
|
||
# - import statements -> ast.Import or ast.ImportFrom
|
||
# - function definitions -> ast.FunctionDef
|
||
# - class definitions -> ast.ClassDef
|
||
# - if __name__ == "__main__" -> ast.If
|
||
# - bare expressions (docstrings) -> ast.Expr
|
||
#
|
||
# Each node carries metadata:
|
||
# - node.lineno : the line number in the original source (1-based)
|
||
# - node.col_offset : the column offset (0-based)
|
||
# - node-specific attributes (e.g. ClassDef has .name, .body, .bases, ...)
|
||
|
||
print("=" * 70)
|
||
print("STEP 1 – Parse source code into an AST")
|
||
print("=" * 70)
|
||
|
||
# Parse the source string into an AST. This does NOT execute the code.
|
||
tree = ast.parse(source_code)
|
||
|
||
print(f"\nParsed '{SOURCE_FILE.name}' successfully.")
|
||
print(f"Root node type : {type(tree).__name__}") # -> "Module"
|
||
print(f"Number of top-level statements: {len(tree.body)}\n")
|
||
|
||
# Iterate over top-level statements and print their type and metadata.
|
||
# We use isinstance() to check what kind of node each statement is,
|
||
# then access node-specific attributes (.name for classes/functions,
|
||
# .names for imports, etc.).
|
||
for i, node in enumerate(tree.body):
|
||
# getattr with a default handles nodes that lack .lineno (rare)
|
||
line = getattr(node, "lineno", "?")
|
||
|
||
# type(node).__name__ gives us the AST class name, e.g. "ClassDef"
|
||
print(f" [{i:2d}] Line {line:>3}: {type(node).__name__}", end="")
|
||
|
||
# Print extra info depending on the node type
|
||
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||
# FunctionDef.name is the function name as a string
|
||
print(f" -> function '{node.name}'", end="")
|
||
|
||
elif isinstance(node, ast.ClassDef):
|
||
# ClassDef.name is the class name as a string
|
||
print(f" -> class '{node.name}'", end="")
|
||
|
||
elif isinstance(node, (ast.Import, ast.ImportFrom)):
|
||
# Import/ImportFrom have a .names list of ast.alias objects.
|
||
# Each alias has .name (the real name) and .asname (the alias or None).
|
||
# ImportFrom also has .module (e.g. "scipy.stats" in "from scipy.stats import ...").
|
||
names = [alias.name for alias in node.names]
|
||
module = getattr(node, "module", None)
|
||
if module:
|
||
print(f" -> from {module} import {', '.join(names)}", end="")
|
||
else:
|
||
print(f" -> import {', '.join(names)}", end="")
|
||
|
||
print() # newline
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 2: Walk the tree with ast.walk()
|
||
# ===========================================================================
|
||
#
|
||
# ast.walk(node) is a generator that yields EVERY node in the subtree
|
||
# rooted at `node`, in breadth-first order. It visits all children,
|
||
# grandchildren, etc. -- the entire tree.
|
||
#
|
||
# This is the simplest traversal method. It is ideal when you want to
|
||
# count or collect nodes of a certain type regardless of where they appear
|
||
# in the tree (e.g. "how many function calls are there in total?").
|
||
#
|
||
# Limitation: ast.walk() does NOT tell you the parent of each node.
|
||
# If you need parent information, see Step 6.
|
||
|
||
print("\n" + "=" * 70)
|
||
print("STEP 2 – Walk the tree: count node types")
|
||
print("=" * 70)
|
||
|
||
# Count how often each AST node type appears in the entire tree.
|
||
# type(node).__name__ gives us the class name (e.g. "Call", "Name", "Assign").
|
||
node_counts: dict[str, int] = {}
|
||
for node in ast.walk(tree):
|
||
name = type(node).__name__
|
||
node_counts[name] = node_counts.get(name, 0) + 1
|
||
|
||
# Print the 15 most frequent node types.
|
||
# Common types you will see:
|
||
# Name - every variable reference (e.g. `x`, `self`, `np`)
|
||
# Load - context marker: the name is being read (not assigned)
|
||
# Store - context marker: the name is being assigned to
|
||
# Call - a function/method call (e.g. np.mean(...))
|
||
# Attribute - a dotted access (e.g. self.data, np.array)
|
||
# Constant - a literal value (number, string, None, True, False)
|
||
# Assign - an assignment statement (x = ...)
|
||
# Return - a return statement
|
||
print("\nTop 15 most frequent node types:")
|
||
for name, count in sorted(node_counts.items(), key=lambda x: -x[1])[:15]:
|
||
print(f" {name:25s} {count:4d}")
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 3: Use ast.NodeVisitor for targeted traversal
|
||
# ===========================================================================
|
||
#
|
||
# While ast.walk() visits everything, ast.NodeVisitor lets you define
|
||
# handler methods for specific node types. The naming convention is:
|
||
#
|
||
# def visit_<NodeType>(self, node):
|
||
# ... # your logic
|
||
# self.generic_visit(node) # continue to children
|
||
#
|
||
# When you call visitor.visit(tree), the visitor walks the tree and
|
||
# dispatches to the matching visit_* method for each node.
|
||
#
|
||
# IMPORTANT: If you forget to call self.generic_visit(node) at the end
|
||
# of your visit_* method, the visitor will NOT recurse into that node's
|
||
# children. This is a common mistake!
|
||
#
|
||
# Use case here: find every class and list its methods.
|
||
|
||
print("\n" + "=" * 70)
|
||
print("STEP 3 – NodeVisitor: find all classes and their methods")
|
||
print("=" * 70)
|
||
|
||
|
||
class ClassMethodVisitor(ast.NodeVisitor):
|
||
"""Visit every ClassDef and collect its method names.
|
||
|
||
How it works:
|
||
1. The visitor walks the tree starting from the Module root.
|
||
2. When it encounters a ClassDef node, visit_ClassDef is called.
|
||
3. We iterate over the class body (node.body) and pick out FunctionDef
|
||
nodes -- these are the class's methods.
|
||
4. We store them in a dict: {class_name: [method_name, ...]}.
|
||
5. We call generic_visit to allow recursion into nested classes.
|
||
"""
|
||
|
||
def __init__(self):
|
||
# Maps class name -> list of method names
|
||
self.classes: dict[str, list[str]] = {}
|
||
|
||
def visit_ClassDef(self, node: ast.ClassDef):
|
||
# node.body is a list of AST nodes inside the class.
|
||
# Methods are FunctionDef nodes; there can also be Assign nodes
|
||
# (class-level attributes), Expr nodes (docstrings), etc.
|
||
methods = []
|
||
for item in node.body:
|
||
if isinstance(item, ast.FunctionDef):
|
||
# item.name is the method name as a string
|
||
methods.append(item.name)
|
||
self.classes[node.name] = methods
|
||
|
||
# IMPORTANT: call generic_visit to recurse into nested classes
|
||
# (classes defined inside this class). Without this call, nested
|
||
# classes would be silently skipped.
|
||
self.generic_visit(node)
|
||
|
||
|
||
# Create an instance of the visitor and run it on the tree.
|
||
visitor = ClassMethodVisitor()
|
||
visitor.visit(tree)
|
||
|
||
# Print the results: each class and its methods.
|
||
print()
|
||
for cls_name, methods in visitor.classes.items():
|
||
print(f" class {cls_name}:")
|
||
for m in methods:
|
||
print(f" - {m}()")
|
||
print()
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 4: Find function calls inside a method
|
||
# ===========================================================================
|
||
#
|
||
# Function calls in the AST are represented by ast.Call nodes.
|
||
# A Call node has a .func attribute that tells us WHAT is being called:
|
||
#
|
||
# Simple call: len(data)
|
||
# -> Call(func=Name(id='len'), args=[Name(id='data')])
|
||
#
|
||
# Attribute call: np.mean(data)
|
||
# -> Call(func=Attribute(value=Name(id='np'), attr='mean'), ...)
|
||
#
|
||
# Chained call: self.sections.append(...)
|
||
# -> Call(func=Attribute(
|
||
# value=Attribute(value=Name(id='self'), attr='sections'),
|
||
# attr='append'), ...)
|
||
#
|
||
# To reconstruct the full dotted name (e.g. "self.sections.append"), we
|
||
# walk up the chain of Attribute nodes until we reach a Name node.
|
||
|
||
print("=" * 70)
|
||
print("STEP 4 – Find function calls inside a specific method")
|
||
print("=" * 70)
|
||
|
||
|
||
class CallFinder(ast.NodeVisitor):
|
||
"""Collect all function/method calls inside a node.
|
||
|
||
For each ast.Call found, reconstruct the callable's name:
|
||
- Simple call (Name): e.g. "len", "print", "DataCleaner"
|
||
- Attribute call (obj.attr): e.g. "self.remove_nans", "np.mean"
|
||
- Chained attribute call: e.g. "self.sections.append"
|
||
"""
|
||
|
||
def __init__(self):
|
||
self.calls: list[str] = []
|
||
|
||
def visit_Call(self, node: ast.Call):
|
||
# Case 1: Simple name call, e.g. len(), print(), DataCleaner()
|
||
# node.func is an ast.Name with .id = the function name
|
||
if isinstance(node.func, ast.Name):
|
||
self.calls.append(node.func.id)
|
||
|
||
# Case 2: Attribute call, e.g. self.remove_nans(), np.mean()
|
||
# node.func is an ast.Attribute with .attr = the method name
|
||
# and .value = the object (which can itself be an Attribute for chains)
|
||
elif isinstance(node.func, ast.Attribute):
|
||
# Walk the chain of Attribute nodes to reconstruct the full
|
||
# dotted name. For example, for self.sections.append:
|
||
# Attribute(attr='append',
|
||
# value=Attribute(attr='sections',
|
||
# value=Name(id='self')))
|
||
# We collect: ['append', 'sections', 'self'] then reverse.
|
||
parts = []
|
||
current = node.func
|
||
while isinstance(current, ast.Attribute):
|
||
parts.append(current.attr) # collect each .attr
|
||
current = current.value # move to the parent object
|
||
if isinstance(current, ast.Name):
|
||
parts.append(current.id) # the root object name
|
||
# Reverse to get the natural order: "self.sections.append"
|
||
self.calls.append(".".join(reversed(parts)))
|
||
|
||
# Continue visiting children -- a Call can contain other Calls
|
||
# in its arguments. For example: print(len(data)) has two Calls.
|
||
self.generic_visit(node)
|
||
|
||
|
||
# Find the run_analysis_pipeline function in the tree and extract its calls.
|
||
# We use ast.walk() to search for the specific FunctionDef by name.
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.FunctionDef) and node.name == "run_analysis_pipeline":
|
||
finder = CallFinder()
|
||
finder.visit(node) # visit only THIS function's subtree
|
||
print(f"\nCalls inside 'run_analysis_pipeline()':")
|
||
for call in finder.calls:
|
||
print(f" -> {call}()")
|
||
break # stop after finding the first match
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 5: Inspect function signatures (arguments)
|
||
# ===========================================================================
|
||
#
|
||
# Every FunctionDef node has an .args attribute of type ast.arguments.
|
||
# This contains:
|
||
# .args : list of ast.arg -- the positional parameters
|
||
# .vararg : ast.arg or None -- the *args parameter
|
||
# .kwonlyargs : list of ast.arg -- keyword-only parameters
|
||
# .kwarg : ast.arg or None -- the **kwargs parameter
|
||
# .defaults : list of default values for positional args
|
||
# .kw_defaults: list of default values for keyword-only args
|
||
#
|
||
# Each ast.arg has:
|
||
# .arg : str -- the parameter name (e.g. "self", "data", "x")
|
||
# .annotation : AST node or None -- the type annotation
|
||
#
|
||
# FunctionDef also has:
|
||
# .returns : AST node or None -- the return type annotation
|
||
#
|
||
# To convert an annotation node back to readable Python, use:
|
||
# ast.unparse(node) -- returns a string like "np.ndarray"
|
||
# ast.dump(node) -- returns the raw AST representation (for debugging)
|
||
|
||
print("\n" + "=" * 70)
|
||
print("STEP 5 – Extract function signatures")
|
||
print("=" * 70)
|
||
|
||
print()
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.FunctionDef):
|
||
# node.args is the ast.arguments object
|
||
args = node.args
|
||
|
||
# Extract parameter names from the list of ast.arg objects
|
||
param_names = [arg.arg for arg in args.args]
|
||
|
||
# Check for type annotations on each parameter.
|
||
# arg.annotation is None if no annotation, otherwise an AST node.
|
||
# We use ast.dump() here for demonstration; in the exercises you
|
||
# will use ast.unparse() for cleaner output.
|
||
annotations = {}
|
||
for arg in args.args:
|
||
if arg.annotation:
|
||
annotations[arg.arg] = ast.dump(arg.annotation)
|
||
|
||
# Check for return type annotation.
|
||
# node.returns is None if there is no "-> ..." annotation.
|
||
ret = ast.dump(node.returns) if node.returns else "None"
|
||
|
||
print(f" def {node.name}({', '.join(param_names)}) -> {ret}")
|
||
|
||
print()
|
||
|
||
|
||
# ===========================================================================
|
||
# STEP 6: Linking nodes – parent-child relationships
|
||
# ===========================================================================
|
||
#
|
||
# A common question: "Which class does this method belong to?"
|
||
# The AST does NOT store a .parent reference on each node by default.
|
||
# But we can add one ourselves using a simple pre-processing step:
|
||
#
|
||
# for node in ast.walk(tree):
|
||
# for child in ast.iter_child_nodes(node):
|
||
# child._parent = node
|
||
#
|
||
# ast.iter_child_nodes(node) yields the direct children of a node.
|
||
# By setting child._parent = node, we create a back-link from each
|
||
# child to its parent.
|
||
#
|
||
# After this, for any FunctionDef node we can check:
|
||
# - Is its parent a ClassDef? -> it's a method
|
||
# - Is its parent a Module? -> it's a module-level function
|
||
# - Is its parent a FunctionDef? -> it's a nested function
|
||
|
||
print("=" * 70)
|
||
print("STEP 6 – Parent-child relationships")
|
||
print("=" * 70)
|
||
|
||
# Pre-processing: annotate every node with a _parent reference.
|
||
# We use a private attribute name (_parent) by convention because
|
||
# this is not part of the official AST API -- we are adding it ourselves.
|
||
for node in ast.walk(tree):
|
||
for child in ast.iter_child_nodes(node):
|
||
child._parent = node # type: ignore[attr-defined]
|
||
|
||
# Now find every FunctionDef and report whether it is a method (parent is
|
||
# ClassDef) or a module-level function (parent is Module).
|
||
print("\nMethod -> Parent mapping:")
|
||
for node in ast.walk(tree):
|
||
if isinstance(node, ast.FunctionDef):
|
||
# Retrieve the parent we stored above
|
||
parent = getattr(node, "_parent", None)
|
||
parent_type = type(parent).__name__ if parent else "?"
|
||
parent_name = getattr(parent, "name", "")
|
||
|
||
if isinstance(parent, ast.ClassDef):
|
||
# This FunctionDef is a method of a class
|
||
print(f" {parent_name}.{node.name}() [method of class]")
|
||
elif isinstance(parent, ast.Module):
|
||
# This FunctionDef is a top-level function
|
||
print(f" {node.name}() [module-level function]")
|
||
# (We skip nested functions for brevity -- they would have a
|
||
# FunctionDef as their parent.)
|
||
|
||
print("\n" + "=" * 70)
|
||
print("DEMO COMPLETE – You are now ready to start the exercises!")
|
||
print("=" * 70)
|