Update Readme and cleanup
This commit is contained in:
		
							parent
							
								
									52d087ce90
								
							
						
					
					
						commit
						b6789cb5f7
					
				
							
								
								
									
										50
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,50 @@
 | 
			
		||||
# Atlas Librarian
 | 
			
		||||
 | 
			
		||||
A comprehensive content processing and management system for extracting, chunking, and vectorizing information from various sources.
 | 
			
		||||
 | 
			
		||||
## Overview
 | 
			
		||||
 | 
			
		||||
Atlas Librarian is a modular system designed to process, organize, and make searchable large amounts of content through web scraping, content extraction, chunking, and vector embeddings.
 | 
			
		||||
 | 
			
		||||
## Project Structure
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
atlas/
 | 
			
		||||
├── librarian/
 | 
			
		||||
│   ├── atlas-librarian/     # Main application
 | 
			
		||||
│   ├── librarian-core/      # Core functionality and storage
 | 
			
		||||
│   └── plugins/
 | 
			
		||||
│       ├── librarian-chunker/    # Content chunking
 | 
			
		||||
│       ├── librarian-extractor/  # Content extraction with AI
 | 
			
		||||
│       ├── librarian-scraper/    # Web scraping and crawling
 | 
			
		||||
│       └── librarian-vspace/     # Vector space operations
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
## Components
 | 
			
		||||
 | 
			
		||||
- **Atlas Librarian**: Main application with API, web app, and recipe management
 | 
			
		||||
- **Librarian Core**: Shared utilities, storage, and Supabase integration
 | 
			
		||||
- **Chunker Plugin**: Splits content into processable chunks
 | 
			
		||||
- **Extractor Plugin**: Extracts and sanitizes content using AI
 | 
			
		||||
- **Scraper Plugin**: Crawls and downloads web content
 | 
			
		||||
- **VSpace Plugin**: Vector embeddings and similarity search
 | 
			
		||||
 | 
			
		||||
## Getting Started
 | 
			
		||||
 | 
			
		||||
1. Clone the repository
 | 
			
		||||
2. Install dependencies for each component
 | 
			
		||||
3. Configure environment variables
 | 
			
		||||
4. Run the main application
 | 
			
		||||
 | 
			
		||||
## Features
 | 
			
		||||
 | 
			
		||||
- Web content scraping and crawling
 | 
			
		||||
- AI-powered content extraction and sanitization
 | 
			
		||||
- Intelligent content chunking
 | 
			
		||||
- Vector embeddings for semantic search
 | 
			
		||||
- Supabase integration for data storage
 | 
			
		||||
- Modular plugin architecture
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
*For detailed documentation, see the individual component directories.*
 | 
			
		||||
@ -3,18 +3,13 @@ import importlib
 | 
			
		||||
 | 
			
		||||
__all__ = []
 | 
			
		||||
 | 
			
		||||
# Iterate over all modules in this package
 | 
			
		||||
for finder, module_name, is_pkg in pkgutil.iter_modules(__path__):
 | 
			
		||||
    # import the sub-module
 | 
			
		||||
    module = importlib.import_module(f"{__name__}.{module_name}")
 | 
			
		||||
 | 
			
		||||
    # decide which names to re-export:
 | 
			
		||||
    #  use module.__all__ if it exists, otherwise every non-private attribute
 | 
			
		||||
    public_names = getattr(
 | 
			
		||||
        module, "__all__", [n for n in dir(module) if not n.startswith("_")]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # bring each name into the package namespace
 | 
			
		||||
    for name in public_names:
 | 
			
		||||
        globals()[name] = getattr(module, name)
 | 
			
		||||
        __all__.append(name) # type: ignore
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,3 @@
 | 
			
		||||
# ------------------------------------------------------ #
 | 
			
		||||
# Workers have to be imported here to be discovered by the worker loader
 | 
			
		||||
# ------------------------------------------------------ #
 | 
			
		||||
from librarian_chunker.chunker import Chunker
 | 
			
		||||
from librarian_extractor.ai_sanitizer import AISanitizer
 | 
			
		||||
from librarian_extractor.extractor import Extractor
 | 
			
		||||
 | 
			
		||||
@ -32,12 +32,9 @@ from atlas_librarian.stores.workers import WORKERS
 | 
			
		||||
router = APIRouter(tags=["recipes"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Pydantic models                                                             #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class RecipeRequest(BaseModel):
 | 
			
		||||
    workers: List[str] = Field(min_length=1)
 | 
			
		||||
    payload: dict  # input of the first worker
 | 
			
		||||
    payload: dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class RecipeMetadata(BaseModel):
 | 
			
		||||
@ -47,17 +44,11 @@ class RecipeMetadata(BaseModel):
 | 
			
		||||
    flow_run_id: str | None = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# in-memory “DB”
 | 
			
		||||
_RECIPES: dict[str, RecipeMetadata] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# routes                                                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
@router.post("/run", status_code=202, response_model=list[FlowArtifact])
 | 
			
		||||
def run_recipe(req: RecipeRequest) -> list[FlowArtifact]:
 | 
			
		||||
    # validation of worker chain
 | 
			
		||||
    for w in req.workers:
 | 
			
		||||
        if w not in WORKERS:
 | 
			
		||||
            raise HTTPException(400, f"Unknown worker: {w}")
 | 
			
		||||
@ -71,7 +62,6 @@ def run_recipe(req: RecipeRequest) -> list[FlowArtifact]:
 | 
			
		||||
    async def _run_worker(worker: type[Worker], art: FlowArtifact) -> FlowArtifact:
 | 
			
		||||
        return await worker.flow()(art)
 | 
			
		||||
 | 
			
		||||
    # Kick off the first worker
 | 
			
		||||
    art: FlowArtifact = anyio.run(_run_worker, start_worker, FlowArtifact(data=payload, run_id=str(uuid.uuid4()), dir=Path(".")))
 | 
			
		||||
    artifacts.append(art)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,9 +16,6 @@ from pydantic import BaseModel
 | 
			
		||||
router = APIRouter(tags=["runs"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# response model                                                              #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class RunInfo(BaseModel):
 | 
			
		||||
    run_id: str
 | 
			
		||||
    worker: str
 | 
			
		||||
@ -27,9 +24,6 @@ class RunInfo(BaseModel):
 | 
			
		||||
    data: dict | None = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# helper                                                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def _open_store(run_id: str) -> WorkerStore:
 | 
			
		||||
    try:
 | 
			
		||||
        return WorkerStore.open(run_id)
 | 
			
		||||
@ -37,16 +31,11 @@ def _open_store(run_id: str) -> WorkerStore:
 | 
			
		||||
        raise HTTPException(status_code=404, detail="Run-id not found") from exc
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# routes                                                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
@router.get("/{run_id}", response_model=RunInfo)
 | 
			
		||||
def get_run(run_id: str) -> RunInfo:
 | 
			
		||||
    """
 | 
			
		||||
    Return coarse-grained information about a single flow run.
 | 
			
		||||
 | 
			
		||||
    For the web-UI we expose only minimal metadata plus the local directory
 | 
			
		||||
    where files were written; clients can read further details from disk.
 | 
			
		||||
    Returns coarse-grained info for a flow run, including local data directory.
 | 
			
		||||
    Web UI uses this for minimal metadata display.
 | 
			
		||||
    """
 | 
			
		||||
    store = _open_store(run_id)
 | 
			
		||||
    meta = store.metadata  # {'worker_name': …, 'state': …, …}
 | 
			
		||||
@ -79,7 +68,6 @@ def get_latest_run(worker_name: str) -> RunInfo | None:
 | 
			
		||||
@router.get("/{run_id}/artifact")
 | 
			
		||||
def get_artifact(run_id: str) -> str:
 | 
			
		||||
    store = _open_store(run_id)
 | 
			
		||||
    # Check if the artifact.md file exists
 | 
			
		||||
    if not store._run_dir.joinpath("artifact.md").exists():
 | 
			
		||||
        raise HTTPException(status_code=404, detail="Artifact not found")
 | 
			
		||||
    return store._run_dir.joinpath("artifact.md").read_text()
 | 
			
		||||
 | 
			
		||||
@ -22,21 +22,15 @@ from atlas_librarian.stores.workers import WORKERS
 | 
			
		||||
router = APIRouter(tags=["workers"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# response schema                                                             #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
class Order(BaseModel):
 | 
			
		||||
    order_id: str
 | 
			
		||||
    worker_name: str
 | 
			
		||||
    payload: dict
 | 
			
		||||
 | 
			
		||||
    # Job is accepted and will be moved in the Job Pool
 | 
			
		||||
    def accept(self):
 | 
			
		||||
        _ORDER_POOL[self.order_id] = self
 | 
			
		||||
        return self.order_id
 | 
			
		||||
 | 
			
		||||
    # Job is completed and will be removed from the Job Pool
 | 
			
		||||
    def complete(self):
 | 
			
		||||
        del _ORDER_POOL[self.order_id]
 | 
			
		||||
 | 
			
		||||
@ -44,9 +38,6 @@ class Order(BaseModel):
 | 
			
		||||
_ORDER_POOL: dict[str, Order] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# helpers                                                                     #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def _get_worker_or_404(name: str):
 | 
			
		||||
    cls = WORKERS.get(name)
 | 
			
		||||
    if cls is None:
 | 
			
		||||
@ -60,13 +51,8 @@ def _get_artifact_from_payload(
 | 
			
		||||
    run_id = payload.get("run_id") or None
 | 
			
		||||
    input_data = cls.input_model.model_validate(payload["data"])
 | 
			
		||||
 | 
			
		||||
    # Making sure the payload is valid
 | 
			
		||||
    return FlowArtifact.new(data=input_data, dir=dir, run_id=run_id)
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# GET Routes                                                                  #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
class WorkerMeta(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    input: str
 | 
			
		||||
@ -89,11 +75,6 @@ def list_orders() -> list[Order]:
 | 
			
		||||
    return list(_ORDER_POOL.values())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# POST Routes                                                                 #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
# ---------- Submit and get the result ----------------------------------------
 | 
			
		||||
@router.post("/{worker_name}/submit", response_model=FlowArtifact, status_code=202)
 | 
			
		||||
def submit_worker(worker_name: str, payload: dict[str, Any]) -> FlowArtifact:
 | 
			
		||||
    cls = _get_worker_or_404(worker_name)
 | 
			
		||||
@ -108,7 +89,6 @@ def submit_worker(worker_name: str, payload: dict[str, Any]) -> FlowArtifact:
 | 
			
		||||
    except Exception as e:
 | 
			
		||||
        raise HTTPException(status_code=500, detail=str(e))
 | 
			
		||||
 | 
			
		||||
# Submit on existing run ----------------------------------------------
 | 
			
		||||
@router.post(
 | 
			
		||||
    "/{worker_name}/submit/{prev_run_id}/chain",
 | 
			
		||||
    response_model=FlowArtifact,
 | 
			
		||||
@ -136,7 +116,6 @@ def submit_chain(worker_name: str, prev_run_id: str) -> FlowArtifact:
 | 
			
		||||
        raise HTTPException(status_code=500, detail=str(e))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Submit and chain, with the latest output of a worker
 | 
			
		||||
@router.post(
 | 
			
		||||
    "/{worker_name}/submit/{prev_worker_name}/chain/latest",
 | 
			
		||||
    response_model=FlowArtifact | None,
 | 
			
		||||
@ -162,14 +141,12 @@ def submit_chain_latest(worker_name: str, prev_worker_name: str) -> FlowArtifact
 | 
			
		||||
    return cls.submit(art)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------- Place an Order and get a receipt ----------------------------------------------------
 | 
			
		||||
@router.post("/{worker_name}/order", response_model=Order, status_code=202)
 | 
			
		||||
def place_order(worker_name: str, payload: dict[str, Any]) -> Order:
 | 
			
		||||
    cls = _get_worker_or_404(worker_name)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        art = _get_artifact_from_payload(payload, cls)
 | 
			
		||||
        # order_id = str(uuid.uuid4())
 | 
			
		||||
        order_id = "731ce6ef-ccdc-44bd-b152-da126f104db1"
 | 
			
		||||
        order = Order(order_id=order_id, worker_name=worker_name, payload=art.model_dump())
 | 
			
		||||
        order.accept()
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
# atlas_librarian/app.py
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
@ -8,7 +7,6 @@ from librarian_core.utils.secrets_loader import load_env
 | 
			
		||||
from atlas_librarian.api import recipes_router, runs_router, worker_router
 | 
			
		||||
from atlas_librarian.stores import discover_workers
 | 
			
		||||
 | 
			
		||||
# Application description for OpenAPI docs
 | 
			
		||||
APP_DESCRIPTION = """
 | 
			
		||||
Atlas Librarian API Gateway 🚀
 | 
			
		||||
 | 
			
		||||
@ -32,11 +30,10 @@ def create_app() -> FastAPI:
 | 
			
		||||
 | 
			
		||||
    app = FastAPI(
 | 
			
		||||
        title="Atlas Librarian API",
 | 
			
		||||
        version="0.1.0",  # Use semantic versioning
 | 
			
		||||
        version="0.1.0",
 | 
			
		||||
        description=APP_DESCRIPTION,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Configure CORS
 | 
			
		||||
    app.add_middleware(
 | 
			
		||||
        CORSMiddleware,
 | 
			
		||||
        allow_origins=["*"],
 | 
			
		||||
@ -45,23 +42,18 @@ def create_app() -> FastAPI:
 | 
			
		||||
        allow_headers=["*"],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Include all API routers
 | 
			
		||||
    app.include_router(worker_router, prefix=f"{API_PREFIX}/worker")
 | 
			
		||||
    app.include_router(runs_router, prefix=f"{API_PREFIX}/runs")
 | 
			
		||||
    app.include_router(recipes_router, prefix=f"{API_PREFIX}/recipes")
 | 
			
		||||
 | 
			
		||||
    return app
 | 
			
		||||
 | 
			
		||||
# Create the app instance
 | 
			
		||||
app = create_app()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.get("/", tags=["Root"], summary="API Root/Health Check")
 | 
			
		||||
async def read_root():
 | 
			
		||||
    """
 | 
			
		||||
    Provides a basic health check endpoint.
 | 
			
		||||
    Returns a welcome message indicating the API is running.
 | 
			
		||||
    """
 | 
			
		||||
    """Basic health check endpoint."""
 | 
			
		||||
    return {"message": "Welcome to Atlas Librarian API"}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -71,14 +71,12 @@ recipes = [
 | 
			
		||||
            # "download",
 | 
			
		||||
            "extract",
 | 
			
		||||
        ],
 | 
			
		||||
        # Default-Parameters
 | 
			
		||||
        "parameters": {
 | 
			
		||||
            "crawl": example_study_program,
 | 
			
		||||
            "download": example_moodle_index,
 | 
			
		||||
            "extract": example_downloaded_courses,
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    # All steps in one go
 | 
			
		||||
    {
 | 
			
		||||
        "name": "quick-all",
 | 
			
		||||
        "steps": ["crawl", "download", "extract"],
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,3 @@
 | 
			
		||||
# atlas_librarian/stores/worker_store.py
 | 
			
		||||
"""
 | 
			
		||||
Auto-discovers every third-party Worker package that exposes an entry-point
 | 
			
		||||
 | 
			
		||||
@ -26,32 +25,24 @@ except ImportError:  # Py < 3.10 → fall back to back-port
 | 
			
		||||
 | 
			
		||||
from librarian_core.workers.base import Worker
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
WORKERS: Dict[str, Type[Worker]] = {}  # key = Worker.worker_name
 | 
			
		||||
WORKERS: Dict[str, Type[Worker]] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------------------
 | 
			
		||||
def _register_worker_class(obj: object) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Inspect *obj* and register it if it looks like a Worker subclass
 | 
			
		||||
    produced by the metaclass in *librarian_core*.
 | 
			
		||||
    """
 | 
			
		||||
    """Registers valid Worker subclasses."""
 | 
			
		||||
    if (
 | 
			
		||||
        inspect.isclass(obj)
 | 
			
		||||
        and issubclass(obj, Worker)
 | 
			
		||||
        and obj is not Worker  # not the abstract base
 | 
			
		||||
        and obj is not Worker
 | 
			
		||||
    ):
 | 
			
		||||
        WORKERS[obj.worker_name] = obj  # type: ignore[arg-type]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------------------
 | 
			
		||||
def _import_ep(ep: EntryPoint):
 | 
			
		||||
    """Load the object referenced by an entry-point."""
 | 
			
		||||
    return ep.load()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------------------
 | 
			
		||||
def discover_workers(group: str = "librarian.worker") -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Discover all entry-points of *group* and populate ``WORKERS``.
 | 
			
		||||
@ -70,13 +61,11 @@ def discover_workers(group: str = "librarian.worker") -> None:
 | 
			
		||||
            print(f"[Worker-Loader] Failed to load entry-point {ep!r}: {exc}")
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        # If a module was loaded, inspect its attributes; else try registering directly
 | 
			
		||||
        if isinstance(loaded, ModuleType):
 | 
			
		||||
            for attr in loaded.__dict__.values():
 | 
			
		||||
                _register_worker_class(attr)
 | 
			
		||||
        else:
 | 
			
		||||
            _register_worker_class(loaded)
 | 
			
		||||
 | 
			
		||||
    # Register any Worker subclasses imported directly (e.g., loaded via atlas_librarian/api/__init__)
 | 
			
		||||
    for cls in Worker.__subclasses__():
 | 
			
		||||
        _register_worker_class(cls)
 | 
			
		||||
 | 
			
		||||
@ -27,9 +27,6 @@ from librarian_core.utils import path_utils
 | 
			
		||||
class WorkerStore:
 | 
			
		||||
    """Never exposed to worker code – all access is via helper methods."""
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # constructors                                                       #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def new(cls, *, worker_name: str, flow_id: str) -> "WorkerStore":
 | 
			
		||||
        run_dir = path_utils.get_run_dir(worker_name, flow_id, create=True)
 | 
			
		||||
@ -53,9 +50,6 @@ class WorkerStore:
 | 
			
		||||
                return cls(candidate, meta["worker_name"], run_id)
 | 
			
		||||
        raise FileNotFoundError(run_id)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # life-cycle                                                         #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def __init__(self, run_dir: Path, worker_name: str, flow_id: str):
 | 
			
		||||
        self._run_dir = run_dir
 | 
			
		||||
        self._worker_name = worker_name
 | 
			
		||||
@ -71,9 +65,6 @@ class WorkerStore:
 | 
			
		||||
        self._entry_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
        self._exit_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # entry / exit handling                                              #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @property
 | 
			
		||||
    def entry_dir(self) -> Path:
 | 
			
		||||
        return self._entry_dir
 | 
			
		||||
@ -115,9 +106,6 @@ class WorkerStore:
 | 
			
		||||
                shutil.copy2(src_path, dst)
 | 
			
		||||
        return dst
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # result persistence                                                 #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def save_model(
 | 
			
		||||
        self,
 | 
			
		||||
        model: BaseModel,
 | 
			
		||||
@ -145,9 +133,6 @@ class WorkerStore:
 | 
			
		||||
    def cleanup(self) -> None:
 | 
			
		||||
        shutil.rmtree(self._work_dir, ignore_errors=True)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # public helpers (API needs these)                                   #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @property
 | 
			
		||||
    def data_dir(self) -> Path:
 | 
			
		||||
        return self._run_dir / "data"
 | 
			
		||||
@ -201,16 +186,12 @@ class WorkerStore:
 | 
			
		||||
 | 
			
		||||
        latest_run_dir = sorted_runs[-1][1]
 | 
			
		||||
 | 
			
		||||
        # Load the model
 | 
			
		||||
        return {  # That is a FlowArtifact
 | 
			
		||||
        return {
 | 
			
		||||
            "run_id": latest_run_dir.name,
 | 
			
		||||
            "dir": latest_run_dir / "data",
 | 
			
		||||
            "data": WorkerStore.open(latest_run_dir.name).load_model(as_dict=True),  # type: ignore
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # internals                                                          #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def _write_meta(self, *, state: str) -> None:
 | 
			
		||||
        meta = {
 | 
			
		||||
            "worker_name": self._worker_name,
 | 
			
		||||
@ -233,9 +214,6 @@ class WorkerStore:
 | 
			
		||||
        except Exception:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # clean-up                                                           #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def __del__(self) -> None:
 | 
			
		||||
        try:
 | 
			
		||||
            shutil.rmtree(self._work_dir, ignore_errors=True)
 | 
			
		||||
 | 
			
		||||
@ -45,7 +45,6 @@ class SupabaseGateway:
 | 
			
		||||
        self.client = get_client()
 | 
			
		||||
        self.schema = _cfg.db_schema if _cfg else "library"
 | 
			
		||||
 | 
			
		||||
    # ---------- internal ----------
 | 
			
		||||
    def _rpc(self, fn: str, payload: Dict[str, Any] | None = None):
 | 
			
		||||
        resp = (
 | 
			
		||||
            self.client.schema(self.schema)
 | 
			
		||||
 | 
			
		||||
@ -15,16 +15,10 @@ from prefect.artifacts import acreate_markdown_artifact
 | 
			
		||||
 | 
			
		||||
from librarian_core.storage.worker_store import WorkerStore
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# type parameters                                                             #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
InT = TypeVar("InT", bound=BaseModel)
 | 
			
		||||
OutT = TypeVar("OutT", bound=BaseModel)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# envelope returned by every worker flow                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class FlowArtifact(BaseModel, Generic[OutT]):
 | 
			
		||||
    run_id: str | None = None
 | 
			
		||||
    dir: Path | None = None
 | 
			
		||||
@ -34,23 +28,17 @@ class FlowArtifact(BaseModel, Generic[OutT]):
 | 
			
		||||
    def new(cls, run_id: str | None = None, dir: Path | None = None, data: OutT | None = None) -> FlowArtifact:
 | 
			
		||||
        if not data:
 | 
			
		||||
            raise ValueError("data is required")
 | 
			
		||||
        # Intermediate Worker
 | 
			
		||||
        if run_id and dir:
 | 
			
		||||
            return FlowArtifact(run_id=run_id, dir=dir, data=data)
 | 
			
		||||
 | 
			
		||||
        # Initial Worker
 | 
			
		||||
        else:
 | 
			
		||||
            return FlowArtifact(data=data)
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# metaclass: adds a Prefect flow + envelope to each Worker                    #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class _WorkerMeta(type):
 | 
			
		||||
    def __new__(mcls, name, bases, ns, **kw):
 | 
			
		||||
        cls = super().__new__(mcls, name, bases, dict(ns))
 | 
			
		||||
 | 
			
		||||
        if name == "Worker" and cls.__module__ == __name__:
 | 
			
		||||
            return cls  # abstract base
 | 
			
		||||
            return cls
 | 
			
		||||
 | 
			
		||||
        if not (hasattr(cls, "input_model") and hasattr(cls, "output_model")):
 | 
			
		||||
            raise TypeError(f"{name}: declare 'input_model' / 'output_model'.")
 | 
			
		||||
@ -62,10 +50,9 @@ class _WorkerMeta(type):
 | 
			
		||||
        cls._prefect_flow = mcls._build_prefect_flow(cls) # type: ignore
 | 
			
		||||
        return cls
 | 
			
		||||
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _build_prefect_flow(cls_ref):
 | 
			
		||||
        """Create the Prefect flow and return it."""
 | 
			
		||||
        """Builds the Prefect flow for the worker."""
 | 
			
		||||
        InArt = cls_ref.input_artifact  # noqa: F841
 | 
			
		||||
        OutModel: type[BaseModel] = cls_ref.output_model  # noqa: F841
 | 
			
		||||
        worker_name: str = cls_ref.worker_name
 | 
			
		||||
@ -82,9 +69,7 @@ class _WorkerMeta(type):
 | 
			
		||||
 | 
			
		||||
            inst = cls_ref()
 | 
			
		||||
            inst._inject_store(store)
 | 
			
		||||
            # run worker ------------------------------------------------
 | 
			
		||||
            run_res = inst.__run__(in_art.data)
 | 
			
		||||
            # allow sync or async implementations
 | 
			
		||||
            if inspect.iscoroutine(run_res):
 | 
			
		||||
                result = await run_res
 | 
			
		||||
            else:
 | 
			
		||||
@ -104,7 +89,6 @@ class _WorkerMeta(type):
 | 
			
		||||
                description=f"{worker_name} output"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # save the markdown artifact in the flow directory
 | 
			
		||||
            md_file = store._run_dir / "artifact.md"
 | 
			
		||||
            md_file.write_text(md_table)
 | 
			
		||||
 | 
			
		||||
@ -112,9 +96,8 @@ class _WorkerMeta(type):
 | 
			
		||||
 | 
			
		||||
        return flow(name=worker_name, log_prints=True)(_core)
 | 
			
		||||
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    def _create_input_artifact(cls):
 | 
			
		||||
        """Create & attach a pydantic model ‹InputArtifact› = {dir?, data}."""
 | 
			
		||||
        """Creates the `InputArtifact` model for the worker."""
 | 
			
		||||
        DirField = (Path | None, None)
 | 
			
		||||
        DataField = (cls.input_model, ...)  # type: ignore # required
 | 
			
		||||
        art_name = f"{cls.__name__}InputArtifact"
 | 
			
		||||
@ -124,35 +107,27 @@ class _WorkerMeta(type):
 | 
			
		||||
        cls.input_artifact = artifact  # type: ignore[attr-defined]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# public Worker base                                                          #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class Worker(Generic[InT, OutT], metaclass=_WorkerMeta):
 | 
			
		||||
    """
 | 
			
		||||
    Derive from this class, set *input_model* / *output_model*, and implement
 | 
			
		||||
    an **async** ``__run__(payload: input_model)``.
 | 
			
		||||
    Base class for workers. Subclasses must define `input_model`,
 | 
			
		||||
    `output_model`, and implement `__run__`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    input_model: ClassVar[type[BaseModel]]
 | 
			
		||||
    output_model: ClassVar[type[BaseModel]]
 | 
			
		||||
    input_artifact: ClassVar[type[BaseModel]]  # injected by metaclass
 | 
			
		||||
    input_artifact: ClassVar[type[BaseModel]]
 | 
			
		||||
    worker_name: ClassVar[str]
 | 
			
		||||
    _prefect_flow: ClassVar[Callable[[FlowArtifact[InT]], Awaitable[FlowArtifact[OutT]]]]
 | 
			
		||||
 | 
			
		||||
    # injected at runtime
 | 
			
		||||
    entry: Path
 | 
			
		||||
    _store: WorkerStore
 | 
			
		||||
    entry: Path  # The entry directory for the worker's temporary files
 | 
			
		||||
    _store: WorkerStore # The WorkerStore instance for this run
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # internal wiring                                                    #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def _inject_store(self, store: WorkerStore) -> None:
 | 
			
		||||
        """Injects WorkerStore and sets entry directory."""
 | 
			
		||||
        self._store = store
 | 
			
		||||
        self.entry = store.entry_dir
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # developer helper                                                   #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # Helper method for staging files
 | 
			
		||||
    def stage(
 | 
			
		||||
        self,
 | 
			
		||||
        src: Path | str,
 | 
			
		||||
@ -163,30 +138,26 @@ class Worker(Generic[InT, OutT], metaclass=_WorkerMeta):
 | 
			
		||||
    ) -> Path:
 | 
			
		||||
        return self._store.stage(src, new_name=new_name, sanitize=sanitize, move=move)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # convenience wrappers                                               #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def flow(cls):
 | 
			
		||||
        """Return the auto-generated Prefect flow."""
 | 
			
		||||
        """Returns the Prefect flow."""
 | 
			
		||||
        return cls._prefect_flow
 | 
			
		||||
 | 
			
		||||
    # submit variants --------------------------------------------------- #
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def submit(cls, payload: FlowArtifact[InT]) -> FlowArtifact[OutT]:
 | 
			
		||||
        """Submits payload to the Prefect flow."""
 | 
			
		||||
        async def _runner():
 | 
			
		||||
            art = await cls._prefect_flow(payload)  # type: ignore[arg-type]
 | 
			
		||||
            return art
 | 
			
		||||
 | 
			
		||||
        return anyio.run(_runner)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # abstract                                                           #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    async def __run__(self, payload: InT) -> OutT: ...
 | 
			
		||||
    async def __run__(self, payload: InT) -> OutT:
 | 
			
		||||
        """Core logic, implemented by subclasses."""
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # Should be overridden by the worker
 | 
			
		||||
    async def _to_markdown(self, data: OutT) -> str:
 | 
			
		||||
        """Converts output to Markdown. Override for custom format."""
 | 
			
		||||
        md_table = pd.DataFrame([data.dict()]).to_markdown(index=False)
 | 
			
		||||
        return md_table
 | 
			
		||||
 | 
			
		||||
@ -34,18 +34,15 @@ class Chunker(Worker[ExtractData, ChunkData]):
 | 
			
		||||
 | 
			
		||||
        working_dir = get_temp_path()
 | 
			
		||||
 | 
			
		||||
        # load NLP and tokenizer
 | 
			
		||||
        Chunker.nlp = spacy.load("xx_ent_wiki_sm")
 | 
			
		||||
        Chunker.nlp.add_pipe("sentencizer")
 | 
			
		||||
        Chunker.enc = tiktoken.get_encoding("cl100k_base")
 | 
			
		||||
 | 
			
		||||
        # chunk parameters
 | 
			
		||||
        Chunker.max_tokens = MAX_TOKENS
 | 
			
		||||
        Chunker.overlap_tokens = OVERLAP_TOKENS
 | 
			
		||||
 | 
			
		||||
        result = ChunkData(terms=[])
 | 
			
		||||
 | 
			
		||||
        # Loading files
 | 
			
		||||
        for term in payload.terms:
 | 
			
		||||
            chunked_term = ChunkedTerm(id=term.id, name=term.name)
 | 
			
		||||
            in_term_dir = self.entry / term.name
 | 
			
		||||
@ -79,7 +76,6 @@ class Chunker(Worker[ExtractData, ChunkData]):
 | 
			
		||||
 | 
			
		||||
                chunked_term.courses.append(chunked_course)
 | 
			
		||||
 | 
			
		||||
            # Add the chunked term to the result
 | 
			
		||||
            result.terms.append(chunked_term)
 | 
			
		||||
            self.stage(out_term_dir)
 | 
			
		||||
 | 
			
		||||
@ -94,16 +90,11 @@ class Chunker(Worker[ExtractData, ChunkData]):
 | 
			
		||||
        lg.info(f"Chapter path: {chapter_path}")
 | 
			
		||||
        lg.info(f"Out course dir: {out_course_dir}")
 | 
			
		||||
 | 
			
		||||
        # Extract the Text
 | 
			
		||||
        file_text = Chunker._extract_text(chapter_path / f.name)
 | 
			
		||||
 | 
			
		||||
        # Chunk the Text
 | 
			
		||||
        chunks = Chunker._chunk_text(file_text, f.name, out_course_dir)
 | 
			
		||||
 | 
			
		||||
        images_dir = out_course_dir / "images"
 | 
			
		||||
        images_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
        # Extract the Images
 | 
			
		||||
        images = Chunker._extract_images(chapter_path / f.name, images_dir)
 | 
			
		||||
 | 
			
		||||
        return chunks, images
 | 
			
		||||
@ -128,28 +119,22 @@ class Chunker(Worker[ExtractData, ChunkData]):
 | 
			
		||||
    def _chunk_text(text: str, f_name: str, out_course_dir: Path) -> list[Chunk]:
 | 
			
		||||
        lg = get_run_logger()
 | 
			
		||||
        lg.info(f"Chunking text for file {f_name}")
 | 
			
		||||
        # split text into sentences and get tokens
 | 
			
		||||
        nlp_doc = Chunker.nlp(text)
 | 
			
		||||
        sentences = [sent.text.strip() for sent in nlp_doc.sents]
 | 
			
		||||
        sentence_token_counts = [len(Chunker.enc.encode(s)) for s in sentences]
 | 
			
		||||
        lg.info(f"Extracted {len(sentences)} sentences with token counts: {sentence_token_counts}")
 | 
			
		||||
 | 
			
		||||
        # Buffers
 | 
			
		||||
        chunks: list[Chunk] = []
 | 
			
		||||
        current_chunk = []
 | 
			
		||||
        current_token_total = 0
 | 
			
		||||
 | 
			
		||||
        chunk_id = 0
 | 
			
		||||
 | 
			
		||||
        for s, tc in zip(sentences, sentence_token_counts):  # Pair sentences and tokens
 | 
			
		||||
            if tc + current_token_total <= MAX_TOKENS:  # Check Token limit
 | 
			
		||||
                # Add sentences to chunk
 | 
			
		||||
        for s, tc in zip(sentences, sentence_token_counts):
 | 
			
		||||
            if tc + current_token_total <= MAX_TOKENS:
 | 
			
		||||
                current_chunk.append(s)
 | 
			
		||||
                current_token_total += tc
 | 
			
		||||
            else:
 | 
			
		||||
                # Flush Chunk
 | 
			
		||||
                chunk_text = "\n\n".join(current_chunk)
 | 
			
		||||
 | 
			
		||||
                chunk_name = f"{f_name}_{chunk_id}"
 | 
			
		||||
                with open(
 | 
			
		||||
                    out_course_dir / f"{chunk_name}.md", "w", encoding="utf-8"
 | 
			
		||||
@ -165,14 +150,12 @@ class Chunker(Worker[ExtractData, ChunkData]):
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # Get Overlap from Chunk
 | 
			
		||||
                token_ids = Chunker.enc.encode(chunk_text)
 | 
			
		||||
                overlap_ids = token_ids[-OVERLAP_TOKENS :]
 | 
			
		||||
                overlap_text = Chunker.enc.decode(overlap_ids)
 | 
			
		||||
                overlap_doc = Chunker.nlp(overlap_text)
 | 
			
		||||
                overlap_sents = [sent.text for sent in overlap_doc.sents]
 | 
			
		||||
 | 
			
		||||
                # Start new Chunk
 | 
			
		||||
                current_chunk = overlap_sents + [s]
 | 
			
		||||
                current_token_total = sum(
 | 
			
		||||
                    len(Chunker.enc.encode(s)) for s in current_chunk
 | 
			
		||||
 | 
			
		||||
@ -2,10 +2,6 @@ from typing import List
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Output models                                                               #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Chunk(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
 | 
			
		||||
@ -32,9 +32,6 @@ from librarian_extractor.models.extract_data import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# helpers                                                                     #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def _clean_json(txt: str) -> str:
 | 
			
		||||
    txt = txt.strip()
 | 
			
		||||
    if txt.startswith("```"):
 | 
			
		||||
@ -51,7 +48,7 @@ def _safe_json_load(txt: str) -> dict:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _merge_with_original(src: ExtractedCourse, patch: dict, lg) -> ExtractedCourse:
 | 
			
		||||
    """Return *patch* merged with *src* so every id is preserved."""
 | 
			
		||||
    """Merges LLM patch with source, preserving IDs."""
 | 
			
		||||
    try:
 | 
			
		||||
        tgt = ExtractedCourse.model_validate(patch)
 | 
			
		||||
    except ValidationError as err:
 | 
			
		||||
@ -73,9 +70,6 @@ def _merge_with_original(src: ExtractedCourse, patch: dict, lg) -> ExtractedCour
 | 
			
		||||
    return tgt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# OpenAI call – Prefect task                                                  #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
@task(
 | 
			
		||||
    name="sanitize_course_json",
 | 
			
		||||
    retries=2,
 | 
			
		||||
@ -100,9 +94,6 @@ def sanitize_course_json(course_json: str, model: str, temperature: float) -> di
 | 
			
		||||
    return _safe_json_load(rsp.choices[0].message.content or "{}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Worker                                                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
    input_model = ExtractData
 | 
			
		||||
    output_model = ExtractData
 | 
			
		||||
@ -112,14 +103,12 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
        self.model_name = model_name or os.getenv("OPENAI_MODEL", "gpt-4o-mini")
 | 
			
		||||
        self.temperature = temperature
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def __run__(self, data: ExtractData) -> ExtractData:
 | 
			
		||||
        lg = get_run_logger()
 | 
			
		||||
 | 
			
		||||
        futures: List[PrefectFuture] = []
 | 
			
		||||
        originals: List[ExtractedCourse] = []
 | 
			
		||||
 | 
			
		||||
        # 1) submit all courses to the LLM
 | 
			
		||||
        for term in data.terms:
 | 
			
		||||
            for course in term.courses:
 | 
			
		||||
                futures.append(
 | 
			
		||||
@ -133,7 +122,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
 | 
			
		||||
        wait(futures)
 | 
			
		||||
 | 
			
		||||
        # 2) build new graph with merged results
 | 
			
		||||
        terms_out: List[ExtractedTerm] = []
 | 
			
		||||
        idx = 0
 | 
			
		||||
        for term in data.terms:
 | 
			
		||||
@ -149,16 +137,12 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
 | 
			
		||||
        renamed = ExtractData(terms=terms_out)
 | 
			
		||||
 | 
			
		||||
        # 3) stage files with their new names
 | 
			
		||||
        self._export_with_new_names(data, renamed, lg)
 | 
			
		||||
 | 
			
		||||
        return renamed
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # staging helpers                                                    #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def _stage_or_warn(self, src: Path, dst: Path, lg):
 | 
			
		||||
        """Copy *src* → *dst* (via self.stage). Warn if src missing."""
 | 
			
		||||
        """Stages file, warns if source missing."""
 | 
			
		||||
        if not src.exists():
 | 
			
		||||
            lg.warning("Source missing – skipped %s", src)
 | 
			
		||||
            return
 | 
			
		||||
@ -175,7 +159,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
 | 
			
		||||
        for term_old, term_new in zip(original.terms, renamed.terms):
 | 
			
		||||
            for course_old, course_new in zip(term_old.courses, term_new.courses):
 | 
			
		||||
                # ---------- content files (per chapter) -----------------
 | 
			
		||||
                for chap_old, chap_new in zip(course_old.chapters, course_new.chapters):
 | 
			
		||||
                    n = min(len(chap_old.content_files), len(chap_new.content_files))
 | 
			
		||||
                    for i in range(n):
 | 
			
		||||
@ -196,7 +179,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
                        )
 | 
			
		||||
                        self._stage_or_warn(src, dst, lg)
 | 
			
		||||
 | 
			
		||||
                # ---------- media files (course-level “media”) ----------
 | 
			
		||||
                src_media_dir = (
 | 
			
		||||
                    entry / term_old.name / course_old.name / "media"
 | 
			
		||||
                )  # <─ fixed!
 | 
			
		||||
@ -204,7 +186,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
 | 
			
		||||
                if not src_media_dir.is_dir():
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                # build a flat list of (old, new) media filenames
 | 
			
		||||
                media_pairs: List[tuple[ExtractedFile, ExtractedFile]] = []
 | 
			
		||||
                for ch_o, ch_n in zip(course_old.chapters, course_new.chapters):
 | 
			
		||||
                    media_pairs.extend(zip(ch_o.media_files, ch_n.media_files))
 | 
			
		||||
 | 
			
		||||
@ -2,9 +2,6 @@
 | 
			
		||||
Shared lists and prompts
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
# file selection – keep only real documents we can show / convert      #
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
CONTENT_FILE_EXTENSIONS = [
 | 
			
		||||
    "*.pdf",
 | 
			
		||||
    "*.doc",
 | 
			
		||||
@ -26,9 +23,6 @@ MEDIA_FILE_EXTENSIONS = [
 | 
			
		||||
    "*.mp3",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
# naming rules                                                         #
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
SANITIZE_REGEX = {
 | 
			
		||||
    "base": [r"\s*\(\d+\)$"],
 | 
			
		||||
    "course": [
 | 
			
		||||
 | 
			
		||||
@ -51,9 +51,6 @@ ALL_EXTS = CONTENT_EXTS | MEDIA_EXTS
 | 
			
		||||
_id_rx = re.compile(r"\.(\d{4,})[./]")  # 1172180 from “..._.1172180/index.html”
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# helpers                                                                     #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def _hash_id(fname: str) -> str:
 | 
			
		||||
    return hashlib.sha1(fname.encode()).hexdigest()[:10]
 | 
			
		||||
 | 
			
		||||
@ -80,17 +77,14 @@ def _best_payload(node: Path) -> Path | None:  # noqa: C901
 | 
			
		||||
    • File_xxx/dir      → search inside /content or dir itself
 | 
			
		||||
    • File_xxx/index.html stub → parse to find linked file
 | 
			
		||||
    """
 | 
			
		||||
    # 1) immediate hit
 | 
			
		||||
    if node.is_file() and node.suffix.lower() in ALL_EXTS:
 | 
			
		||||
        return node
 | 
			
		||||
 | 
			
		||||
    # 2) if html stub try to parse inner link
 | 
			
		||||
    if node.is_file() and node.suffix.lower() in {".html", ".htm"}:
 | 
			
		||||
        hinted = _html_stub_target(node)
 | 
			
		||||
        if hinted:
 | 
			
		||||
            return _best_payload(hinted)  # recurse
 | 
			
		||||
            return _best_payload(hinted)
 | 
			
		||||
 | 
			
		||||
    # 3) directories to search
 | 
			
		||||
    roots: list[Path] = []
 | 
			
		||||
    if node.is_dir():
 | 
			
		||||
        roots.append(node)
 | 
			
		||||
@ -120,9 +114,6 @@ def task_(**kw):
 | 
			
		||||
    return task(**kw)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Worker                                                                      #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class Extractor(Worker[DownloadData, ExtractData]):
 | 
			
		||||
    input_model = DownloadData
 | 
			
		||||
    output_model = ExtractData
 | 
			
		||||
@ -158,7 +149,6 @@ class Extractor(Worker[DownloadData, ExtractData]):
 | 
			
		||||
        lg.info("Extractor finished – %d terms", len(result.terms))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @task_()
 | 
			
		||||
    def _extract_course(  # noqa: C901
 | 
			
		||||
@ -238,9 +228,6 @@ class Extractor(Worker[DownloadData, ExtractData]):
 | 
			
		||||
        finally:
 | 
			
		||||
            shutil.rmtree(tmp, ignore_errors=True)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # internal helpers                                                   #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _copy_all(
 | 
			
		||||
        root: Path, dst_root: Path, c_meta: ExtractedCourse, media_dir: Path, lg
 | 
			
		||||
 | 
			
		||||
@ -5,24 +5,24 @@ from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
class ExtractedFile(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str  # Name of the file, relative to ExtractedChapter.name
 | 
			
		||||
    name: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExtractedChapter(BaseModel):
 | 
			
		||||
    name: str  # Name of the chapter directory, relative to ExtractedCourse.name
 | 
			
		||||
    name: str
 | 
			
		||||
    content_files: List[ExtractedFile] = Field(default_factory=list)
 | 
			
		||||
    media_files: List[ExtractedFile] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExtractedCourse(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str  # Name of the course directory, relative to ExtractedTerm.name
 | 
			
		||||
    name: str
 | 
			
		||||
    chapters: List[ExtractedChapter] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExtractedTerm(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str  # Name of the term directory, relative to ExtractMeta.dir
 | 
			
		||||
    name: str
 | 
			
		||||
    courses: List[ExtractedCourse] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,3 @@
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
# LLM prompts                                                          #
 | 
			
		||||
# -------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
PROMPT_COURSE = """
 | 
			
		||||
General naming rules
 | 
			
		||||
====================
 | 
			
		||||
 | 
			
		||||
@ -1,8 +1,4 @@
 | 
			
		||||
"""
 | 
			
		||||
URLs used by the scraper.
 | 
			
		||||
Functions marked as PUBLIC can be accessed without authentication.
 | 
			
		||||
Functions marked as PRIVATE require authentication.
 | 
			
		||||
"""
 | 
			
		||||
"""Scraper URLs. PUBLIC/PRIVATE indicates auth requirement."""
 | 
			
		||||
 | 
			
		||||
BASE_URL = "https://moodle.fhgr.ch"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,20 +12,8 @@ from librarian_scraper.constants import PRIVATE_URLS, PUBLIC_URLS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CookieCrawler:
 | 
			
		||||
    """
 | 
			
		||||
    Retrieve Moodle session cookies + sesskey via Playwright.
 | 
			
		||||
    """Retrieve Moodle session cookies + sesskey via Playwright."""
 | 
			
		||||
 | 
			
		||||
    Usage
 | 
			
		||||
    -----
 | 
			
		||||
    >>> crawler = CookieCrawler()
 | 
			
		||||
    >>> cookies, sesskey = await crawler.crawl()          # inside async code
 | 
			
		||||
    # or
 | 
			
		||||
    >>> cookies, sesskey = CookieCrawler.crawl_sync()     # plain scripts
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # construction                                                       #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def __init__(self, *, headless: bool = True) -> None:
 | 
			
		||||
        self.headless = headless
 | 
			
		||||
        self.cookies: Optional[List[Cookie]] = None
 | 
			
		||||
@ -38,13 +26,8 @@ class CookieCrawler:
 | 
			
		||||
                "Set MOODLE_USERNAME and MOODLE_PASSWORD as environment variables."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # public API                                                         #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    async def crawl(self) -> tuple[Cookies, str]:
 | 
			
		||||
        """
 | 
			
		||||
        Async entry-point – await this inside FastAPI / Prefect etc.
 | 
			
		||||
        """
 | 
			
		||||
        """Async method to crawl cookies and sesskey."""
 | 
			
		||||
        async with async_playwright() as p:
 | 
			
		||||
            browser: Browser = await p.chromium.launch(headless=self.headless)
 | 
			
		||||
            page = await browser.new_page()
 | 
			
		||||
@ -61,51 +44,34 @@ class CookieCrawler:
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def crawl_sync(cls, **kwargs) -> tuple[Cookies, str]:
 | 
			
		||||
        """
 | 
			
		||||
        Synchronous helper for CLI / notebooks.
 | 
			
		||||
 | 
			
		||||
        Detects whether an event loop is already running.  If so, it
 | 
			
		||||
        schedules the coroutine and waits; otherwise it starts a fresh loop.
 | 
			
		||||
        """
 | 
			
		||||
        """Synchronous version of crawl. Handles event loop."""
 | 
			
		||||
        self = cls(**kwargs)
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            loop = asyncio.get_running_loop()
 | 
			
		||||
        except RuntimeError:  # no loop running → safe to create one
 | 
			
		||||
        except RuntimeError:
 | 
			
		||||
            return asyncio.run(self.crawl())
 | 
			
		||||
 | 
			
		||||
        # An event loop exists – schedule coroutine
 | 
			
		||||
        return loop.run_until_complete(self.crawl())
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # internal helpers                                                   #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    async def _login(self, page: Page) -> None:
 | 
			
		||||
        """Fill the SSO form and extract cookies + sesskey."""
 | 
			
		||||
 | 
			
		||||
        # Select organisation / IdP
 | 
			
		||||
        await page.click("#wayf_submit_button")
 | 
			
		||||
 | 
			
		||||
        # Wait for the credential form
 | 
			
		||||
        await page.wait_for_selector("form[method='post']", state="visible")
 | 
			
		||||
 | 
			
		||||
        # Credentials
 | 
			
		||||
        await page.fill("input[id='username']", self.username)
 | 
			
		||||
        await page.fill("input[id='password']", self.password)
 | 
			
		||||
        await page.click("button[class='aai_login_button']")
 | 
			
		||||
 | 
			
		||||
        # Wait for redirect to /my/ page (dashboard), this means the login is complete
 | 
			
		||||
        await page.wait_for_url(PRIVATE_URLS.dashboard)
 | 
			
		||||
        await page.wait_for_selector("body", state="attached")
 | 
			
		||||
 | 
			
		||||
        # Navigate to personal course overview
 | 
			
		||||
        await page.goto(PRIVATE_URLS.user_courses)
 | 
			
		||||
        await page.wait_for_selector("body", state="attached")
 | 
			
		||||
 | 
			
		||||
        # Collect session cookies
 | 
			
		||||
        self.cookies = await page.context.cookies()
 | 
			
		||||
 | 
			
		||||
        # Extract sesskey from injected Moodle config
 | 
			
		||||
        try:
 | 
			
		||||
            self.sesskey = await page.evaluate(
 | 
			
		||||
                "() => window.M && M.cfg && M.cfg.sesskey"
 | 
			
		||||
@ -119,13 +85,9 @@ class CookieCrawler:
 | 
			
		||||
        logging.debug("sesskey: %s", self.sesskey)
 | 
			
		||||
        logging.debug("cookies: %s", self.cookies)
 | 
			
		||||
 | 
			
		||||
        # Dev convenience
 | 
			
		||||
        if not self.headless:
 | 
			
		||||
            await page.wait_for_timeout(5000)
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # cookie conversion                                                  #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def _to_cookiejar(self, raw: List[Cookie]) -> Cookies:
 | 
			
		||||
        jar = Cookies()
 | 
			
		||||
        for c in raw:
 | 
			
		||||
 | 
			
		||||
@ -39,18 +39,12 @@ from librarian_scraper.models.crawl_data import (
 | 
			
		||||
    CrawlTerm,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# module-level shared items for static task                                   #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
_COOKIE_JAR: httpx.Cookies | None = None
 | 
			
		||||
_DELAY: float = 0.0
 | 
			
		||||
 | 
			
		||||
CACHE_FILE = get_cache_root() / "librarian_no_access_cache.json"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# utility                                                                     #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def looks_like_enrol(resp: httpx.Response) -> bool:
 | 
			
		||||
    txt = resp.text.lower()
 | 
			
		||||
    return (
 | 
			
		||||
@ -60,21 +54,14 @@ def looks_like_enrol(resp: httpx.Response) -> bool:
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# main worker                                                                 #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
    input_model = CrawlProgram
 | 
			
		||||
    output_model = CrawlData
 | 
			
		||||
 | 
			
		||||
    # toggles (env overrides)
 | 
			
		||||
    RELAXED: bool
 | 
			
		||||
    USER_SPECIFIC: bool
 | 
			
		||||
    CLEAR_CACHE: bool
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # flow entry-point                                                   #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    async def __run__(self, program: CrawlProgram) -> CrawlData:
 | 
			
		||||
        global _COOKIE_JAR, _DELAY
 | 
			
		||||
        lg = get_run_logger()
 | 
			
		||||
@ -93,7 +80,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
            batch,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # --------------------------- login
 | 
			
		||||
        cookies, _ = await CookieCrawler().crawl()
 | 
			
		||||
        _COOKIE_JAR = cookies
 | 
			
		||||
        self._client = httpx.Client(cookies=cookies, follow_redirects=True)
 | 
			
		||||
@ -102,14 +88,11 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
            lg.error("Guest session detected – aborting crawl.")
 | 
			
		||||
            raise RuntimeError("Login failed")
 | 
			
		||||
 | 
			
		||||
        # --------------------------- cache
 | 
			
		||||
        no_access: set[str] = set() if self.CLEAR_CACHE else self._load_cache()
 | 
			
		||||
 | 
			
		||||
        # --------------------------- scrape terms (first two for dev)
 | 
			
		||||
        terms = self._crawl_terms(program.id)[:2]
 | 
			
		||||
        lg.info("Terms discovered: %d", len(terms))
 | 
			
		||||
 | 
			
		||||
        # --------------------------- scrape courses
 | 
			
		||||
        for term in terms:
 | 
			
		||||
            courses = self._crawl_courses(term.id)
 | 
			
		||||
            lg.info("[%s] raw courses: %d", term.name, len(courses))
 | 
			
		||||
@ -137,7 +120,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
            )
 | 
			
		||||
            lg.info("[%s] kept: %d", term.name, len(term.courses))
 | 
			
		||||
 | 
			
		||||
        # --------------------------- persist cache
 | 
			
		||||
        self._save_cache(no_access)
 | 
			
		||||
 | 
			
		||||
        return CrawlData(
 | 
			
		||||
@ -148,9 +130,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # static task inside class                                           #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @task(
 | 
			
		||||
        name="crawl_course",
 | 
			
		||||
@ -198,9 +177,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
 | 
			
		||||
        return course_id, href.split("=")[-1]
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # helpers                                                            #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    def _logged_in(self) -> bool:
 | 
			
		||||
        html = self._get_html(PUBLIC_URLS.index)
 | 
			
		||||
        return not parsel.Selector(text=html).css("div.usermenu span.login a")
 | 
			
		||||
@ -246,9 +222,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
 | 
			
		||||
            get_run_logger().warning("GET %s failed (%s)", url, exc)
 | 
			
		||||
            return ""
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # cache helpers                                                      #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def _load_cache() -> set[str]:
 | 
			
		||||
        try:
 | 
			
		||||
 | 
			
		||||
@ -22,11 +22,8 @@ class IndexCrawler:
 | 
			
		||||
        self.debug = debug
 | 
			
		||||
        self.client = httpx.Client(cookies=cookies, follow_redirects=True)
 | 
			
		||||
        self.max_workers = max_workers
 | 
			
		||||
 | 
			
		||||
        # When True the cached “no-access” set is ignored for this run
 | 
			
		||||
        self._ignore_cache: bool = False
 | 
			
		||||
 | 
			
		||||
        # Load persisted cache of course-IDs the user cannot access
 | 
			
		||||
        if NO_ACCESS_CACHE_FILE.exists():
 | 
			
		||||
            try:
 | 
			
		||||
                self._no_access_cache: set[str] = set(json.loads(NO_ACCESS_CACHE_FILE.read_text()))
 | 
			
		||||
@ -54,7 +51,7 @@ class IndexCrawler:
 | 
			
		||||
 | 
			
		||||
    def crawl_index(self, userSpecific: bool = True, *, use_cache: bool = True) -> MoodleIndex:
 | 
			
		||||
        """
 | 
			
		||||
        Build and return a `MoodleIndex`.
 | 
			
		||||
        Builds and returns a `MoodleIndex`.
 | 
			
		||||
 | 
			
		||||
        Parameters
 | 
			
		||||
        ----------
 | 
			
		||||
@ -65,21 +62,17 @@ class IndexCrawler:
 | 
			
		||||
            afresh. Newly discovered “no-access” courses are still written back to the
 | 
			
		||||
            cache at the end of the crawl.
 | 
			
		||||
        """
 | 
			
		||||
        # Set runtime flag for has_user_access()
 | 
			
		||||
        self._ignore_cache = not use_cache
 | 
			
		||||
 | 
			
		||||
        semesters = []
 | 
			
		||||
        # Get all courses for each semester and the courseid and name for each course.
 | 
			
		||||
        semesters = self.crawl_semesters()
 | 
			
		||||
        # Crawl only the latest two semesters to reduce load (remove once caching is implemented)
 | 
			
		||||
        for semester in semesters[:2]:
 | 
			
		||||
            courses = self.crawl_courses(semester)
 | 
			
		||||
 | 
			
		||||
            # Crawl courses in parallel to speed things up
 | 
			
		||||
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as pool:
 | 
			
		||||
                list(pool.map(self.crawl_course, courses))
 | 
			
		||||
 | 
			
		||||
            # Filter courses once all have been processed
 | 
			
		||||
            for course in courses:
 | 
			
		||||
                if userSpecific:
 | 
			
		||||
                    if course.content_ressource_id:
 | 
			
		||||
@ -87,8 +80,6 @@ class IndexCrawler:
 | 
			
		||||
                else:
 | 
			
		||||
                    semester.courses.append(course)
 | 
			
		||||
 | 
			
		||||
        # Only add semesters that have at least one course
 | 
			
		||||
        # Filter out semesters that ended up with no courses after crawling
 | 
			
		||||
        semesters: list[Semester] = [
 | 
			
		||||
            semester for semester in semesters if semester.courses
 | 
			
		||||
        ]
 | 
			
		||||
@ -100,21 +91,13 @@ class IndexCrawler:
 | 
			
		||||
                semesters=semesters,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # Persist any newly discovered no-access courses
 | 
			
		||||
        self._save_no_access_cache()
 | 
			
		||||
 | 
			
		||||
        # Restore default behaviour for subsequent calls
 | 
			
		||||
        self._ignore_cache = False
 | 
			
		||||
 | 
			
		||||
        return created_index
 | 
			
		||||
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    # High-level crawling helpers
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    def crawl_semesters(self) -> list[Semester]:
 | 
			
		||||
        """
 | 
			
		||||
        Crawl the semesters from the Moodle index page.
 | 
			
		||||
        """
 | 
			
		||||
        """Crawls semester data."""
 | 
			
		||||
        url = URLs.get_degree_program_url(self.degree_program.id)
 | 
			
		||||
        res = self.get_with_retries(url)
 | 
			
		||||
 | 
			
		||||
@ -126,9 +109,7 @@ class IndexCrawler:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def crawl_courses(self, semester: Semester) -> list[Course]:
 | 
			
		||||
        """
 | 
			
		||||
        Crawl the courses from the Moodle index page.
 | 
			
		||||
        """
 | 
			
		||||
        """Crawls course data for a semester."""
 | 
			
		||||
        url = URLs.get_semester_url(semester_id=semester.id)
 | 
			
		||||
        res = self.get_with_retries(url)
 | 
			
		||||
 | 
			
		||||
@ -140,10 +121,7 @@ class IndexCrawler:
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    def crawl_course(self, course: Course) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Crawl a single Moodle course page.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        """Crawls details for a single course."""
 | 
			
		||||
        hasAccess = self.has_user_access(course)
 | 
			
		||||
 | 
			
		||||
        if not hasAccess:
 | 
			
		||||
@ -154,13 +132,8 @@ class IndexCrawler:
 | 
			
		||||
        course.content_ressource_id = self.crawl_content_ressource_id(course)
 | 
			
		||||
        course.files = self.crawl_course_files(course)
 | 
			
		||||
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    # Networking utilities
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    def get_with_retries(self, url: str, retries: int = 3, delay: int = 1) -> httpx.Response:
 | 
			
		||||
        """
 | 
			
		||||
        Simple GET with retries and exponential back-off.
 | 
			
		||||
        """
 | 
			
		||||
        """Simple GET with retries and exponential back-off."""
 | 
			
		||||
        for attempt in range(1, retries + 1):
 | 
			
		||||
            try:
 | 
			
		||||
                response = self.client.get(url)
 | 
			
		||||
@ -183,17 +156,11 @@ class IndexCrawler:
 | 
			
		||||
            f.write(response.text)
 | 
			
		||||
        logging.info(f"Saved HTML to {filename}")
 | 
			
		||||
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    # Extractors
 | 
			
		||||
    # --------------------------------------------------------------------- #
 | 
			
		||||
    def extract_semesters(self, html: str) -> list[Semester]:
 | 
			
		||||
        """Extracts semester names and IDs from HTML."""
 | 
			
		||||
        selector = parsel.Selector(text=html)
 | 
			
		||||
 | 
			
		||||
        logging.info("Extracting semesters from the HTML content.")
 | 
			
		||||
 | 
			
		||||
        semesters: list[Semester] = []
 | 
			
		||||
 | 
			
		||||
        # Each semester sits in a collapsed container
 | 
			
		||||
        semester_containers = selector.css("div.category.notloaded.with_children.collapsed")
 | 
			
		||||
 | 
			
		||||
        for container in semester_containers:
 | 
			
		||||
@ -207,7 +174,6 @@ class IndexCrawler:
 | 
			
		||||
            )
 | 
			
		||||
            semester_id = anchor.attrib.get("href", "").split("=")[-1]
 | 
			
		||||
 | 
			
		||||
            # Only keep semesters labeled FS or HS
 | 
			
		||||
            if "FS" not in semester_name and "HS" not in semester_name:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
@ -250,12 +216,9 @@ class IndexCrawler:
 | 
			
		||||
            )
 | 
			
		||||
            course_id = anchor.attrib.get("href", "").split("=")[-1]
 | 
			
		||||
 | 
			
		||||
            # Remove trailing semester tag and code patterns
 | 
			
		||||
            course_name = re.sub(r"\s*(FS|HS)\d{2}\s*", "", course_name)
 | 
			
		||||
            course_name = re.sub(r"\s*\(.*?\)\s*", "", course_name).strip()
 | 
			
		||||
 | 
			
		||||
            # Try to locate a hero/overview image that belongs to this course box
 | 
			
		||||
            # Traverse up to the containing course box, then look for <div class="courseimage"><img ...>
 | 
			
		||||
            course_container = header.xpath('./ancestor::*[contains(@class,"coursebox")][1]')
 | 
			
		||||
            hero_src = (
 | 
			
		||||
                course_container.css("div.courseimage img::attr(src)").get("")
 | 
			
		||||
@ -275,10 +238,7 @@ class IndexCrawler:
 | 
			
		||||
        return courses
 | 
			
		||||
 | 
			
		||||
    def has_user_access(self, course: Course) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Return True only if the authenticated user can access the course (result cached).
 | 
			
		||||
        (i.e. the response is HTTP 200 **and** is not a redirected login/enrol page).
 | 
			
		||||
        """
 | 
			
		||||
        """Checks if user can access course (caches negative results)."""
 | 
			
		||||
        if not self._ignore_cache and course.id in self._no_access_cache:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
@ -289,21 +249,19 @@ class IndexCrawler:
 | 
			
		||||
            self._no_access_cache.add(course.id)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Detect Moodle redirection to a login or enrolment page
 | 
			
		||||
        final_url = str(res.url).lower()
 | 
			
		||||
        if "login" in final_url or "enrol" in final_url:
 | 
			
		||||
            self._no_access_cache.add(course.id)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Some enrolment pages still return 200; look for HTML markers
 | 
			
		||||
        if "#page-enrol" in res.text or "you need to enrol" in res.text.lower():
 | 
			
		||||
            self._no_access_cache.add(course.id)
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # If we got here the user has access; otherwise cache the deny
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def crawl_content_ressource_id(self, course: Course) -> str:
 | 
			
		||||
        """Extracts content resource ID for a course."""
 | 
			
		||||
        course_id = course.id
 | 
			
		||||
        url = URLs.get_course_url(course_id)
 | 
			
		||||
        res = self.get_with_retries(url)
 | 
			
		||||
@ -311,12 +269,10 @@ class IndexCrawler:
 | 
			
		||||
 | 
			
		||||
        try:
 | 
			
		||||
            logging.info("Searching for 'Download course content' link.")
 | 
			
		||||
            # Use parsel CSS selector to find the anchor tag with the specific data attribute
 | 
			
		||||
            download_link_selector = psl.css('a[data-downloadcourse="1"]')
 | 
			
		||||
            if not download_link_selector:
 | 
			
		||||
                raise ValueError("Download link not found.")
 | 
			
		||||
 | 
			
		||||
            # Extract the href attribute from the first matching element
 | 
			
		||||
            href = download_link_selector[0].attrib.get("href")
 | 
			
		||||
            if not href:
 | 
			
		||||
                raise ValueError("Href attribute not found on the download link.")
 | 
			
		||||
@ -334,9 +290,7 @@ class IndexCrawler:
 | 
			
		||||
            return ''
 | 
			
		||||
 | 
			
		||||
    def crawl_course_files(self, course: Course) -> list[FileEntry]:
 | 
			
		||||
        """
 | 
			
		||||
        Crawl the course files from the Moodle course page.
 | 
			
		||||
        """
 | 
			
		||||
        """Crawls file entries for a course."""
 | 
			
		||||
        url = URLs.get_course_url(course.id)
 | 
			
		||||
        res = self.get_with_retries(url)
 | 
			
		||||
 | 
			
		||||
@ -347,10 +301,8 @@ class IndexCrawler:
 | 
			
		||||
 | 
			
		||||
        return []
 | 
			
		||||
 | 
			
		||||
    # ----------------------------------------------------------------- #
 | 
			
		||||
    # Cache persistence helpers
 | 
			
		||||
    # ----------------------------------------------------------------- #
 | 
			
		||||
    def _save_no_access_cache(self) -> None:
 | 
			
		||||
        """Saves course IDs the user cannot access to a cache file."""
 | 
			
		||||
        try:
 | 
			
		||||
            NO_ACCESS_CACHE_FILE.write_text(json.dumps(sorted(self._no_access_cache)))
 | 
			
		||||
        except Exception as exc:
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,5 @@
 | 
			
		||||
# TODO: Move to librarian-core
 | 
			
		||||
"""
 | 
			
		||||
All URLs used in the crawler.
 | 
			
		||||
Functions marked as PUBLIC can be accessed without authentication.
 | 
			
		||||
Functions marked as PRIVATE require authentication.
 | 
			
		||||
"""
 | 
			
		||||
"""Moodle URLs. PUBLIC/PRIVATE indicates auth requirement."""
 | 
			
		||||
class URLs:
 | 
			
		||||
    base_url = "https://moodle.fhgr.ch"
 | 
			
		||||
 | 
			
		||||
@ -12,7 +8,6 @@ class URLs:
 | 
			
		||||
        """PUBLIC"""
 | 
			
		||||
        return cls.base_url
 | 
			
		||||
 | 
			
		||||
    # ------------------------- Moodle URLs -------------------------
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def get_login_url(cls):
 | 
			
		||||
        """PUBLIC"""
 | 
			
		||||
 | 
			
		||||
@ -35,9 +35,6 @@ from librarian_scraper.models.download_data import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# helper decorator                                                            #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def task_(**kw):
 | 
			
		||||
    kw.setdefault("log_prints", True)
 | 
			
		||||
    kw.setdefault("retries", 2)
 | 
			
		||||
@ -45,9 +42,6 @@ def task_(**kw):
 | 
			
		||||
    return task(**kw)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# shared state for static task                                                #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
_COOKIE_JAR: httpx.Cookies | None = None
 | 
			
		||||
_SESSKEY: str = ""
 | 
			
		||||
_LIMIT: int = 2
 | 
			
		||||
@ -57,27 +51,22 @@ _DELAY: float = 0.0
 | 
			
		||||
class Downloader(Worker[CrawlData, DownloadData]):
 | 
			
		||||
    DOWNLOAD_URL = "https://moodle.fhgr.ch/course/downloadcontent.php"
 | 
			
		||||
 | 
			
		||||
    # tuning
 | 
			
		||||
    CONCURRENCY = 8
 | 
			
		||||
    RELAXED = True  # False → faster
 | 
			
		||||
    RELAXED = True
 | 
			
		||||
 | 
			
		||||
    input_model = CrawlData
 | 
			
		||||
    output_model = DownloadData
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    async def __run__(self, crawl: CrawlData) -> DownloadData:
 | 
			
		||||
        global _COOKIE_JAR, _SESSKEY, _LIMIT, _DELAY
 | 
			
		||||
        lg = get_run_logger()
 | 
			
		||||
 | 
			
		||||
        # ------------ login
 | 
			
		||||
        cookies, sesskey = await CookieCrawler().crawl()
 | 
			
		||||
        _COOKIE_JAR, _SESSKEY = cookies, sesskey
 | 
			
		||||
 | 
			
		||||
        # ------------ tuning
 | 
			
		||||
        _LIMIT = 1 if self.RELAXED else max(1, min(self.CONCURRENCY, 8))
 | 
			
		||||
        _DELAY = CRAWLER["DELAY_SLOW"] if self.RELAXED else CRAWLER["DELAY_FAST"]
 | 
			
		||||
 | 
			
		||||
        # ------------ working dir
 | 
			
		||||
        work_root = Path(get_temp_path()) / f"dl_{int(time.time())}"
 | 
			
		||||
        work_root.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
@ -85,7 +74,6 @@ class Downloader(Worker[CrawlData, DownloadData]):
 | 
			
		||||
        futures = []
 | 
			
		||||
        term_dirs: List[Tuple[str, Path]] = []
 | 
			
		||||
 | 
			
		||||
        # schedule downloads
 | 
			
		||||
        for term in crawl.degree_program.terms:
 | 
			
		||||
            term_dir = work_root / term.name
 | 
			
		||||
            term_dir.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
@ -101,18 +89,14 @@ class Downloader(Worker[CrawlData, DownloadData]):
 | 
			
		||||
                    self._download_task.submit(course.content_ressource_id, dest)
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        wait(futures)  # block for all downloads
 | 
			
		||||
        wait(futures)
 | 
			
		||||
 | 
			
		||||
        # stage term directories
 | 
			
		||||
        for name, dir_path in term_dirs:
 | 
			
		||||
            self.stage(dir_path, new_name=name, sanitize=False, move=True)
 | 
			
		||||
 | 
			
		||||
        lg.info("Downloader finished – staged %d term folders", len(term_dirs))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    # static task                                                        #
 | 
			
		||||
    # ------------------------------------------------------------------ #
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @task_()
 | 
			
		||||
    def _download_task(context_id: str, dest: Path) -> None:
 | 
			
		||||
 | 
			
		||||
@ -135,9 +135,6 @@ MoodleIndex: {
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Base Model
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
class CrawlData(BaseModel):
 | 
			
		||||
    degree_program: CrawlProgram = Field(
 | 
			
		||||
        default_factory=lambda: CrawlProgram(id="", name="")
 | 
			
		||||
@ -147,18 +144,12 @@ class CrawlData(BaseModel):
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
#  Degree Program
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
class CrawlProgram(BaseModel):
 | 
			
		||||
    id: str = Field("1157", description="Unique identifier for the degree program.")
 | 
			
		||||
    name: str = Field("Computational and Data Science", description="Name of the degree program.")
 | 
			
		||||
    terms: list[CrawlTerm] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
#  Term
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
_TERM_RE = re.compile(r"^(HS|FS)\d{2}$")  # HS24 / FS25 …
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -168,9 +159,6 @@ class CrawlTerm(BaseModel):
 | 
			
		||||
    courses: list[CrawlCourse] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
#  Course
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
class CrawlCourse(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    name: str
 | 
			
		||||
@ -179,9 +167,6 @@ class CrawlCourse(BaseModel):
 | 
			
		||||
    files: list[CrawlFile] = Field(default_factory=list)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
#  Files
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
class CrawlFile(BaseModel):
 | 
			
		||||
    id: str
 | 
			
		||||
    res_id: str
 | 
			
		||||
 | 
			
		||||
@ -62,9 +62,6 @@ def _create_hnsw_index(
 | 
			
		||||
    except Exception:
 | 
			
		||||
        logger.exception("Failed to run create_or_reindex_hnsw")
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# single file                                                                 #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
def embed_single_file(
 | 
			
		||||
    *,
 | 
			
		||||
    course_id: str,
 | 
			
		||||
@ -104,7 +101,6 @@ def embed_single_file(
 | 
			
		||||
    wf.process()
 | 
			
		||||
    return chunk_path
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
async def run_embedder(
 | 
			
		||||
    course: ChunkCourse,
 | 
			
		||||
    concat_path: Union[str, Path],
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,6 @@ class EmbeddingWorkflow:
 | 
			
		||||
 | 
			
		||||
        # No need to store db_schema/db_function here if inserter handles it
 | 
			
		||||
 | 
			
		||||
    # ---------------- helpers ----------------
 | 
			
		||||
    def _load_chunk(self) -> Optional[str]:
 | 
			
		||||
        try:
 | 
			
		||||
            text = self.chunk_path.read_text(encoding="utf-8").strip()
 | 
			
		||||
@ -85,7 +84,7 @@ class EmbeddingWorkflow:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        logger.debug(f"Successfully processed and inserted {self.chunk_path}")
 | 
			
		||||
        return True # Indicate success
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Keep __all__ if needed
 | 
			
		||||
 | 
			
		||||
@ -22,7 +22,6 @@ from librarian_core.workers.base import Worker
 | 
			
		||||
 | 
			
		||||
from librarian_vspace.vecview.vecview import get_tsne_json
 | 
			
		||||
 | 
			
		||||
# ------------------------------------------------------------------ #
 | 
			
		||||
def _safe_get_logger(name: str):
 | 
			
		||||
    try:
 | 
			
		||||
        return get_run_logger()
 | 
			
		||||
@ -30,9 +29,7 @@ def _safe_get_logger(name: str):
 | 
			
		||||
        return logging.getLogger(name)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ------------------------------------------------------------------ #
 | 
			
		||||
# Pydantic payloads
 | 
			
		||||
# ------------------------------------------------------------------ #
 | 
			
		||||
class TsneExportInput(BaseModel):
 | 
			
		||||
    course_id: int
 | 
			
		||||
    limit: Optional[int] = None
 | 
			
		||||
@ -48,7 +45,6 @@ class TsneExportOutput(BaseModel):
 | 
			
		||||
    json_path: Path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ------------------------------------------------------------------ #
 | 
			
		||||
class TsneExportWorker(Worker[TsneExportInput, TsneExportOutput]):
 | 
			
		||||
    """Runs the t‑SNE export inside a Prefect worker."""  # noqa: D401
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,9 +20,6 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
DEFAULT_N_CLUSTERS = 8
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
# Internal helpers (kept minimal – no extra bells & whistles)
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
def _run_kmeans(df: pd.DataFrame, *, embedding_column: str, k: int = DEFAULT_N_CLUSTERS) -> pd.DataFrame:
 | 
			
		||||
    """Adds a 'cluster' column using K‑means (string labels)."""
 | 
			
		||||
    if df.empty or embedding_column not in df.columns:
 | 
			
		||||
@ -61,9 +58,6 @@ def _add_hover(df: pd.DataFrame) -> pd.DataFrame:
 | 
			
		||||
    return df
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
# Public helpers
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
def get_tsne_dataframe(
 | 
			
		||||
    db_schema: str,
 | 
			
		||||
    db_function: str,
 | 
			
		||||
 | 
			
		||||
@ -58,9 +58,6 @@ import pandas as pd
 | 
			
		||||
from sklearn.cluster import KMeans
 | 
			
		||||
from sklearn.metrics import pairwise_distances_argmin_min
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Map Vectorbase credential names → Supabase names expected by loader code
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
_ALIAS_ENV_MAP = {
 | 
			
		||||
    "VECTORBASE_URL": "SUPABASE_URL",
 | 
			
		||||
    "VECTORBASE_API_KEY": "SUPABASE_KEY",
 | 
			
		||||
@ -86,11 +83,8 @@ except ImportError as e:
 | 
			
		||||
    raise ImportError(f"Could not import VectorQueryLoader: {e}") from e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Logging setup (used by both script and callable function)
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# This basicConfig runs when the module is imported.
 | 
			
		||||
# Callers might want to configure logging before importing.
 | 
			
		||||
# Callers might want toconfigure logging before importing.
 | 
			
		||||
# If logging is already configured, basicConfig does nothing.
 | 
			
		||||
logging.basicConfig(
 | 
			
		||||
    level=logging.INFO,
 | 
			
		||||
@ -99,17 +93,11 @@ logging.basicConfig(
 | 
			
		||||
)
 | 
			
		||||
logger = logging.getLogger(__name__) # Use __name__ for module-specific logger
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Helper – JSON dump for centroid in YAML front‑matter
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
def centroid_to_json(vec: np.ndarray) -> str:
 | 
			
		||||
    """Converts a numpy vector to a JSON string suitable for YAML frontmatter."""
 | 
			
		||||
    return json.dumps([float(x) for x in vec], ensure_ascii=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Main clustering and export logic as a callable function
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
def run_cluster_export_job(
 | 
			
		||||
    course_id: Optional[int] = None, # Added course_id parameter
 | 
			
		||||
    output_dir: Union[str, Path] = "./cluster_md", # Output directory parameter
 | 
			
		||||
@ -147,9 +135,7 @@ def run_cluster_export_job(
 | 
			
		||||
    output_path.mkdir(parents=True, exist_ok=True)
 | 
			
		||||
    logger.info("Writing Markdown files to %s", output_path)
 | 
			
		||||
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    # Fetch embeddings - Now using VectorQueryLoader with filtering
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    try:
 | 
			
		||||
        # Use parameters for loader config
 | 
			
		||||
        # --- FIX: Instantiate VectorQueryLoader ---
 | 
			
		||||
@ -249,9 +235,7 @@ def run_cluster_export_job(
 | 
			
		||||
    # -------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    # Prepare training sample and determine effective k
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    # Use the parameter train_sample_size
 | 
			
		||||
    train_vecs = embeddings[:train_sample_size]
 | 
			
		||||
 | 
			
		||||
@ -302,9 +286,7 @@ def run_cluster_export_job(
 | 
			
		||||
        raise RuntimeError(f"K-means clustering failed: {e}") from e
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    # Assign every vector to its nearest centroid (full table)
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    logger.info("Assigning vectors to centroids...")
 | 
			
		||||
    try:
 | 
			
		||||
        # Use the determined embedding column for assignment as well
 | 
			
		||||
@ -315,9 +297,7 @@ def run_cluster_export_job(
 | 
			
		||||
        logger.exception("Failed to assign vectors to centroids.")
 | 
			
		||||
        raise RuntimeError(f"Failed to assign vectors to centroids: {e}") from e
 | 
			
		||||
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    # Write one Markdown file per cluster
 | 
			
		||||
    # ---------------------------------------------------------------------------
 | 
			
		||||
    files_written_count = 0
 | 
			
		||||
 | 
			
		||||
    logger.info("Writing cluster Markdown files to %s", output_path)
 | 
			
		||||
@ -385,9 +365,7 @@ def run_cluster_export_job(
 | 
			
		||||
    return output_path
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
# Script entry point
 | 
			
		||||
# ---------------------------------------------------------------------------
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # Configuration via environment for script
 | 
			
		||||
    script_output_dir = Path(os.environ.get("OUTPUT_DIR", "./cluster_md")).expanduser()
 | 
			
		||||
 | 
			
		||||
@ -37,15 +37,9 @@ from librarian_vspace.models.query_model import (
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
# Main helper
 | 
			
		||||
# --------------------------------------------------------------------- #
 | 
			
		||||
class VectorQuery(BaseVectorOperator):
 | 
			
		||||
    """High‑level helper for vector searches via Supabase RPC."""
 | 
			
		||||
 | 
			
		||||
    # -----------------------------------------------------------------
 | 
			
		||||
    # Public – modern API
 | 
			
		||||
    # -----------------------------------------------------------------
 | 
			
		||||
    def search(self, request: VectorSearchRequest) -> VectorSearchResponse:
 | 
			
		||||
        """Perform a similarity search and return structured results."""
 | 
			
		||||
 | 
			
		||||
@ -100,9 +94,6 @@ class VectorQuery(BaseVectorOperator):
 | 
			
		||||
            logger.exception("RPC 'vector_search' failed: %s", exc)
 | 
			
		||||
            return VectorSearchResponse(total=0, results=[])
 | 
			
		||||
 | 
			
		||||
    # -----------------------------------------------------------------
 | 
			
		||||
    # Public – legacy compatibility
 | 
			
		||||
    # -----------------------------------------------------------------
 | 
			
		||||
    def get_chucklets_by_vector(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
 | 
			
		||||
@ -38,9 +38,6 @@ import sys
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, Dict, Literal, Optional
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Optional dependencies                                                       #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
try:
 | 
			
		||||
    import psutil  # type: ignore
 | 
			
		||||
except ModuleNotFoundError:  # pragma: no cover
 | 
			
		||||
@ -51,10 +48,6 @@ try:
 | 
			
		||||
except ModuleNotFoundError:  # pragma: no cover
 | 
			
		||||
    torch = None  # type: ignore
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Hardware discovery helpers                                                  #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(maxsize=None)
 | 
			
		||||
def logical_cores() -> int:
 | 
			
		||||
@ -124,11 +117,6 @@ def cgroup_cpu_limit() -> Optional[int]:
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# GPU helpers (CUDA only for now)                                             #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@functools.lru_cache(maxsize=None)
 | 
			
		||||
def gpu_info() -> Optional[Dict[str, Any]]:
 | 
			
		||||
    """Return basic info for the first CUDA device via *torch* (or ``None``)."""
 | 
			
		||||
@ -147,11 +135,6 @@ def gpu_info() -> Optional[Dict[str, Any]]:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Recommendation logic                                                        #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def recommended_workers(
 | 
			
		||||
    *,
 | 
			
		||||
    kind: Literal["cpu", "io", "gpu"] = "cpu",
 | 
			
		||||
@ -194,11 +177,6 @@ def recommended_workers(
 | 
			
		||||
    return max(1, base)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# Convenience snapshot of system info                                         #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def system_snapshot() -> Dict[str, Any]:
 | 
			
		||||
    """Return a JSON‑serialisable snapshot of parallelism‑related facts."""
 | 
			
		||||
    return {
 | 
			
		||||
@ -211,11 +189,6 @@ def system_snapshot() -> Dict[str, Any]:
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
# CLI                                                                         #
 | 
			
		||||
# --------------------------------------------------------------------------- #
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _cli() -> None:  # pragma: no cover
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        prog="parallelism_advisor", description="Rule‑of‑thumb worker estimator"
 | 
			
		||||
 | 
			
		||||
@ -24,13 +24,11 @@ class BaseVectorOperator:
 | 
			
		||||
        self.table: Optional[str] = None
 | 
			
		||||
        self._resolve_ids()
 | 
			
		||||
 | 
			
		||||
    # ---------------- public helpers ----------------
 | 
			
		||||
    def table_fqn(self) -> str:
 | 
			
		||||
        if not self.table:
 | 
			
		||||
            raise RuntimeError("VectorOperator not initialised – no table")
 | 
			
		||||
        return f"{self.schema}.{self.table}"
 | 
			
		||||
 | 
			
		||||
    # ---------------- internals ----------------
 | 
			
		||||
    def _resolve_ids(self) -> None:
 | 
			
		||||
        self.model_id = self._rpc_get_model_id(self.model)
 | 
			
		||||
        if self.model_id is None:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user