AST ex01 - ex03b

This commit is contained in:
Michael Schären 2026-05-03 20:27:09 +02:00
parent e9bf6f7f95
commit a5a657250d
11 changed files with 3065 additions and 6 deletions

View File

@ -3,6 +3,8 @@
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<excludeFolder url="file://$MODULE_DIR$/.venv" />
<excludeFolder url="file://$MODULE_DIR$/AISE501 LLM Zugang/.venv" />
<excludeFolder url="file://$MODULE_DIR$/Code embeddings/.venv" />
</content>
<orderEntry type="jdk" jdkName="Python 3.12 (aise-501_aise_in_se_i)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />

View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="DataSourcePerFileMappings">
<file url="file://$APPLICATION_CONFIG_DIR$/scratches/scratch_1.sql" value="be9eece5-a8ff-447a-a6a9-4660fffe89da" />
</component>
</project>

410
AST Files/ast_demo.py Normal file
View File

@ -0,0 +1,410 @@
"""
ast_demo.py Walkthrough: How Python's ast Module Works
=========================================================
AISE501 · AST Exercises · Spring Semester 2026
This file demonstrates how to parse Python source code into an Abstract
Syntax Tree (AST) and walk through it to find nodes and relationships.
WHAT IS AN AST?
An AST is a tree representation of source code. Every construct in
Python (import, function definition, class, assignment, call, ...)
becomes a node in the tree. Nodes have children: for example, a
ClassDef node contains FunctionDef children (its methods).
WHY USE THE ast MODULE?
Python ships with a built-in ``ast`` module that can parse any valid
Python source string into an AST. This lets you analyse code
*without* running it (static analysis). Use cases include:
- code linters and formatters
- documentation generators
- dependency analysis
- AI-powered code search (RAG pipelines)
HOW TO RUN THIS FILE:
cd ast_exercises
python ast_demo.py
Read the printed output step by step. Each step introduces one core
technique that you will need in the exercises.
"""
import ast
from pathlib import Path
# ---------------------------------------------------------------------------
# Load the target source file.
# We read it as a plain string -- ast.parse() works on strings, not files.
# ---------------------------------------------------------------------------
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
# ===========================================================================
# STEP 1: Parse source code into an AST
# ===========================================================================
#
# ast.parse(source_string) takes a string of Python source code and returns
# the root node of the AST. The root is always an ast.Module node.
#
# The Module node has a .body attribute -- a Python list containing one AST
# node per top-level statement in the file. "Top-level" means statements
# that are NOT inside a class or function. Examples:
# - import statements -> ast.Import or ast.ImportFrom
# - function definitions -> ast.FunctionDef
# - class definitions -> ast.ClassDef
# - if __name__ == "__main__" -> ast.If
# - bare expressions (docstrings) -> ast.Expr
#
# Each node carries metadata:
# - node.lineno : the line number in the original source (1-based)
# - node.col_offset : the column offset (0-based)
# - node-specific attributes (e.g. ClassDef has .name, .body, .bases, ...)
print("=" * 70)
print("STEP 1 Parse source code into an AST")
print("=" * 70)
# Parse the source string into an AST. This does NOT execute the code.
tree = ast.parse(source_code)
print(f"\nParsed '{SOURCE_FILE.name}' successfully.")
print(f"Root node type : {type(tree).__name__}") # -> "Module"
print(f"Number of top-level statements: {len(tree.body)}\n")
# Iterate over top-level statements and print their type and metadata.
# We use isinstance() to check what kind of node each statement is,
# then access node-specific attributes (.name for classes/functions,
# .names for imports, etc.).
for i, node in enumerate(tree.body):
# getattr with a default handles nodes that lack .lineno (rare)
line = getattr(node, "lineno", "?")
# type(node).__name__ gives us the AST class name, e.g. "ClassDef"
print(f" [{i:2d}] Line {line:>3}: {type(node).__name__}", end="")
# Print extra info depending on the node type
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
# FunctionDef.name is the function name as a string
print(f" -> function '{node.name}'", end="")
elif isinstance(node, ast.ClassDef):
# ClassDef.name is the class name as a string
print(f" -> class '{node.name}'", end="")
elif isinstance(node, (ast.Import, ast.ImportFrom)):
# Import/ImportFrom have a .names list of ast.alias objects.
# Each alias has .name (the real name) and .asname (the alias or None).
# ImportFrom also has .module (e.g. "scipy.stats" in "from scipy.stats import ...").
names = [alias.name for alias in node.names]
module = getattr(node, "module", None)
if module:
print(f" -> from {module} import {', '.join(names)}", end="")
else:
print(f" -> import {', '.join(names)}", end="")
print() # newline
# ===========================================================================
# STEP 2: Walk the tree with ast.walk()
# ===========================================================================
#
# ast.walk(node) is a generator that yields EVERY node in the subtree
# rooted at `node`, in breadth-first order. It visits all children,
# grandchildren, etc. -- the entire tree.
#
# This is the simplest traversal method. It is ideal when you want to
# count or collect nodes of a certain type regardless of where they appear
# in the tree (e.g. "how many function calls are there in total?").
#
# Limitation: ast.walk() does NOT tell you the parent of each node.
# If you need parent information, see Step 6.
print("\n" + "=" * 70)
print("STEP 2 Walk the tree: count node types")
print("=" * 70)
# Count how often each AST node type appears in the entire tree.
# type(node).__name__ gives us the class name (e.g. "Call", "Name", "Assign").
node_counts: dict[str, int] = {}
for node in ast.walk(tree):
name = type(node).__name__
node_counts[name] = node_counts.get(name, 0) + 1
# Print the 15 most frequent node types.
# Common types you will see:
# Name - every variable reference (e.g. `x`, `self`, `np`)
# Load - context marker: the name is being read (not assigned)
# Store - context marker: the name is being assigned to
# Call - a function/method call (e.g. np.mean(...))
# Attribute - a dotted access (e.g. self.data, np.array)
# Constant - a literal value (number, string, None, True, False)
# Assign - an assignment statement (x = ...)
# Return - a return statement
print("\nTop 15 most frequent node types:")
for name, count in sorted(node_counts.items(), key=lambda x: -x[1])[:15]:
print(f" {name:25s} {count:4d}")
# ===========================================================================
# STEP 3: Use ast.NodeVisitor for targeted traversal
# ===========================================================================
#
# While ast.walk() visits everything, ast.NodeVisitor lets you define
# handler methods for specific node types. The naming convention is:
#
# def visit_<NodeType>(self, node):
# ... # your logic
# self.generic_visit(node) # continue to children
#
# When you call visitor.visit(tree), the visitor walks the tree and
# dispatches to the matching visit_* method for each node.
#
# IMPORTANT: If you forget to call self.generic_visit(node) at the end
# of your visit_* method, the visitor will NOT recurse into that node's
# children. This is a common mistake!
#
# Use case here: find every class and list its methods.
print("\n" + "=" * 70)
print("STEP 3 NodeVisitor: find all classes and their methods")
print("=" * 70)
class ClassMethodVisitor(ast.NodeVisitor):
"""Visit every ClassDef and collect its method names.
How it works:
1. The visitor walks the tree starting from the Module root.
2. When it encounters a ClassDef node, visit_ClassDef is called.
3. We iterate over the class body (node.body) and pick out FunctionDef
nodes -- these are the class's methods.
4. We store them in a dict: {class_name: [method_name, ...]}.
5. We call generic_visit to allow recursion into nested classes.
"""
def __init__(self):
# Maps class name -> list of method names
self.classes: dict[str, list[str]] = {}
def visit_ClassDef(self, node: ast.ClassDef):
# node.body is a list of AST nodes inside the class.
# Methods are FunctionDef nodes; there can also be Assign nodes
# (class-level attributes), Expr nodes (docstrings), etc.
methods = []
for item in node.body:
if isinstance(item, ast.FunctionDef):
# item.name is the method name as a string
methods.append(item.name)
self.classes[node.name] = methods
# IMPORTANT: call generic_visit to recurse into nested classes
# (classes defined inside this class). Without this call, nested
# classes would be silently skipped.
self.generic_visit(node)
# Create an instance of the visitor and run it on the tree.
visitor = ClassMethodVisitor()
visitor.visit(tree)
# Print the results: each class and its methods.
print()
for cls_name, methods in visitor.classes.items():
print(f" class {cls_name}:")
for m in methods:
print(f" - {m}()")
print()
# ===========================================================================
# STEP 4: Find function calls inside a method
# ===========================================================================
#
# Function calls in the AST are represented by ast.Call nodes.
# A Call node has a .func attribute that tells us WHAT is being called:
#
# Simple call: len(data)
# -> Call(func=Name(id='len'), args=[Name(id='data')])
#
# Attribute call: np.mean(data)
# -> Call(func=Attribute(value=Name(id='np'), attr='mean'), ...)
#
# Chained call: self.sections.append(...)
# -> Call(func=Attribute(
# value=Attribute(value=Name(id='self'), attr='sections'),
# attr='append'), ...)
#
# To reconstruct the full dotted name (e.g. "self.sections.append"), we
# walk up the chain of Attribute nodes until we reach a Name node.
print("=" * 70)
print("STEP 4 Find function calls inside a specific method")
print("=" * 70)
class CallFinder(ast.NodeVisitor):
"""Collect all function/method calls inside a node.
For each ast.Call found, reconstruct the callable's name:
- Simple call (Name): e.g. "len", "print", "DataCleaner"
- Attribute call (obj.attr): e.g. "self.remove_nans", "np.mean"
- Chained attribute call: e.g. "self.sections.append"
"""
def __init__(self):
self.calls: list[str] = []
def visit_Call(self, node: ast.Call):
# Case 1: Simple name call, e.g. len(), print(), DataCleaner()
# node.func is an ast.Name with .id = the function name
if isinstance(node.func, ast.Name):
self.calls.append(node.func.id)
# Case 2: Attribute call, e.g. self.remove_nans(), np.mean()
# node.func is an ast.Attribute with .attr = the method name
# and .value = the object (which can itself be an Attribute for chains)
elif isinstance(node.func, ast.Attribute):
# Walk the chain of Attribute nodes to reconstruct the full
# dotted name. For example, for self.sections.append:
# Attribute(attr='append',
# value=Attribute(attr='sections',
# value=Name(id='self')))
# We collect: ['append', 'sections', 'self'] then reverse.
parts = []
current = node.func
while isinstance(current, ast.Attribute):
parts.append(current.attr) # collect each .attr
current = current.value # move to the parent object
if isinstance(current, ast.Name):
parts.append(current.id) # the root object name
# Reverse to get the natural order: "self.sections.append"
self.calls.append(".".join(reversed(parts)))
# Continue visiting children -- a Call can contain other Calls
# in its arguments. For example: print(len(data)) has two Calls.
self.generic_visit(node)
# Find the run_analysis_pipeline function in the tree and extract its calls.
# We use ast.walk() to search for the specific FunctionDef by name.
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef) and node.name == "run_analysis_pipeline":
finder = CallFinder()
finder.visit(node) # visit only THIS function's subtree
print(f"\nCalls inside 'run_analysis_pipeline()':")
for call in finder.calls:
print(f" -> {call}()")
break # stop after finding the first match
# ===========================================================================
# STEP 5: Inspect function signatures (arguments)
# ===========================================================================
#
# Every FunctionDef node has an .args attribute of type ast.arguments.
# This contains:
# .args : list of ast.arg -- the positional parameters
# .vararg : ast.arg or None -- the *args parameter
# .kwonlyargs : list of ast.arg -- keyword-only parameters
# .kwarg : ast.arg or None -- the **kwargs parameter
# .defaults : list of default values for positional args
# .kw_defaults: list of default values for keyword-only args
#
# Each ast.arg has:
# .arg : str -- the parameter name (e.g. "self", "data", "x")
# .annotation : AST node or None -- the type annotation
#
# FunctionDef also has:
# .returns : AST node or None -- the return type annotation
#
# To convert an annotation node back to readable Python, use:
# ast.unparse(node) -- returns a string like "np.ndarray"
# ast.dump(node) -- returns the raw AST representation (for debugging)
print("\n" + "=" * 70)
print("STEP 5 Extract function signatures")
print("=" * 70)
print()
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# node.args is the ast.arguments object
args = node.args
# Extract parameter names from the list of ast.arg objects
param_names = [arg.arg for arg in args.args]
# Check for type annotations on each parameter.
# arg.annotation is None if no annotation, otherwise an AST node.
# We use ast.dump() here for demonstration; in the exercises you
# will use ast.unparse() for cleaner output.
annotations = {}
for arg in args.args:
if arg.annotation:
annotations[arg.arg] = ast.dump(arg.annotation)
# Check for return type annotation.
# node.returns is None if there is no "-> ..." annotation.
ret = ast.dump(node.returns) if node.returns else "None"
print(f" def {node.name}({', '.join(param_names)}) -> {ret}")
print()
# ===========================================================================
# STEP 6: Linking nodes parent-child relationships
# ===========================================================================
#
# A common question: "Which class does this method belong to?"
# The AST does NOT store a .parent reference on each node by default.
# But we can add one ourselves using a simple pre-processing step:
#
# for node in ast.walk(tree):
# for child in ast.iter_child_nodes(node):
# child._parent = node
#
# ast.iter_child_nodes(node) yields the direct children of a node.
# By setting child._parent = node, we create a back-link from each
# child to its parent.
#
# After this, for any FunctionDef node we can check:
# - Is its parent a ClassDef? -> it's a method
# - Is its parent a Module? -> it's a module-level function
# - Is its parent a FunctionDef? -> it's a nested function
print("=" * 70)
print("STEP 6 Parent-child relationships")
print("=" * 70)
# Pre-processing: annotate every node with a _parent reference.
# We use a private attribute name (_parent) by convention because
# this is not part of the official AST API -- we are adding it ourselves.
for node in ast.walk(tree):
for child in ast.iter_child_nodes(node):
child._parent = node # type: ignore[attr-defined]
# Now find every FunctionDef and report whether it is a method (parent is
# ClassDef) or a module-level function (parent is Module).
print("\nMethod -> Parent mapping:")
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# Retrieve the parent we stored above
parent = getattr(node, "_parent", None)
parent_type = type(parent).__name__ if parent else "?"
parent_name = getattr(parent, "name", "")
if isinstance(parent, ast.ClassDef):
# This FunctionDef is a method of a class
print(f" {parent_name}.{node.name}() [method of class]")
elif isinstance(parent, ast.Module):
# This FunctionDef is a top-level function
print(f" {node.name}() [module-level function]")
# (We skip nested functions for brevity -- they would have a
# FunctionDef as their parent.)
print("\n" + "=" * 70)
print("DEMO COMPLETE You are now ready to start the exercises!")
print("=" * 70)

View File

@ -0,0 +1,868 @@
"""
ast_introduction.py A Guided Tour of Python's ast Module
===========================================================
AISE501 · AST Exercises · Spring Semester 2026
PURPOSE
-------
This file is a **reference and tutorial** that introduces every class and
method from Python's ``ast`` module that you will need in the exercises.
It is organised into seven sections:
1. Parsing: How source code becomes a tree
2. Node hierarchy: The class hierarchy of AST nodes
3. Statement nodes: Import, FunctionDef, ClassDef, Assign, ...
4. Expression nodes: Call, Attribute, Name, Constant, BinOp, ...
5. Traversal methods: ast.walk(), ast.NodeVisitor, ast.iter_child_nodes()
6. Utility functions: ast.dump(), ast.unparse(), ast.get_docstring()
7. Putting it all together: A mini-analysis pipeline
Each section contains:
- A short explanation of the concept
- Live code that runs and prints output
- Inline comments explaining every line
HOW TO USE
----------
Run this file from the ast_exercises/ folder:
python ast_introduction.py
Read the output alongside the source code. This is a *learning* file --
you are encouraged to modify it, add print statements, and experiment.
PREREQUISITES
-------------
- Python 3.9+ (for ast.unparse)
- No external packages required
"""
import ast
import textwrap
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 1: PARSING How Source Code Becomes a Tree ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# The entry point to the ast module is ast.parse(). It takes a string of
# valid Python source code and returns the root node of the AST.
#
# Key function:
# ast.parse(source, filename='<unknown>', mode='exec')
# source : str the Python source code to parse
# filename: str used in error messages only (optional)
# mode : str 'exec' (module), 'eval' (expression), 'single' (statement)
#
# The return value is always an ast.Module object (when mode='exec').
print("=" * 72)
print("SECTION 1: PARSING")
print("=" * 72)
# A simple example: parse a two-line Python program.
example_source = textwrap.dedent("""\
import math
x = math.sqrt(16)
""")
# ast.parse() converts the source string into an AST.
# This does NOT execute the code -- it only analyses its structure.
tree = ast.parse(example_source)
# The result is an ast.Module node.
print(f"\nReturn type of ast.parse(): {type(tree).__name__}")
print(f" -> This is always 'Module' when mode='exec'")
print(f" -> The Module node represents the entire file/script")
# The Module's .body attribute is a list of top-level statements.
print(f"\nNumber of top-level statements: {len(tree.body)}")
for i, stmt in enumerate(tree.body):
print(f" [{i}] {type(stmt).__name__} (line {stmt.lineno})")
# You can also parse a single expression:
expr_tree = ast.parse("3 + 4 * 2", mode="eval")
print(f"\nParsing in 'eval' mode returns: {type(expr_tree).__name__}")
print(f" -> The Expression node wraps a single expression")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 2: NODE HIERARCHY The Class Hierarchy of AST Nodes ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# Every node in the AST is an instance of a class defined in the ast module.
# The hierarchy (simplified):
#
# ast.AST <-- Base class for ALL nodes
# ├── ast.mod <-- Module-level nodes
# │ ├── ast.Module <-- A file/script (mode='exec')
# │ ├── ast.Expression <-- A single expression (mode='eval')
# │ └── ast.Interactive <-- A single statement (mode='single')
# │
# ├── ast.stmt <-- Statement nodes (things that DO something)
# │ ├── ast.FunctionDef <-- def foo(): ...
# │ ├── ast.AsyncFunctionDef <-- async def foo(): ...
# │ ├── ast.ClassDef <-- class Foo: ...
# │ ├── ast.Return <-- return x
# │ ├── ast.Assign <-- x = 42
# │ ├── ast.AnnAssign <-- x: int = 42
# │ ├── ast.AugAssign <-- x += 1
# │ ├── ast.For <-- for x in y: ...
# │ ├── ast.While <-- while cond: ...
# │ ├── ast.If <-- if cond: ...
# │ ├── ast.With <-- with ctx as x: ...
# │ ├── ast.Raise <-- raise ValueError(...)
# │ ├── ast.Try <-- try: ... except: ...
# │ ├── ast.Import <-- import os
# │ ├── ast.ImportFrom <-- from os import path
# │ ├── ast.Expr <-- a bare expression used as a statement
# │ │ (e.g. a docstring, or a function call
# │ │ whose return value is discarded)
# │ └── ast.Pass / Break / Continue
# │
# ├── ast.expr <-- Expression nodes (things that PRODUCE a value)
# │ ├── ast.Name <-- a variable name: x, self, np
# │ ├── ast.Constant <-- a literal: 42, "hello", None, True
# │ ├── ast.Attribute <-- dotted access: self.data, np.array
# │ ├── ast.Call <-- a function call: foo(), np.mean(x)
# │ ├── ast.BinOp <-- binary operation: x + y, a * b
# │ ├── ast.UnaryOp <-- unary operation: -x, not x
# │ ├── ast.Compare <-- comparison: x > 0, a == b
# │ ├── ast.BoolOp <-- boolean: x and y, a or b
# │ ├── ast.Subscript <-- indexing: x[0], data["key"]
# │ ├── ast.List / Tuple / Set / Dict <-- container literals
# │ ├── ast.ListComp <-- list comprehension: [x for x in y]
# │ ├── ast.Lambda <-- lambda x: x + 1
# │ ├── ast.IfExp <-- ternary: a if cond else b
# │ └── ast.JoinedStr <-- f-string: f"Hello {name}"
# │
# └── other <-- Helper / context nodes
# ├── ast.arguments <-- the argument list of a function
# ├── ast.arg <-- a single parameter (name + annotation)
# ├── ast.keyword <-- a keyword argument in a call
# ├── ast.alias <-- an import alias (import X as Y)
# └── ast.Load / Store / Del <-- name context (reading / writing / deleting)
#
# EVERY node has these common attributes (inherited from ast.AST):
# - lineno : int the source line number (1-based)
# - col_offset : int the column offset (0-based)
# - end_lineno : int where the node ends (line)
# - end_col_offset: int where the node ends (column)
#
# Node-specific attributes are documented below in each section.
print("\n" + "=" * 72)
print("SECTION 2: NODE HIERARCHY")
print("=" * 72)
# Let's verify: every node is an instance of ast.AST
code = "x = 1 + 2"
small_tree = ast.parse(code)
print(f"\nAll nodes are subclasses of ast.AST:")
for node in ast.walk(small_tree):
print(f" {type(node).__name__:20s} isinstance(node, ast.AST) = {isinstance(node, ast.AST)}")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 3: STATEMENT NODES The Building Blocks of a Program ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# Statements are things that DO something: define a function, import a
# module, assign a variable, loop, branch, etc. They form the top-level
# .body list of a Module, and also the .body lists inside classes,
# functions, if-blocks, loops, etc.
#
# Below we explore the most important statement nodes.
print("\n" + "=" * 72)
print("SECTION 3: STATEMENT NODES")
print("=" * 72)
# ── 3a. ast.Import and ast.ImportFrom ──────────────────────────────────────
#
# ast.Import represents: import os, sys
# ast.ImportFrom represents: from os.path import join, exists
#
# Attributes:
# Import:
# .names : list[ast.alias] each alias has .name and .asname
#
# ImportFrom:
# .module : str the module being imported from (e.g. "os.path")
# .names : list[ast.alias] the imported names
# .level : int number of dots for relative imports (0 = absolute)
#
# ast.alias:
# .name : str the real name (e.g. "numpy")
# .asname : str or None the alias (e.g. "np" in "import numpy as np")
print("\n── 3a. Import nodes ──")
import_code = textwrap.dedent("""\
import os
import numpy as np
from scipy.stats import ttest_ind, norm
from . import utils
""")
import_tree = ast.parse(import_code)
for node in import_tree.body:
if isinstance(node, ast.Import):
for alias in node.names:
alias_str = f" as {alias.asname}" if alias.asname else ""
print(f" import {alias.name}{alias_str}")
print(f" -> ast.Import, alias.name='{alias.name}', alias.asname={alias.asname!r}")
elif isinstance(node, ast.ImportFrom):
names = [f"{a.name}" + (f" as {a.asname}" if a.asname else "") for a in node.names]
dots = "." * node.level # relative import dots
print(f" from {dots}{node.module or ''} import {', '.join(names)}")
print(f" -> ast.ImportFrom, module='{node.module}', level={node.level}")
# ── 3b. ast.FunctionDef ───────────────────────────────────────────────────
#
# Represents a function or method definition.
#
# Attributes:
# .name : str function name
# .args : ast.arguments the parameter specification
# .body : list[ast.stmt] the function body (list of statements)
# .decorator_list : list[ast.expr] decorators (@staticmethod, etc.)
# .returns : ast.expr or None return type annotation
# .lineno : int line number of 'def'
#
# The .args attribute is an ast.arguments object with:
# .args : list[ast.arg] positional parameters
# .vararg : ast.arg or None *args
# .kwonlyargs : list[ast.arg] keyword-only parameters
# .kwarg : ast.arg or None **kwargs
# .defaults : list[ast.expr] default values (right-aligned)
# .kw_defaults : list defaults for kwonlyargs
#
# Each ast.arg has:
# .arg : str parameter name
# .annotation : ast.expr or None type annotation
print("\n── 3b. FunctionDef ──")
func_code = textwrap.dedent("""\
@staticmethod
def calculate(data: list[float], threshold: float = 0.05) -> dict:
\"\"\"Perform a calculation.\"\"\"
result = sum(data)
return {"total": result}
""")
func_tree = ast.parse(func_code)
func_node = func_tree.body[0]
print(f"\n Function name : {func_node.name}")
print(f" Line number : {func_node.lineno}")
print(f" Decorators : {[ast.unparse(d) for d in func_node.decorator_list]}")
print(f" Return type : {ast.unparse(func_node.returns) if func_node.returns else 'None'}")
print(f" Body length : {len(func_node.body)} statements")
# Inspect parameters
args = func_node.args
print(f"\n Parameters ({len(args.args)} total):")
for arg in args.args:
ann = ast.unparse(arg.annotation) if arg.annotation else "no annotation"
print(f" {arg.arg}: {ann}")
# Defaults are right-aligned: if there are 2 params and 1 default,
# the default belongs to the LAST parameter.
print(f" Defaults: {[ast.unparse(d) for d in args.defaults]}")
print(f" -> defaults are right-aligned to parameters")
# ── 3c. ast.ClassDef ─────────────────────────────────────────────────────
#
# Represents a class definition.
#
# Attributes:
# .name : str class name
# .bases : list[ast.expr] base classes
# .keywords : list[ast.keyword] metaclass etc.
# .body : list[ast.stmt] class body (methods, attributes, ...)
# .decorator_list : list[ast.expr] decorators
print("\n── 3c. ClassDef ──")
class_code = textwrap.dedent("""\
class DataProcessor(BaseProcessor):
\"\"\"Process data for analysis.\"\"\"
def __init__(self, data: list):
self.data = data
self.result = None
@staticmethod
def validate(item):
return item is not None
def process(self) -> list:
return [x for x in self.data if self.validate(x)]
""")
class_tree = ast.parse(class_code)
class_node = class_tree.body[0]
print(f"\n Class name : {class_node.name}")
print(f" Base classes: {[ast.unparse(b) for b in class_node.bases]}")
print(f" Decorators : {[ast.unparse(d) for d in class_node.decorator_list]}")
print(f" Body has {len(class_node.body)} items:")
for item in class_node.body:
if isinstance(item, ast.FunctionDef):
print(f" FunctionDef: {item.name}() (line {item.lineno})")
elif isinstance(item, ast.Expr):
print(f" Expr: (docstring) (line {item.lineno})")
else:
print(f" {type(item).__name__} (line {item.lineno})")
# ── 3d. ast.Assign and ast.AnnAssign ─────────────────────────────────────
#
# ast.Assign: x = 42 (no type annotation)
# .targets : list[ast.expr] what is assigned to (can be multiple: a = b = 1)
# .value : ast.expr the right-hand side
#
# ast.AnnAssign: x: int = 42 (with type annotation)
# .target : ast.expr what is assigned to (single target)
# .annotation: ast.expr the type annotation
# .value : ast.expr or None the value (None if just a declaration)
print("\n── 3d. Assign and AnnAssign ──")
assign_code = textwrap.dedent("""\
x = 42
a = b = [1, 2, 3]
name: str = "Alice"
count: int
""")
assign_tree = ast.parse(assign_code)
for node in assign_tree.body:
if isinstance(node, ast.Assign):
targets = [ast.unparse(t) for t in node.targets]
print(f" Assign: {', '.join(targets)} = {ast.unparse(node.value)}")
elif isinstance(node, ast.AnnAssign):
val = ast.unparse(node.value) if node.value else "(no value)"
print(f" AnnAssign: {ast.unparse(node.target)}: {ast.unparse(node.annotation)} = {val}")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 4: EXPRESSION NODES Things That Produce Values ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# Expressions are things that PRODUCE a value. They appear inside
# statements: the right-hand side of an assignment, function arguments,
# conditions in if-statements, etc.
print("\n" + "=" * 72)
print("SECTION 4: EXPRESSION NODES")
print("=" * 72)
# ── 4a. ast.Name ─────────────────────────────────────────────────────────
#
# Represents a variable reference.
#
# Attributes:
# .id : str the variable name (e.g. "x", "self", "np")
# .ctx : ast.Load / ast.Store / ast.Del
# context: is the name being read (Load), written to (Store),
# or deleted (Del)?
#
# ast.Name appears everywhere: in assignments (Store), in expressions (Load),
# as function call targets, as import aliases, etc.
print("\n── 4a. ast.Name ──")
name_code = "x = y + z"
name_tree = ast.parse(name_code)
for node in ast.walk(name_tree):
if isinstance(node, ast.Name):
ctx = type(node.ctx).__name__ # "Load" or "Store"
print(f" Name(id='{node.id}', ctx={ctx})")
# ── 4b. ast.Constant ────────────────────────────────────────────────────
#
# Represents a literal value: number, string, boolean, None, bytes.
#
# Attributes:
# .value : the Python value (int, float, str, bool, None, bytes, ...)
#
# Note: In Python 3.8+, ast.Constant replaces the older ast.Num, ast.Str,
# ast.NameConstant, ast.Bytes, and ast.Ellipsis nodes.
print("\n── 4b. ast.Constant ──")
const_code = 'x = 42; y = "hello"; z = True; w = None; f = 3.14'
const_tree = ast.parse(const_code)
for node in ast.walk(const_tree):
if isinstance(node, ast.Constant):
print(f" Constant(value={node.value!r}, type={type(node.value).__name__})")
# ── 4c. ast.Attribute ───────────────────────────────────────────────────
#
# Represents dotted access: self.data, np.array, os.path.join
#
# Attributes:
# .value : ast.expr the object (e.g. Name(id='self') or another Attribute)
# .attr : str the attribute name (e.g. "data", "array")
# .ctx : Load / Store / Del
#
# IMPORTANT: Chained attributes are nested.
# self.sections.append becomes:
# Attribute(
# value=Attribute(
# value=Name(id='self'),
# attr='sections'
# ),
# attr='append'
# )
print("\n── 4c. ast.Attribute ──")
attr_code = textwrap.dedent("""\
self.data = np.array([1, 2, 3])
result = self.data.mean()
""")
attr_tree = ast.parse(attr_code)
for node in ast.walk(attr_tree):
if isinstance(node, ast.Attribute):
# Reconstruct the dotted name for display
print(f" Attribute: .attr='{node.attr}', value={ast.unparse(node.value)}")
print(f" -> full expression: {ast.unparse(node)}")
# ── 4d. ast.Call ────────────────────────────────────────────────────────
#
# Represents a function or method call.
#
# Attributes:
# .func : ast.expr what is being called
# (ast.Name for foo(), ast.Attribute for obj.method())
# .args : list[ast.expr] positional arguments
# .keywords : list[ast.keyword] keyword arguments (each has .arg and .value)
#
# This is the MOST IMPORTANT expression node for code analysis.
# In the exercises, you will use it to build call graphs.
print("\n── 4d. ast.Call ──")
call_code = textwrap.dedent("""\
print("hello")
result = np.mean(data, axis=0)
self.process(items, verbose=True)
""")
call_tree = ast.parse(call_code)
for node in ast.walk(call_tree):
if isinstance(node, ast.Call):
# Determine what is being called
if isinstance(node.func, ast.Name):
callable_name = node.func.id
call_type = "simple (Name)"
elif isinstance(node.func, ast.Attribute):
callable_name = ast.unparse(node.func)
call_type = "attribute (obj.method)"
else:
callable_name = ast.unparse(node.func)
call_type = "other"
# Count arguments
n_positional = len(node.args)
n_keyword = len(node.keywords)
kw_names = [kw.arg for kw in node.keywords]
print(f" Call: {callable_name}()")
print(f" type : {call_type}")
print(f" positional: {n_positional} args")
print(f" keyword : {n_keyword} args {kw_names}")
print()
# ── 4e. ast.BinOp ──────────────────────────────────────────────────────
#
# Represents a binary operation: x + y, a * b, etc.
#
# Attributes:
# .left : ast.expr the left operand
# .op : ast.operator the operator (Add, Sub, Mult, Div, Pow, Mod, ...)
# .right : ast.expr the right operand
print("── 4e. ast.BinOp ──")
binop_code = "result = (a + b) * c - d / e"
binop_tree = ast.parse(binop_code)
for node in ast.walk(binop_tree):
if isinstance(node, ast.BinOp):
op_name = type(node.op).__name__
print(f" BinOp: {ast.unparse(node.left)} {op_name} {ast.unparse(node.right)}")
print(f" -> full: {ast.unparse(node)}")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 5: TRAVERSAL METHODS ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# The ast module provides three ways to traverse a tree:
#
# 1. ast.walk(node)
# - Generator that yields every node in the subtree (breadth-first)
# - Simplest approach; no parent information
# - Use when: you need to find/count all nodes of a type
#
# 2. ast.NodeVisitor
# - Subclass it and define visit_<NodeType>() methods
# - The framework dispatches to the correct method automatically
# - Use when: you need different logic for different node types
# - MUST call self.generic_visit(node) to continue to children
#
# 3. ast.iter_child_nodes(node)
# - Generator that yields the direct children of a single node
# - Use when: you need parent-child relationships (see Step 6 of ast_demo.py)
#
# There is also ast.NodeTransformer (subclass of NodeVisitor) for MODIFYING
# the tree, but we don't use it in these exercises.
print("\n" + "=" * 72)
print("SECTION 5: TRAVERSAL METHODS")
print("=" * 72)
# ── 5a. ast.walk() ──────────────────────────────────────────────────────
print("\n── 5a. ast.walk() ──")
sample = textwrap.dedent("""\
class Foo:
def bar(self):
return self.baz()
def baz(self):
return 42
""")
sample_tree = ast.parse(sample)
# ast.walk yields every node -- we can filter with isinstance
function_defs = [n for n in ast.walk(sample_tree) if isinstance(n, ast.FunctionDef)]
print(f" Found {len(function_defs)} FunctionDef nodes via ast.walk():")
for fd in function_defs:
print(f" - {fd.name}() at line {fd.lineno}")
# ── 5b. ast.NodeVisitor ──────────────────────────────────────────────────
print("\n── 5b. ast.NodeVisitor ──")
print(" (Detailed example in ast_demo.py, Steps 3-4)")
# Quick demonstration: a visitor that counts calls and names.
class CounterVisitor(ast.NodeVisitor):
"""Count ast.Call and ast.Name nodes."""
def __init__(self):
self.call_count = 0
self.name_count = 0
def visit_Call(self, node):
self.call_count += 1
self.generic_visit(node) # <-- DON'T FORGET THIS!
def visit_Name(self, node):
self.name_count += 1
self.generic_visit(node) # <-- DON'T FORGET THIS!
counter = CounterVisitor()
counter.visit(sample_tree)
print(f" In the 'Foo' class:")
print(f" Call nodes: {counter.call_count}")
print(f" Name nodes: {counter.name_count}")
print()
print(" REMINDER: If you forget self.generic_visit(node), the visitor")
print(" STOPS recursing into that node's children. This is the #1 mistake.")
# ── 5c. ast.iter_child_nodes() ──────────────────────────────────────────
print("\n── 5c. ast.iter_child_nodes() ──")
# ast.iter_child_nodes yields only the DIRECT children of a node.
# This is useful for building parent-child relationships.
class_node = sample_tree.body[0] # The ClassDef for 'Foo'
print(f" Direct children of ClassDef '{class_node.name}':")
for child in ast.iter_child_nodes(class_node):
print(f" {type(child).__name__}", end="")
if hasattr(child, "name"):
print(f" (name='{child.name}')", end="")
print()
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 6: UTILITY FUNCTIONS ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# ast.dump(node, indent=None)
# - Returns a string representation of the AST (for debugging)
# - With indent=2, prints a nicely formatted multi-line tree
#
# ast.unparse(node) [Python 3.9+]
# - Converts an AST node BACK into Python source code
# - Useful for printing type annotations, expressions, etc.
# - Does NOT reproduce the original formatting (comments are lost)
#
# ast.get_docstring(node)
# - Returns the docstring of a Module, ClassDef, or FunctionDef
# - Returns None if there is no docstring
# - Under the hood: checks if the first statement in .body is an
# Expr containing a Constant with a string value
print("\n" + "=" * 72)
print("SECTION 6: UTILITY FUNCTIONS")
print("=" * 72)
# ── 6a. ast.dump() ──────────────────────────────────────────────────────
print("\n── 6a. ast.dump() ──")
tiny_tree = ast.parse("x = 1 + 2")
print(" ast.dump (compact):")
print(f" {ast.dump(tiny_tree)}")
print()
print(" ast.dump (indented):")
print(textwrap.indent(ast.dump(tiny_tree, indent=2), " "))
# ── 6b. ast.unparse() ──────────────────────────────────────────────────
print("\n── 6b. ast.unparse() ──")
# ast.unparse is invaluable for printing type annotations and expressions
# in human-readable form rather than raw AST dumps.
func_code2 = "def greet(name: str, times: int = 1) -> None: pass"
func_tree2 = ast.parse(func_code2).body[0]
print(f" Function: {func_tree2.name}")
for arg in func_tree2.args.args:
if arg.annotation:
# Compare: ast.dump gives raw AST, ast.unparse gives Python code
print(f" {arg.arg}: dump={ast.dump(arg.annotation)}")
print(f" {arg.arg}: unparse={ast.unparse(arg.annotation)}")
if func_tree2.returns:
print(f" Returns: {ast.unparse(func_tree2.returns)}")
# ── 6c. ast.get_docstring() ────────────────────────────────────────────
print("\n── 6c. ast.get_docstring() ──")
doc_code = textwrap.dedent("""\
class MyClass:
\"\"\"This is the class docstring.\"\"\"
def my_method(self):
\"\"\"This is the method docstring.\"\"\"
pass
def no_doc(self):
pass
""")
doc_tree = ast.parse(doc_code)
cls = doc_tree.body[0]
print(f" Class '{cls.name}' docstring: {ast.get_docstring(cls)!r}")
for item in cls.body:
if isinstance(item, ast.FunctionDef):
doc = ast.get_docstring(item)
print(f" Method '{item.name}' docstring: {doc!r}")
# ╔═══════════════════════════════════════════════════════════════════════════╗
# ║ SECTION 7: PUTTING IT ALL TOGETHER A Mini Analysis Pipeline ║
# ╚═══════════════════════════════════════════════════════════════════════════╝
#
# This section combines all techniques into a short analysis pipeline
# that extracts a "code inventory" from a source file -- the same kind
# of analysis you will build in the exercises, but simplified.
print("\n" + "=" * 72)
print("SECTION 7: PUTTING IT ALL TOGETHER")
print("=" * 72)
# We'll analyse a small inline program.
analysis_target = textwrap.dedent("""\
import numpy as np
from scipy import stats
def load_data(path: str) -> np.ndarray:
\"\"\"Load data from a file.\"\"\"
return np.loadtxt(path)
class Analyzer:
\"\"\"Perform statistical analysis.\"\"\"
def __init__(self, data: np.ndarray):
self.data = data
self.results = {}
def compute_mean(self) -> float:
return float(np.mean(self.data))
def run_test(self, threshold: float = 0.05) -> dict:
stat, p = stats.normaltest(self.data)
return {"statistic": stat, "p_value": p, "significant": p < threshold}
""")
target_tree = ast.parse(analysis_target)
class FullInventoryVisitor(ast.NodeVisitor):
"""Comprehensive code inventory visitor.
Collects:
- Imports (module aliases)
- Top-level functions (with signatures)
- Classes (with methods, attributes, and external calls)
"""
def __init__(self):
self.imports = {} # alias -> module
self.functions = [] # list of {name, params, returns}
self.classes = [] # list of {name, methods, attributes}
self._current_class = None # track which class we are inside
def visit_Import(self, node):
for alias in node.names:
key = alias.asname if alias.asname else alias.name
self.imports[key] = alias.name
def visit_ImportFrom(self, node):
for alias in node.names:
key = alias.asname if alias.asname else alias.name
self.imports[key] = f"{node.module}.{alias.name}"
def visit_FunctionDef(self, node):
# Extract parameter info
params = []
for arg in node.args.args:
if arg.arg == "self":
continue
ann = ast.unparse(arg.annotation) if arg.annotation else None
params.append({"name": arg.arg, "type": ann})
returns = ast.unparse(node.returns) if node.returns else None
info = {"name": node.name, "params": params, "returns": returns}
if self._current_class is not None:
# It's a method -- add to current class
self._current_class["methods"].append(info)
else:
# It's a top-level function
self.functions.append(info)
self.generic_visit(node)
def visit_ClassDef(self, node):
class_info = {
"name": node.name,
"bases": [ast.unparse(b) for b in node.bases],
"docstring": ast.get_docstring(node),
"methods": [],
"attributes": [],
}
# Find __init__ and extract self.xxx assignments
for item in node.body:
if isinstance(item, ast.FunctionDef) and item.name == "__init__":
for sub in ast.walk(item):
if isinstance(sub, ast.Assign):
for target in sub.targets:
if (isinstance(target, ast.Attribute)
and isinstance(target.value, ast.Name)
and target.value.id == "self"):
class_info["attributes"].append(target.attr)
# Visit children (methods) with class context
old = self._current_class
self._current_class = class_info
self.generic_visit(node)
self._current_class = old
self.classes.append(class_info)
# Run the analysis
inventory = FullInventoryVisitor()
inventory.visit(target_tree)
# Print the results
print(f"\n Imports:")
for alias, module in inventory.imports.items():
print(f" {alias} -> {module}")
print(f"\n Top-level functions:")
for func in inventory.functions:
params_str = ", ".join(
f"{p['name']}: {p['type']}" if p["type"] else p["name"]
for p in func["params"]
)
print(f" def {func['name']}({params_str}) -> {func['returns']}")
print(f"\n Classes:")
for cls in inventory.classes:
print(f" class {cls['name']}({', '.join(cls['bases'])}):")
print(f" docstring: {cls['docstring']!r}")
print(f" attributes: {cls['attributes']}")
for method in cls["methods"]:
params_str = ", ".join(
f"{p['name']}: {p['type']}" if p["type"] else p["name"]
for p in method["params"]
)
print(f" def {method['name']}({params_str}) -> {method['returns']}")
# ── Summary ─────────────────────────────────────────────────────────────────
print("\n" + "=" * 72)
print("REFERENCE SUMMARY")
print("=" * 72)
print("""
PARSING
ast.parse(source) Parse string -> ast.Module
KEY NODE TYPES
ast.Module Root: represents the file
ast.Import / ast.ImportFrom Import statements
ast.FunctionDef Function/method definitions
ast.ClassDef Class definitions
ast.Assign / ast.AnnAssign Assignment statements
ast.Call Function/method calls
ast.Name Variable references
ast.Attribute Dotted access (obj.attr)
ast.Constant Literal values (42, "hello", True)
ast.Expr Bare expression (e.g. docstring)
ast.arguments / ast.arg Parameter specifications
TRAVERSAL
ast.walk(node) Yield all nodes (breadth-first)
ast.NodeVisitor Visitor pattern (visit_X methods)
ast.iter_child_nodes(node) Yield direct children only
UTILITIES
ast.dump(node, indent=2) Debug: print AST structure
ast.unparse(node) Convert AST back to Python source
ast.get_docstring(node) Extract docstring from def/class
You are now ready to start the exercises!
Begin with: python ast_demo.py
Then: python ex01_find_classes_functions.py
""")

View File

@ -0,0 +1,133 @@
"""
Exercise 1 Find All Classes and Top-Level Functions
=====================================================
AISE501 · AST Exercises · Spring Semester 2026
Learning goals
--------------
* Parse a Python source file into an AST using ``ast.parse()``.
* Walk the tree with ``ast.walk()`` to discover ``ClassDef`` and
``FunctionDef`` nodes.
* Extract basic metadata: name, line number, docstring.
Tasks
-----
Part A Parse sample_stats.py and list every top-level class (TODO 1-2).
Part B List every top-level (module-level) function (TODO 3-4).
Part C Extract the docstring of each class and function (TODO 5-6).
"""
import ast
from pathlib import Path
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
# Parse the source code into an AST
tree = ast.parse(source_code)
# ── Part A: Find All Classes ────────────────────────────────────────────────
# Iterate over the top-level body of the module and collect ClassDef nodes.
print("=" * 60)
print("Part A Classes in sample_stats.py")
print("=" * 60)
# TODO 1: Create a list called `classes` that contains every ast.ClassDef
# node found in tree.body (the top-level statements).
#
# Hint: Use a list comprehension:
# [node for node in tree.body if isinstance(node, ???)]
classes = [node for node in tree.body if isinstance(node, ast.ClassDef)] # TODO: replace with list comprehension
# TODO 2: Print each class name and its line number.
# Access node.name and node.lineno.
max_class_length = str(max([len(cls.name) for cls in classes]))
for cls in classes:
# TODO: print class name and line number
print(f"{cls.name:{max_class_length}} | Line: {cls.lineno}")
# ── Part B: Find All Top-Level Functions ────────────────────────────────────
# Same approach, but filter for FunctionDef instead of ClassDef.
print("\n" + "=" * 60)
print("Part B Top-level functions in sample_stats.py")
print("=" * 60)
# TODO 3: Create a list called `functions` that contains every
# ast.FunctionDef node found in tree.body.
functions = [node for node in tree.body if isinstance(node, ast.FunctionDef)] # TODO: replace with list comprehension
# TODO 4: Print each function name and its line number.
max_function_length = str(max([len(func.name) for func in functions]))
for func in functions:
# TODO: print function name and line number
print(f"{func.name:{max_function_length}} | Line: {func.lineno}")
# ── Part C: Extract Docstrings ──────────────────────────────────────────────
# In Python's AST a docstring is the first statement in the body of a
# class or function, if that statement is an ast.Expr whose value is an
# ast.Constant with a string value.
print("\n" + "=" * 60)
print("Part C Docstrings")
print("=" * 60)
# TODO 5: Write a helper function get_docstring(node) that returns the
# docstring of a ClassDef or FunctionDef, or None if there is none.
#
# Hint: You can also use ast.get_docstring(node) from the standard library,
# but try implementing it manually first to understand the tree structure.
#
# Manual approach:
# 1. Check if node.body is non-empty.
# 2. Check if the first element is an ast.Expr.
# 3. Check if that Expr's .value is an ast.Constant with a str value.
# 4. If all checks pass, return the string; otherwise return None.
def get_docstring(node):
"""Return the docstring of *node*, or None."""
if not node.body:
return None
if isinstance(node.body[0], ast.Expr):
if isinstance(node.body[0].value, ast.Constant):
return str(node.body[0].value.s)
return None
# TODO 6: For each class and function, print its name and docstring (first line only).
print("\nClasses:")
for cls in classes:
doc = get_docstring(cls)
# TODO: print cls.name and the first line of doc (or "No docstring")
if not doc:
doc = "No docstring"
print(f"{cls.name:{max_class_length}} | {doc.split('\n')[0]}")
print("\nFunctions:")
for func in functions:
doc = get_docstring(func)
# TODO: print func.name and the first line of doc (or "No docstring")
if not doc:
doc = "No docstring"
print(f"{func.name:{max_function_length}} | {doc.split('\n')[0]}")
# ── Expected Output (abbreviated) ──────────────────────────────────────────
# Part A should list: DataCleaner, DescriptiveStats, HypothesisTester,
# CurveFitter, ReportGenerator
# Part B should list: load_csv, save_json, validate_data, run_analysis_pipeline
# Part C should show the first line of each docstring.

View File

@ -0,0 +1,223 @@
"""
Exercise 2 Analyse Class Methods and Attributes
==================================================
AISE501 · AST Exercises · Spring Semester 2026
Learning goals
--------------
* Use ``ast.NodeVisitor`` to build a targeted traversal.
* Extract method signatures (parameters, return annotations).
* Identify instance attributes assigned in ``__init__``.
* Distinguish between regular methods and static methods.
Tasks
-----
Part A Use a NodeVisitor to collect methods per class (TODOs 1-2).
Part B Extract the parameter list for each method (TODOs 3-4).
Part C Find instance attributes set in __init__ (TODOs 5-6).
Part D Detect @staticmethod decorators (TODOs 7-8).
"""
import ast
from pathlib import Path
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
tree = ast.parse(source_code)
# ── Part A: Collect Methods Per Class with NodeVisitor ──────────────────────
print("=" * 60)
print("Part A Methods per class (NodeVisitor)")
print("=" * 60)
# TODO 1: Complete the ClassMethodVisitor.
# In visit_ClassDef, iterate over node.body and collect every
# FunctionDef into a list. Store the result in self.classes
# as a dict mapping class_name -> list of method names.
#
# Hint: Don't forget to call self.generic_visit(node) at the end
# so that nested classes (if any) are also visited.
class ClassMethodVisitor(ast.NodeVisitor):
def __init__(self):
self.classes: dict[str, list[ast.FunctionDef]] = {}
def visit_ClassDef(self, node: ast.ClassDef):
# TODO: collect method names and store in self.classes
functions = [func for func in node.body if isinstance(func, ast.FunctionDef)]
self.classes[node.name] = functions
self.generic_visit(node)
# TODO 2: Instantiate the visitor, call visitor.visit(tree),
# and print each class with its methods.
visitor = ClassMethodVisitor()
visitor.visit(tree)
for cls_name, methods in visitor.classes.items():
print(f"\n class {cls_name}:")
for m in methods:
print(f" - {m.name}()")
# ── Part B: Extract Method Signatures ───────────────────────────────────────
# For each method, extract its parameter names (excluding 'self')
# and any type annotations.
print("\n" + "=" * 60)
print("Part B Method signatures")
print("=" * 60)
# TODO 3: Write a function `get_signature(func_node)` that returns a string
# representation of the function's parameters.
#
# For each parameter in func_node.args.args:
# - Skip 'self'
# - Get the parameter name from arg.arg
# - If arg.annotation exists, unparse it with ast.unparse()
# - Format as "name: type" or just "name"
#
# Also check func_node.returns for a return annotation.
#
# Example output: "(data: np.ndarray, z_threshold: float) -> np.ndarray"
def get_signature(func_node: ast.FunctionDef) -> str:
"""Return a string like '(param1: Type, param2) -> ReturnType'."""
params = []
for arg in func_node.args.args:
# Skip 'self'
if arg.arg == "self":
continue
# Get parameter name
name = arg.arg
# Check for type annotation
if arg.annotation:
type_hint = ast.unparse(arg.annotation)
params.append(f"{name}: {type_hint}")
else:
params.append(name)
return_type = ast.unparse(func_node.returns) if func_node.returns else None
return f"({", ".join(params)}) -> {return_type}"
# TODO 4: For each class and method, print the full signature.
# Reuse the visitor results from Part A.
for cls_name, methods in visitor.classes.items():
print(f"\n class {cls_name}:")
# You need the actual FunctionDef nodes, not just names.
# Hint: walk the tree again or modify the visitor to store nodes.
for m in methods:
print(f" - {get_signature(m)}")
# ── Part C: Find Instance Attributes ───────────────────────────────────────
# Instance attributes are typically assigned in __init__ as self.xxx = ...
print("\n" + "=" * 60)
print("Part C Instance attributes (self.xxx in __init__)")
print("=" * 60)
# TODO 5: Write a function `find_instance_attributes(class_node)` that
# returns a list of attribute names assigned via self.xxx = ...
# in the __init__ method.
#
# Approach:
# 1. Find the __init__ method in class_node.body.
# 2. Walk through the __init__ body looking for ast.Assign or
# ast.AnnAssign nodes.
# 3. For Assign: check if any target is an ast.Attribute where
# target.value is ast.Name(id='self').
# 4. Collect the attribute names (target.attr).
def find_instance_attributes(class_node: ast.ClassDef) -> list[str]:
"""Return attribute names assigned as self.xxx in __init__."""
attributes = []
# 1. Find __init__
init_method = None
for node in class_node.body:
if isinstance(node, ast.FunctionDef) and node.name == "__init__":
init_method = node
break
if not init_method:
return attributes
# 2. Walk __init__ body
for node in ast.walk(init_method):
# 3a. Handle regular assignment: self.x = ...
if isinstance(node, ast.Assign):
for target in node.targets:
if (isinstance(target, ast.Attribute) and
isinstance(target.value, ast.Name) and
target.value.id == "self"):
attributes.append(target.attr)
# 3b. Handle annotated assignment: self.x: int = ...
elif isinstance(node, ast.AnnAssign):
target = node.target
if (isinstance(target, ast.Attribute) and
isinstance(target.value, ast.Name) and
target.value.id == "self"):
attributes.append(target.attr)
return attributes
# TODO 6: For each class, print its instance attributes.
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
attrs = find_instance_attributes(node)
print(f"\n class {node.name}:")
for attr in attrs:
print(f" self.{attr}")
# ── Part D: Detect Static Methods ──────────────────────────────────────────
print("\n" + "=" * 60)
print("Part D Static methods")
print("=" * 60)
# TODO 7: Write a function `is_static_method(func_node)` that returns True
# if the function has a @staticmethod decorator.
#
# Hint: Decorators are in func_node.decorator_list.
# Each decorator is an ast.Name node (for simple decorators).
# Check if any decorator has .id == "staticmethod".
def is_static_method(func_node: ast.FunctionDef) -> bool:
"""Return True if *func_node* is decorated with @staticmethod."""
return any([ast.unparse(dec) == "staticmethod" for dec in func_node.decorator_list])
# TODO 8: For each class, list its static methods separately.
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
statics = [m.name for m in node.body
if isinstance(m, ast.FunctionDef) and is_static_method(m)]
regulars = [m.name for m in node.body
if isinstance(m, ast.FunctionDef) and not is_static_method(m)]
print(f"\n class {node.name}:")
print(f" Regular methods : {regulars}")
print(f" Static methods : {statics}")
# ── Expected Output (abbreviated) ──────────────────────────────────────────
# Part A: Each class with its method list.
# Part B: Full signatures, e.g. "remove_outliers(z_threshold: float) -> np.ndarray"
# Part C: DataCleaner -> self.raw_data, self.cleaned
# DescriptiveStats -> self.data
# etc.
# Part D: CurveFitter has static methods: linear_model, quadratic_model, exponential_model

View File

@ -0,0 +1,234 @@
"""
Exercise 3 Build a Method Call Graph Between Classes
======================================================
AISE501 · AST Exercises · Spring Semester 2026
Learning goals
--------------
* Resolve method calls (``self.method()``, ``obj.method()``) to their
owning class using AST analysis.
* Build a call graph that records which methods call which other methods.
* Detect cross-class calls (e.g. ``ReportGenerator.add_descriptive``
receives a ``DescriptiveStats`` object and calls ``desc.full_report()``).
* Output the call graph as an adjacency list.
Tasks
-----
Part A Find all method calls within each method (TODOs 1-3).
Part B Resolve self.method() calls within the same class (TODOs 4-5).
Part C Detect cross-class calls using constructor arguments (TODOs 6-7).
Part D Print the full call graph as an adjacency list (TODO 8).
"""
import ast
from pathlib import Path
from collections import defaultdict
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
tree = ast.parse(source_code)
# ── Helper: collect class information ───────────────────────────────────────
class_info: dict[str, dict] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
methods = {}
for item in node.body:
if isinstance(item, ast.FunctionDef):
methods[item.name] = item
class_info[node.name] = {"node": node, "methods": methods}
# ── Part A: Find All Calls Within Each Method ──────────────────────────────
print("=" * 60)
print("Part A Raw calls inside each method")
print("=" * 60)
# TODO 1: Write a function `extract_calls(func_node)` that returns a list
# of call descriptions found inside *func_node*.
#
# For each ast.Call node found (use ast.walk):
# - If func is ast.Attribute (e.g. self.remove_nans(), np.mean()):
# record {"type": "attribute", "object": <name>, "method": <attr>}
# - If func is ast.Name (e.g. print(), len()):
# record {"type": "name", "name": <id>}
#
# To get the object name for Attribute calls:
# - If node.func.value is ast.Name -> use .id (e.g. "self", "np")
# - Otherwise use ast.unparse(node.func.value) as a fallback
def extract_calls(func_node: ast.FunctionDef) -> list[dict]:
"""Return a list of call descriptions inside *func_node*."""
calls = []
# TODO: implement using ast.walk(func_node) and isinstance checks
for call in [n for n in ast.walk(func_node) if isinstance(n, ast.Call)]:
if isinstance(call.func, ast.Attribute):
if isinstance(call.func.value, ast.Name):
obj_name = call.func.value.id
else:
obj_name = ast.unparse(call.func.value)
calls.append({
"type": "attribute",
"object": obj_name,
"method": call.func.attr
})
elif isinstance(call.func, ast.Name):
calls.append({
"type": "name",
"name": call.func.id
})
return calls
# TODO 2: For each class and method, print the extracted calls.
for cls_name, info in class_info.items():
print(f"\n class {cls_name}:")
for method_name, method_node in info["methods"].items():
calls = extract_calls(method_node)
print(f" {method_name}():")
for c in calls:
if c["type"] == "attribute":
print(f" -> {c['object']}.{c['method']}()")
else:
print(f" -> {c['name']}()")
# ── Part B: Resolve self.method() Calls ────────────────────────────────────
# When a method calls self.xxx(), that is an internal call to another
# method of the same class.
print("\n" + "=" * 60)
print("Part B Internal calls (self.method())")
print("=" * 60)
# TODO 3: Write a function `find_internal_calls(cls_name, method_name)`
# that returns a list of method names called via self.xxx()
# where xxx is also a method of the same class.
#
# Approach:
# 1. Get the method node from class_info[cls_name]["methods"][method_name]
# 2. Call extract_calls() on it
# 3. Filter for calls where type=="attribute" and object=="self"
# 4. Further filter: check that the called method name exists in the
# same class's method list.
def find_internal_calls(cls_name: str, method_name: str) -> list[str]:
"""Return method names called as self.xxx() within the same class."""
# TODO: implement
class_methods = class_info[cls_name]["methods"]
method = class_methods[method_name]
self_calls = [call["method"] for call in extract_calls(method)
if call["type"] == "attribute" and call["object"] == "self"
and call["method"] in class_methods.keys()]
return self_calls
# TODO 4: Print internal call edges for each class.
for cls_name, info in class_info.items():
print(f"\n class {cls_name}:")
for method_name in info["methods"]:
internal = find_internal_calls(cls_name, method_name)
if internal:
for target in internal:
print(f" {method_name}() -> self.{target}()")
# ── Part C: Detect Cross-Class Calls ──────────────────────────────────────
# Cross-class calls occur when a method receives an object of another class
# and calls a method on it. For example:
# def add_descriptive(self, desc: DescriptiveStats) -> None:
# self.sections.append({"content": desc.full_report()})
# Here, desc.full_report() is a cross-class call.
print("\n" + "=" * 60)
print("Part C Cross-class calls")
print("=" * 60)
# TODO 5: Write a function `get_param_types(func_node)` that returns a dict
# mapping parameter names to their annotation type names (as strings).
#
# For each arg in func_node.args.args (skip 'self'):
# - If arg.annotation is an ast.Name, use .id
# - If arg.annotation is an ast.Attribute, use ast.unparse()
# - Otherwise skip it
#
# Example: for add_descriptive(self, desc: DescriptiveStats)
# return {"desc": "DescriptiveStats"}
def get_param_types(func_node: ast.FunctionDef) -> dict[str, str]:
"""Map parameter name -> annotation type name."""
# TODO: implement
return {}
# TODO 6: Write a function `find_cross_class_calls(cls_name, method_name)`
# that returns a list of tuples: (target_class, target_method).
#
# Approach:
# 1. Get parameter types via get_param_types().
# 2. Get all calls via extract_calls().
# 3. For each attribute call where the object name matches a parameter
# whose type is a known class (in class_info), record the edge.
def find_cross_class_calls(cls_name: str, method_name: str) -> list[tuple[str, str]]:
"""Return [(target_class, target_method), ...] for cross-class calls."""
# TODO: implement
return []
# TODO 7: Print cross-class call edges.
# for cls_name, info in class_info.items():
# for method_name in info["methods"]:
# cross = find_cross_class_calls(cls_name, method_name)
# for target_cls, target_method in cross:
# print(f" {cls_name}.{method_name}() -> {target_cls}.{target_method}()")
# ── Part D: Full Call Graph as Adjacency List ──────────────────────────────
print("\n" + "=" * 60)
print("Part D Full call graph (adjacency list)")
print("=" * 60)
# TODO 8: Combine internal and cross-class calls into a single adjacency
# list (dict mapping "Class.method" -> list of "Class.method").
# Include module-level functions calling class methods too.
#
# Also analyse run_analysis_pipeline() which instantiates classes
# and calls their methods.
# call_graph = defaultdict(list)
#
# # Add internal calls
# for cls_name, info in class_info.items():
# for method_name in info["methods"]:
# source = f"{cls_name}.{method_name}"
# for target in find_internal_calls(cls_name, method_name):
# call_graph[source].append(f"{cls_name}.{target}")
# for target_cls, target_method in find_cross_class_calls(cls_name, method_name):
# call_graph[source].append(f"{target_cls}.{target_method}")
#
# # Print the graph
# for source, targets in sorted(call_graph.items()):
# print(f" {source}")
# for t in targets:
# print(f" -> {t}")
# ── Expected Output (key edges) ────────────────────────────────────────────
# DataCleaner.remove_outliers -> DataCleaner.remove_nans (internal)
# ReportGenerator.add_descriptive -> DescriptiveStats.full_report (cross-class)
# ReportGenerator.add_hypothesis -> HypothesisTester.interpret_result (cross-class)
# ReportGenerator.add_curve_fit -> CurveFitter.r_squared (cross-class)

View File

@ -0,0 +1,285 @@
"""
Exercise 4 Full Dependency Graph with Visualisation
=====================================================
AISE501 · AST Exercises · Spring Semester 2026
Learning goals
--------------
* Combine all previous analyses into a comprehensive dependency graph.
* Track import dependencies (which modules does each class use?).
* Analyse the ``run_analysis_pipeline()`` function to discover how
classes are instantiated and wired together.
* Export the graph in DOT format and render it with Graphviz.
* (Optional) Use ``networkx`` and ``matplotlib`` for interactive display.
Tasks
-----
Part A Map each class to its external library calls (TODOs 1-2).
Part B Analyse run_analysis_pipeline for data flow (TODOs 3-5).
Part C Export to DOT format (TODOs 6-7).
Part D (Optional) Render with networkx + matplotlib (TODOs 8-9).
"""
import ast
from pathlib import Path
from collections import defaultdict
SOURCE_FILE = Path(__file__).parent / "sample_stats.py"
source_code = SOURCE_FILE.read_text()
tree = ast.parse(source_code)
# ── Reusable helpers from previous exercises ───────────────────────────────
def extract_calls(func_node: ast.FunctionDef) -> list[dict]:
"""Return call descriptions inside *func_node*."""
calls = []
for node in ast.walk(func_node):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Attribute):
if isinstance(node.func.value, ast.Name):
obj_name = node.func.value.id
else:
obj_name = ast.unparse(node.func.value)
calls.append({
"type": "attribute",
"object": obj_name,
"method": node.func.attr,
})
elif isinstance(node.func, ast.Name):
calls.append({"type": "name", "name": node.func.id})
return calls
# Collect class info
class_info: dict[str, dict] = {}
for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
methods = {}
for item in node.body:
if isinstance(item, ast.FunctionDef):
methods[item.name] = item
class_info[node.name] = {"node": node, "methods": methods}
# Collect module-level functions
module_functions: dict[str, ast.FunctionDef] = {}
for node in tree.body:
if isinstance(node, ast.FunctionDef):
module_functions[node.name] = node
# ── Part A: External Library Dependencies Per Class ─────────────────────────
print("=" * 60)
print("Part A External library calls per class")
print("=" * 60)
# First, collect all imports to know which names are external modules.
# TODO 1: Walk tree.body and collect all imported module names.
#
# For ast.Import: names like "os", "csv", "json"
# For ast.ImportFrom: the module, e.g. "numpy", "scipy.stats"
#
# Also collect aliases: "import numpy as np" means "np" -> "numpy"
#
# Store in a dict: alias_to_module = {"np": "numpy", "stats": "scipy.stats", ...}
alias_to_module: dict[str, str] = {}
# TODO: iterate over tree.body and fill alias_to_module
# for node in tree.body:
# if isinstance(node, ast.Import):
# for alias in node.names:
# key = alias.asname if alias.asname else alias.name
# alias_to_module[key] = alias.name
# elif isinstance(node, ast.ImportFrom):
# for alias in node.names:
# key = alias.asname if alias.asname else alias.name
# alias_to_module[key] = node.module or ""
# TODO 2: For each class, find which external modules its methods call.
#
# For each method, look at extract_calls() results:
# - For attribute calls: check if the object name is in alias_to_module
# - Record the mapping: class_name -> set of module names
#
# Example: DescriptiveStats calls np.mean, np.var, stats.skew
# -> DescriptiveStats uses {"numpy", "scipy.stats"}
def get_external_deps(cls_name: str) -> set[str]:
"""Return the set of external module names used by *cls_name*."""
deps = set()
# TODO: implement
return deps
# for cls_name in class_info:
# deps = get_external_deps(cls_name)
# print(f"\n {cls_name}: {sorted(deps) if deps else '(none)'}")
# ── Part B: Analyse run_analysis_pipeline for Data Flow ────────────────────
# This function is the "glue" that creates objects and passes them between
# classes. We want to discover:
# - Which classes are instantiated
# - Which methods are called on those instances
# - How data flows: output of one call becomes input of another
print("\n" + "=" * 60)
print("Part B Data flow in run_analysis_pipeline")
print("=" * 60)
# TODO 3: Find the run_analysis_pipeline function node.
pipeline_func = None
# TODO: find it in tree.body
# TODO 4: Walk the pipeline function and find all variable assignments.
# For each ast.Assign where the right-hand side is a Call:
# - Record what variable name receives the result
# - Record what class/function was called
#
# For example: cleaner = DataCleaner(raw_data)
# -> variable "cleaner" is assigned an instance of "DataCleaner"
#
# Store as: var_types = {"cleaner": "DataCleaner", "desc": "DescriptiveStats", ...}
var_types: dict[str, str] = {}
# TODO: implement by walking pipeline_func
# TODO 5: Now trace method calls on those variables.
# For each attribute call (e.g. cleaner.remove_nans()):
# - Look up the variable in var_types to find its class
# - Record the edge: "run_analysis_pipeline" -> "DataCleaner.remove_nans"
#
# Build a list of edges: [(source, target), ...]
pipeline_edges: list[tuple[str, str]] = []
# TODO: implement
# print("\n Data flow edges:")
# for source, target in pipeline_edges:
# print(f" {source} -> {target}")
# ── Part C: Export to DOT Format ────────────────────────────────────────────
# DOT is the graph description language used by Graphviz.
print("\n" + "=" * 60)
print("Part C Export to DOT format")
print("=" * 60)
# TODO 6: Collect ALL edges into a single list:
# - Internal calls (self.method -> self.other_method within a class)
# - Cross-class calls (from Exercise 3 logic)
# - Pipeline edges (from Part B)
# - External dependency edges (Class -> module)
#
# Use the format: (source_label, target_label, edge_type)
# where edge_type is one of: "internal", "cross_class", "pipeline", "external"
all_edges: list[tuple[str, str, str]] = []
# TODO: collect all edges
# TODO 7: Generate a DOT string and write it to "dependency_graph.dot".
#
# DOT format example:
# digraph G {
# rankdir=LR;
# node [shape=box, style=filled, fillcolor=lightblue];
# "DataCleaner.remove_nans" -> "DataCleaner.remove_outliers";
# "ReportGenerator.add_descriptive" -> "DescriptiveStats.full_report";
# }
#
# Use different colors for different edge types:
# internal -> black
# cross_class -> blue
# pipeline -> red
# external -> gray
def generate_dot(edges: list[tuple[str, str, str]]) -> str:
"""Return a DOT-format string for the dependency graph."""
# TODO: implement
return "digraph G {\n}\n"
# dot_string = generate_dot(all_edges)
# dot_file = Path(__file__).parent / "dependency_graph.dot"
# dot_file.write_text(dot_string)
# print(f"\n Written to {dot_file}")
# print(f" Render with: dot -Tpng dependency_graph.dot -o dependency_graph.png")
# ── Part D: (Optional) Render with networkx + matplotlib ───────────────────
print("\n" + "=" * 60)
print("Part D (Optional) Visualise with networkx")
print("=" * 60)
# TODO 8: Install networkx and matplotlib if not already available:
# pip install networkx matplotlib
#
# TODO 9: Build a networkx DiGraph from all_edges and render it.
#
# import networkx as nx
# import matplotlib.pyplot as plt
#
# G = nx.DiGraph()
#
# # Color map for edge types
# edge_colors = {
# "internal": "black",
# "cross_class": "blue",
# "pipeline": "red",
# "external": "gray",
# }
#
# # Add edges with colors
# for source, target, etype in all_edges:
# G.add_edge(source, target, color=edge_colors.get(etype, "black"))
#
# # Node colors: classes in light blue, functions in light green, modules in light gray
# node_colors = []
# for n in G.nodes():
# if "." in n and n.split(".")[0] in class_info:
# node_colors.append("lightblue")
# elif n in module_functions:
# node_colors.append("lightgreen")
# else:
# node_colors.append("lightgray")
#
# # Draw
# pos = nx.spring_layout(G, k=2, seed=42)
# colors = [G[u][v]["color"] for u, v in G.edges()]
#
# plt.figure(figsize=(16, 10))
# nx.draw(G, pos,
# with_labels=True,
# node_color=node_colors,
# edge_color=colors,
# node_size=2000,
# font_size=7,
# arrows=True,
# arrowsize=15)
# plt.title("Dependency Graph sample_stats.py")
# plt.tight_layout()
# plt.savefig(Path(__file__).parent / "dependency_graph.png", dpi=150)
# plt.show()
# print(" Saved dependency_graph.png")
print("\n (Uncomment the code above after installing networkx and matplotlib)")
# ── Expected Output ────────────────────────────────────────────────────────
# Part A: Each class maps to numpy, scipy.stats, scipy.optimize, etc.
# Part B: Pipeline shows DataCleaner -> DescriptiveStats -> HypothesisTester
# -> ReportGenerator chain.
# Part C: A .dot file with ~20-30 edges in four colours.
# Part D: A visual graph showing the full architecture of sample_stats.py.

295
AST Files/sample_stats.py Normal file
View File

@ -0,0 +1,295 @@
"""
sample_stats.py Statistical Analysis Module
==============================================
This module provides classes and functions for loading, cleaning,
analysing, and visualising numerical data. It is used as the target
program for the AST-analysis exercises in AISE501.
Dependencies: numpy, scipy
"""
import os
import csv
import json
from typing import List, Dict, Optional, Tuple
import numpy as np
from scipy import stats
from scipy.optimize import curve_fit
# ── Helper functions ────────────────────────────────────────────────────────
def load_csv(filepath: str) -> List[List[str]]:
"""Read a CSV file and return its rows as lists of strings."""
rows = []
with open(filepath, newline="") as fh:
reader = csv.reader(fh)
for row in reader:
rows.append(row)
return rows
def save_json(data: dict, filepath: str) -> None:
"""Serialise *data* to a JSON file."""
with open(filepath, "w") as fh:
json.dump(data, fh, indent=2)
def validate_data(values: List[float]) -> bool:
"""Return True if all values are finite numbers."""
return all(np.isfinite(v) for v in values)
# ── Data cleaning ───────────────────────────────────────────────────────────
class DataCleaner:
"""Handles missing-value imputation and outlier removal."""
def __init__(self, raw_data: List[float]):
self.raw_data = raw_data
self.cleaned: Optional[np.ndarray] = None
def remove_nans(self) -> np.ndarray:
"""Drop NaN entries from the data."""
arr = np.array(self.raw_data, dtype=float)
self.cleaned = arr[~np.isnan(arr)]
return self.cleaned
def impute_mean(self) -> np.ndarray:
"""Replace NaN entries with the column mean."""
arr = np.array(self.raw_data, dtype=float)
mean_val = np.nanmean(arr)
arr[np.isnan(arr)] = mean_val
self.cleaned = arr
return self.cleaned
def remove_outliers(self, z_threshold: float = 3.0) -> np.ndarray:
"""Remove values whose z-score exceeds *z_threshold*."""
if self.cleaned is None:
self.remove_nans()
z_scores = np.abs(stats.zscore(self.cleaned))
self.cleaned = self.cleaned[z_scores < z_threshold]
return self.cleaned
def get_summary(self) -> Dict[str, float]:
"""Return a summary dict of the cleaned data."""
data = self.cleaned if self.cleaned is not None else np.array(self.raw_data)
return {
"count": len(data),
"mean": float(np.mean(data)),
"std": float(np.std(data, ddof=1)),
"min": float(np.min(data)),
"max": float(np.max(data)),
}
# ── Descriptive statistics ──────────────────────────────────────────────────
class DescriptiveStats:
"""Compute descriptive statistics on cleaned data."""
def __init__(self, data: np.ndarray):
self.data = data
def mean(self) -> float:
return float(np.mean(self.data))
def median(self) -> float:
return float(np.median(self.data))
def variance(self) -> float:
return float(np.var(self.data, ddof=1))
def std_dev(self) -> float:
return float(np.std(self.data, ddof=1))
def skewness(self) -> float:
return float(stats.skew(self.data))
def kurtosis(self) -> float:
return float(stats.kurtosis(self.data))
def percentile(self, q: float) -> float:
return float(np.percentile(self.data, q))
def full_report(self) -> Dict[str, float]:
"""Build a complete descriptive-statistics report."""
return {
"mean": self.mean(),
"median": self.median(),
"variance": self.variance(),
"std_dev": self.std_dev(),
"skewness": self.skewness(),
"kurtosis": self.kurtosis(),
"q25": self.percentile(25),
"q75": self.percentile(75),
}
# ── Hypothesis testing ──────────────────────────────────────────────────────
class HypothesisTester:
"""Perform common hypothesis tests."""
def __init__(self, sample_a: np.ndarray, sample_b: Optional[np.ndarray] = None):
self.sample_a = sample_a
self.sample_b = sample_b
def t_test_one_sample(self, pop_mean: float = 0.0) -> Tuple[float, float]:
"""One-sample t-test against *pop_mean*."""
stat, p_value = stats.ttest_1samp(self.sample_a, pop_mean)
return float(stat), float(p_value)
def t_test_two_sample(self) -> Tuple[float, float]:
"""Independent two-sample t-test (equal variance assumed)."""
if self.sample_b is None:
raise ValueError("sample_b is required for a two-sample test.")
stat, p_value = stats.ttest_ind(self.sample_a, self.sample_b)
return float(stat), float(p_value)
def chi_squared_test(self, observed: np.ndarray,
expected: np.ndarray) -> Tuple[float, float]:
"""Chi-squared goodness-of-fit test."""
stat, p_value = stats.chisquare(observed, f_exp=expected)
return float(stat), float(p_value)
def normality_test(self) -> Tuple[float, float]:
"""ShapiroWilk test for normality on sample_a."""
stat, p_value = stats.shapiro(self.sample_a)
return float(stat), float(p_value)
def interpret_result(self, p_value: float, alpha: float = 0.05) -> str:
"""Return a plain-English interpretation of *p_value*."""
if p_value < alpha:
return f"Reject H0 (p={p_value:.4f} < alpha={alpha})"
return f"Fail to reject H0 (p={p_value:.4f} >= alpha={alpha})"
# ── Curve fitting ───────────────────────────────────────────────────────────
class CurveFitter:
"""Fit parametric models to (x, y) data."""
def __init__(self, x: np.ndarray, y: np.ndarray):
self.x = x
self.y = y
self.params: Optional[np.ndarray] = None
self.covariance: Optional[np.ndarray] = None
@staticmethod
def linear_model(x, a, b):
return a * x + b
@staticmethod
def quadratic_model(x, a, b, c):
return a * x**2 + b * x + c
@staticmethod
def exponential_model(x, a, b):
return a * np.exp(b * x)
def fit(self, model_func=None) -> Tuple[np.ndarray, np.ndarray]:
"""Fit *model_func* (default: linear) via least squares."""
if model_func is None:
model_func = self.linear_model
self.params, self.covariance = curve_fit(model_func, self.x, self.y)
return self.params, self.covariance
def predict(self, x_new: np.ndarray, model_func=None) -> np.ndarray:
"""Predict y values for *x_new* using fitted parameters."""
if self.params is None:
raise ValueError("Call fit() before predict().")
if model_func is None:
model_func = self.linear_model
return model_func(x_new, *self.params)
def r_squared(self, model_func=None) -> float:
"""Coefficient of determination R^2."""
y_pred = self.predict(self.x, model_func)
ss_res = np.sum((self.y - y_pred) ** 2)
ss_tot = np.sum((self.y - np.mean(self.y)) ** 2)
return float(1 - ss_res / ss_tot)
# ── Report generation ──────────────────────────────────────────────────────
class ReportGenerator:
"""Compose a statistical report from the above components."""
def __init__(self, title: str):
self.title = title
self.sections: List[Dict] = []
def add_descriptive(self, desc: DescriptiveStats) -> None:
"""Append a descriptive-statistics section."""
self.sections.append({
"type": "descriptive",
"content": desc.full_report(),
})
def add_hypothesis(self, tester: HypothesisTester,
test_name: str, result: Tuple[float, float]) -> None:
"""Append a hypothesis-test section."""
stat, p = result
self.sections.append({
"type": "hypothesis",
"test": test_name,
"statistic": stat,
"p_value": p,
"interpretation": tester.interpret_result(p),
})
def add_curve_fit(self, fitter: CurveFitter, model_name: str) -> None:
"""Append a curve-fitting section."""
self.sections.append({
"type": "curve_fit",
"model": model_name,
"params": fitter.params.tolist() if fitter.params is not None else [],
"r_squared": fitter.r_squared() if fitter.params is not None else None,
})
def to_dict(self) -> dict:
"""Return the full report as a nested dictionary."""
return {"title": self.title, "sections": self.sections}
def save(self, filepath: str) -> None:
"""Write the report to a JSON file."""
save_json(self.to_dict(), filepath)
# ── Pipeline function ──────────────────────────────────────────────────────
def run_analysis_pipeline(raw_data: List[float],
title: str = "Analysis Report") -> dict:
"""End-to-end pipeline: clean -> describe -> test -> report."""
# Step 1: clean
cleaner = DataCleaner(raw_data)
cleaner.remove_nans()
cleaner.remove_outliers()
# Step 2: descriptive statistics
desc = DescriptiveStats(cleaner.cleaned)
report_data = desc.full_report()
# Step 3: normality test
tester = HypothesisTester(cleaner.cleaned)
norm_stat, norm_p = tester.normality_test()
# Step 4: assemble report
report = ReportGenerator(title)
report.add_descriptive(desc)
report.add_hypothesis(tester, "Shapiro-Wilk", (norm_stat, norm_p))
return report.to_dict()
# ── Main ────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
np.random.seed(42)
sample = list(np.random.normal(loc=50, scale=10, size=200))
sample += [float("nan"), float("nan"), 999.0] # inject NaN + outlier
result = run_analysis_pipeline(sample, title="Demo Report")
print(json.dumps(result, indent=2))

View File

@ -0,0 +1,589 @@
"""
Exercise 5b -- Build a Basic AI Coding Agent (Guided Version)
==============================================================
AISE501 . Prompting in Coding . Spring Semester 2026
This is a GUIDED version of Exercise 5 with more scaffolding.
It teaches the same concepts but reduces boilerplate so you can
focus on the key insight: how an LLM uses tools.
The key insight
---------------
An LLM cannot run code or read files by itself. But we can give it
"superpowers" through a simple trick:
1. TELL the LLM (via the system prompt) what tools exist.
2. ASK the LLM to respond with JSON saying which tool to call.
3. PARSE the JSON, call the real Python function, and
4. FEED the result back into the conversation as a new message.
This is how ALL AI coding agents work (Claude Code, Cursor, Copilot).
The LLM never actually "runs" code it just asks us to run it!
What is already provided
------------------------
To let you focus on the interesting parts, the following are PRE-BUILT:
- All 7 tool functions (Part A) read_file, grep_search, etc.
- The tool dispatcher (Part B) maps tool names to functions.
- Helper functions: truncate_result, trim_messages, ask_human.
What you need to build (the interesting parts)
-----------------------------------------------
Part C The SYSTEM PROMPT that teaches the LLM about its tools (TODOs 1-2).
Part D The AGENT LOOP that connects the LLM to the tools (TODOs 3-6).
Part E The INTERACTIVE CHAT interface (TODOs 7-8).
Think of it like wiring a robot:
- Part A+B are the robot's HANDS (already built).
- Part C is the robot's INSTRUCTION MANUAL (you write it).
- Part D is the robot's BRAIN LOOP (you wire it).
- Part E is the ON SWITCH (you connect it).
The conversation flow
---------------------
Here is exactly what happens in one iteration of the agent loop:
messages = [
{"role": "system", "content": "<system prompt>"},
{"role": "user", "content": "Fix the bug in app.py"},
]
LLM generates
JSON response
{"thought": "I should...",
"tool": "read_file",
"arguments": {
"path": "app.py"
}}
You parse JSON,
call read_file()
Append to messages:
{"role":"assistant", "content":..}
{"role":"user", "content":
"<tool_result>file contents │
</tool_result>"} │
Next iteration:
LLM sees result,
picks next tool
"""
import ast
import json
import subprocess
import sys
from pathlib import Path
from server_utils import (
chat, chat_json, get_client, print_messages, print_separator,
strip_code_fences,
)
client = get_client()
# ── Agent Configuration ──────────────────────────────────────────────────────
WORKSPACE = Path(__file__).parent / "workspace"
WORKSPACE.mkdir(exist_ok=True)
MAX_ITERATIONS = 50
MAX_RESULT_LENGTH = 8000
MAX_HISTORY_CHARS = 60000
# ═══════════════════════════════════════════════════════════════════════════════
# PART A -- TOOL FUNCTIONS (pre-built)
# ═══════════════════════════════════════════════════════════════════════════════
#
# These are the tools the agent can use. Each is a normal Python function.
# The LLM will never call these directly — it will OUTPUT JSON saying
# "please call read_file with path='app.py'", and OUR CODE will call it.
def read_file(path: str) -> str:
"""Read a .txt or .py file from the workspace and return its contents."""
target = (WORKSPACE / path).resolve()
if not str(target).startswith(str(WORKSPACE.resolve())):
return "ERROR: path is outside the workspace."
if not target.exists():
return f"ERROR: file '{path}' not found."
if target.suffix not in (".py", ".txt"):
return f"ERROR: can only read .py and .txt files, got '{target.suffix}'."
return target.read_text()
def grep_search(pattern: str, file_glob: str = "*.py") -> str:
"""Search for a pattern in workspace files matching the glob."""
matches = []
for filepath in sorted(WORKSPACE.glob(file_glob)):
if filepath.suffix not in (".py", ".txt"):
continue
try:
lines = filepath.read_text().splitlines()
except Exception:
continue
for i, line in enumerate(lines, 1):
if pattern in line:
rel = filepath.relative_to(WORKSPACE)
matches.append(f"{rel}:{i}: {line}")
if not matches:
return f"No matches for '{pattern}' in {file_glob}."
return "\n".join(matches)
def list_files(file_glob: str = "*") -> str:
"""List files in the workspace matching the glob pattern."""
found = sorted(WORKSPACE.glob(file_glob))
found = [f.relative_to(WORKSPACE) for f in found if f.is_file()]
if not found:
return f"No files matching '{file_glob}' in workspace."
return "\n".join(str(f) for f in found)
def write_file(path: str, content: str) -> str:
"""Write content to a .py or .txt file in the workspace."""
target = (WORKSPACE / path).resolve()
if not str(target).startswith(str(WORKSPACE.resolve())):
return "ERROR: path is outside the workspace."
if target.suffix not in (".py", ".txt"):
return f"ERROR: can only write .py and .txt files, got '{target.suffix}'."
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(content)
return f"OK: wrote {len(content)} chars to {path}."
def run_python(path: str) -> str:
"""Execute a Python file in the workspace and return stdout + stderr."""
target = (WORKSPACE / path).resolve()
if not str(target).startswith(str(WORKSPACE.resolve())):
return "ERROR: path is outside the workspace."
if not target.exists():
return f"ERROR: file '{path}' not found."
result = subprocess.run(
[sys.executable, str(target)],
capture_output=True, text=True, timeout=30,
cwd=str(WORKSPACE),
)
output = ""
if result.stdout:
output += f"STDOUT:\n{result.stdout}"
if result.stderr:
output += f"STDERR:\n{result.stderr}"
output += f"\nExit code: {result.returncode}"
return output.strip()
def validate_python(path: str) -> str:
"""Check whether a Python file has valid syntax using ast.parse."""
target = (WORKSPACE / path).resolve()
if not str(target).startswith(str(WORKSPACE.resolve())):
return "ERROR: path is outside the workspace."
if not target.exists():
return f"ERROR: file '{path}' not found."
source = target.read_text()
try:
ast.parse(source)
return "OK: syntax is valid."
except SyntaxError as e:
return f"SYNTAX ERROR: {e}"
def done(summary: str) -> str:
"""Signal that the agent has finished its task."""
return f"DONE: {summary}"
# ═══════════════════════════════════════════════════════════════════════════════
# PART B -- TOOL DISPATCHER (pre-built)
# ═══════════════════════════════════════════════════════════════════════════════
#
# This is the bridge between the LLM's JSON output and Python function calls.
#
# When the LLM says: {"tool": "read_file", "arguments": {"path": "app.py"}}
# The dispatcher does: TOOL_FUNCTIONS["read_file"](path="app.py")
#
# The **arguments syntax means "unpack the dict as keyword arguments":
# {"path": "app.py"} → read_file(path="app.py")
TOOL_FUNCTIONS = {
"read_file": read_file,
"grep_search": grep_search,
"list_files": list_files,
"write_file": write_file,
"run_python": run_python,
"validate_python": validate_python,
"done": done,
}
def dispatch_tool(tool_name: str, arguments: dict) -> str:
"""Look up a tool by name and call it with the given arguments.
Example:
dispatch_tool("read_file", {"path": "app.py"})
calls read_file(path="app.py")
returns the file contents as a string
"""
if tool_name not in TOOL_FUNCTIONS:
return f"ERROR: unknown tool '{tool_name}'. Available: {list(TOOL_FUNCTIONS.keys())}"
func = TOOL_FUNCTIONS[tool_name]
try:
return func(**arguments)
except TypeError as e:
return f"ERROR calling {tool_name}: {e}"
except Exception as e:
return f"ERROR in {tool_name}: {type(e).__name__}: {e}"
# ═══════════════════════════════════════════════════════════════════════════════
# PART C -- SYSTEM PROMPT (TODOs 1-2)
# ═══════════════════════════════════════════════════════════════════════════════
#
# The system prompt is the MOST IMPORTANT part of the agent. It is the only
# way the LLM knows what tools it has and how to use them.
#
# Think about it: the LLM is just a text model. It has no built-in ability
# to read files or run code. The system prompt is where we TELL it:
# "You have these tools. When you want to use one, output this JSON format.
# I (the code) will parse your JSON, run the tool, and give you the result."
#
# The LLM then "plays along" — it outputs JSON that LOOKS LIKE a tool call,
# and our agent loop code makes it ACTUALLY happen.
# TODO 1: Complete the TOOL_DESCRIPTIONS string below.
# This text will be embedded in the system prompt inside a <tools> section.
# The LLM needs to know:
# - The name of each tool (must match the keys in TOOL_FUNCTIONS above!)
# - What arguments each tool takes
# - What each tool does
#
# Four tools are already described for you as examples.
# Add the missing three: write_file, run_python, validate_python.
#
# Follow the same format:
# - tool_name({"param": "<description>"}): What the tool does.
TOOL_DESCRIPTIONS = """\
- read_file({"path": "<relative path>"}): Read a .py or .txt file from the workspace.
- grep_search({"pattern": "<text>", "file_glob": "<glob, default='*.py'>"}): Search for a pattern in files.
- list_files({"file_glob": "<glob, default='*'>"}): List files matching the pattern.
- done({"summary": "<what you accomplished>"}): Signal that you are finished.
- write_file({"path": "<relative path>", "content": "<new content>"): Write new content to file
- run_python({"path": "<relative path>"}): Run code in python file
- validate_python({"path": "<relative path>"}): Validate code in python file
"""
# TODO 2: Complete the system prompt.
# The structure is provided — fill in the <workflow> and <rules> sections.
#
# For <workflow>, describe these steps:
# 1. PLAN: Think about what steps are needed. List them in "thought".
# 2. ACT: Choose ONE tool to call.
# 3. OBSERVE: Analyse the tool's output carefully.
# 4. REPLAN: If the result was unexpected, revise your plan.
# 5. REPEAT: Go back to ACT if more work is needed.
# 6. DONE: Call the "done" tool when the task is complete.
#
# For <rules>, include at least:
# - Always plan before acting.
# - Call exactly ONE tool per response.
# - After writing code, always validate and run it.
# - If an error occurs, try to fix it (up to 3 retries).
# - Stay within the workspace directory.
# - When finished, call the "done" tool.
#
# IMPORTANT: The JSON example uses {{ and }} because this is an f-string.
# In an f-string, {{ produces a literal { in the output.
# So {{"thought": "..."}} becomes {"thought": "..."} when printed.
SYSTEM_PROMPT = f"""\
You are a coding agent that helps users with Python programming tasks.
You work inside a workspace directory and have access to tools.
<tools>
Available tools:
{TOOL_DESCRIPTIONS}
</tools>
<workflow>
1. PLAN: Think about what steps are needed. List them in "thought".
2. ACT: Choose ONE tool to call.
3. OBSERVE: Analyse the tool's output carefully.
4. REPLAN: If the result was unexpected, revise your plan.
5. REPEAT: Go back to ACT if more work is needed.
6. DONE: Call the "done" tool when the task is complete.
</workflow>
<response_format>
You MUST respond with a JSON object every time. The format is:
{{{{
"thought": "<your reasoning about what to do next>",
"tool": "<tool name from the list above>",
"arguments": {{{{ <arguments for the tool> }}}}
}}}}
Example to read a file:
{{{{
"thought": "I need to read app.py to understand the code.",
"tool": "read_file",
"arguments": {{{{"path": "app.py"}}}}
}}}}
Example to signal completion:
{{{{
"thought": "I have fixed all the bugs and verified the code runs.",
"tool": "done",
"arguments": {{{{"summary": "Fixed 3 bugs in app.py and verified all tests pass."}}}}
}}}}
</response_format>
<rules>
- Always plan before acting.
- Call exactly ONE tool per response.
- After writing code, always validate and run it.
- If an error occurs, try to fix it (up to 3 retries).
- Stay within the workspace directory.
- When finished, call the "done" tool.
</rules>
"""
# ═══════════════════════════════════════════════════════════════════════════════
# PART D -- AGENT LOOP (TODOs 3-6)
# ═══════════════════════════════════════════════════════════════════════════════
#
# This is where everything comes together. The agent loop:
#
# 1. Sends messages to the LLM (including the system prompt with tools).
# 2. The LLM responds with JSON like: {"tool": "read_file", "arguments": {"path": "app.py"}}
# 3. We parse that JSON and call the real Python function.
# 4. We put the result back into the conversation as a new message.
# 5. We send the updated conversation to the LLM again.
# 6. The LLM sees the result and decides what to do next.
# 7. Repeat until the LLM calls "done" or we hit the iteration limit.
def truncate_result(result: str) -> str:
"""Truncate a tool result if it exceeds MAX_RESULT_LENGTH."""
if len(result) <= MAX_RESULT_LENGTH:
return result
half = MAX_RESULT_LENGTH // 2
return (
result[:half]
+ f"\n\n... [TRUNCATED — {len(result)} chars total, showing first and last {half}] ...\n\n"
+ result[-half:]
)
def trim_messages(messages: list) -> list:
"""Trim older messages if total character count exceeds MAX_HISTORY_CHARS."""
total = sum(len(m["content"]) for m in messages)
if total <= MAX_HISTORY_CHARS:
return messages
head = messages[:2]
tail = messages[2:]
original_task = messages[1]["content"] if len(messages) > 1 else ""
while tail and sum(len(m["content"]) for m in head + tail) > MAX_HISTORY_CHARS:
tail.pop(0)
reminder = {
"role": "user",
"content": (
"<system_note>Earlier conversation history was trimmed. "
f"REMINDER — your original task was:\n{original_task}\n"
"Continue from where you left off.</system_note>"
),
}
return head + [reminder] + tail
def ask_human() -> str:
"""Ask the user to approve, redirect, or stop before each action."""
try:
reply = input("\n [Enter]=continue, or type a comment (stop to abort): ").strip()
return reply
except (EOFError, KeyboardInterrupt):
return "stop"
def agent_loop(user_task: str) -> None:
"""Run the agent loop: plan -> user review -> act -> observe -> repeat.
Study this function carefully it IS the agent. Everything else is
just support. The loop implements this cycle:
LLM produces JSON we parse it we call the tool
we feed the result back LLM produces next JSON ...
"""
# TODO 3: Initialise the message list.
# Create a list with two messages:
# 1. {"role": "system", "content": SYSTEM_PROMPT}
# 2. {"role": "user", "content": user_task}
#
# The system message teaches the LLM about its tools.
# The user message is the task to accomplish.
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_task}
]
for iteration in range(1, MAX_ITERATIONS + 1):
print_separator(f"Agent Iteration {iteration}")
messages = trim_messages(messages)
# TODO 4: Get the LLM's next action.
# a) Call chat_json(client, messages, temperature=0.2, max_tokens=4096)
# This sends the conversation to the LLM and forces JSON output.
# The LLM will respond with something like:
# '{"thought": "I need to...", "tool": "read_file", "arguments": {"path": "app.py"}}'
#
# b) Parse the JSON string into a Python dict using json.loads().
# Extract three values:
# thought = action["thought"] — the LLM's reasoning
# tool_name = action["tool"] — which tool to call
# arguments = action["arguments"] — arguments for the tool
#
# c) Handle json.JSONDecodeError: if parsing fails, append the raw
# response as an assistant message and a user message asking for
# valid JSON, then 'continue' to retry.
#
# d) Print the thought, tool, and arguments so we can see what
# the agent is planning.
raw = chat_json(client, messages) # TODO: call chat_json(...)
response = json.loads(raw)
tool_name = response["tool"] # TODO: parse and extract
arguments = response["arguments"] # TODO: parse and extract
thought = response["thought"] # TODO: parse and extract
# TODO 5: Human-in-the-loop — let the user review before execution.
# a) If tool_name == "done", print the summary so the user sees it.
# b) Call ask_human() to get user input.
# c) If user says "stop" → print a message and return.
# d) If user typed a comment (non-empty string):
# - Do NOT execute the tool.
# - Append the assistant's raw JSON as {"role": "assistant", "content": raw}
# - Append a user message with the feedback:
# {"role": "user", "content": f"<human_message>{human}</human_message>\n"
# "Please revise your plan based on this feedback."}
# - Then 'continue' to the next iteration.
# e) If user pressed Enter (empty) → fall through to execute.
if tool_name == "done":
print_messages(messages)
human_response = ask_human()
if human_response == "stop":
print("stop interactive chat")
break
if human_response:
messages.append({"role": "assistant", "content": raw})
messages.append({"role": "user", "content": f"<human_message>{human_response}</human_message>\n" +
"Please revise your plan based on this feedback."})
continue
# TODO 6: Execute the tool and feed the result back.
# a) If tool_name == "done" and user approved:
# - Print the summary and return.
#
# b) Call the tool:
# result = dispatch_tool(tool_name, arguments)
#
# c) Truncate the result:
# result = truncate_result(result)
#
# d) Append TWO messages to the conversation:
# 1. The assistant's response (what the LLM said):
# {"role": "assistant", "content": raw}
# 2. The tool result (what we're telling the LLM happened):
# {"role": "user", "content":
# f'<tool_result tool="{tool_name}">\n{result}\n</tool_result>'}
#
# WHY role="user" for the tool result? Because in the OpenAI chat
# format, messages alternate between assistant and user. The tool
# result is information we (the system) are giving back to the LLM,
# so it goes in a "user" message. The LLM will understand from the
# <tool_result> tags that this is a tool response, not a human message.
#
# e) Print the result for debugging.
if tool_name == "done":
print_messages(messages)
return
result = dispatch_tool(tool_name, arguments)
result = truncate_result(result)
messages.append({"role": "assistant", "content": raw})
messages.append({"role": "user", "content": f'<tool_result tool="{tool_name}">\n{result}\n</tool_result>'})
print(response)
print_separator("Agent stopped (max iterations reached)")
# ═══════════════════════════════════════════════════════════════════════════════
# PART E -- INTERACTIVE CHAT (TODOs 7-8)
# ═══════════════════════════════════════════════════════════════════════════════
# TODO 7: Implement the input loop.
# - Read input with: user_input = input("You> ").strip()
# - Handle EOFError and KeyboardInterrupt (Ctrl+C)
# - Skip empty input
# - Exit on "quit" or "exit"
# - Otherwise call agent_loop(user_input)
def interactive_chat():
"""Run an interactive chat loop where the user gives tasks to the agent."""
print_separator("AI Coding Agent -- Interactive Mode")
print("Type your task and press Enter. Type 'quit' or 'exit' to stop.")
print(f"Workspace: {WORKSPACE.resolve()}\n")
# Show what files are in the workspace
files = [f for f in sorted(WORKSPACE.glob("*")) if f.is_file()]
if files:
print("Files in workspace:")
for f in files:
print(f" {f.name}")
else:
print("Workspace is empty.")
print()
# TODO: implement the input loop
try:
while True:
user_input = input("You> ").strip()
if user_input in ["quit", "exit"]:
sys.exit(0)
elif not user_input:
continue
else:
agent_loop(user_input)
except (KeyboardInterrupt, EOFError):
sys.exit(0)
# ═══════════════════════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════════════════════
if __name__ == "__main__":
# TODO 8: Copy analyze_me.py into the workspace if not already there,
# then start interactive_chat().
# Use: source = Path(__file__).parent / "analyze_me.py"
# dest = WORKSPACE / "analyze_me.py"
interactive_chat()

View File

@ -0,0 +1,26 @@
from typing import Union
Number = Union[int, float, complex]
def divide(divisor: Number, dividend: Number) -> float:
"""
Divide the dividend by the divisor with robust error handling.
Args:
divisor: The number to divide by
dividend: The number to be divided
Returns:
float: The result of the division
Raises:
TypeError: If inputs are not numeric
ZeroDivisionError: If divisor is zero
"""
if divisor == 0:
raise ZeroDivisionError("Cannot divide by zero")
try:
return dividend / divisor
except TypeError:
raise TypeError(f"Both divisor and dividend must be numeric types. Got: {type(divisor).__name__} and {type(dividend).__name__}")