Update Readme and cleanup
This commit is contained in:
parent
52d087ce90
commit
b6789cb5f7
50
README.md
Normal file
50
README.md
Normal file
@ -0,0 +1,50 @@
|
||||
# Atlas Librarian
|
||||
|
||||
A comprehensive content processing and management system for extracting, chunking, and vectorizing information from various sources.
|
||||
|
||||
## Overview
|
||||
|
||||
Atlas Librarian is a modular system designed to process, organize, and make searchable large amounts of content through web scraping, content extraction, chunking, and vector embeddings.
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
atlas/
|
||||
├── librarian/
|
||||
│ ├── atlas-librarian/ # Main application
|
||||
│ ├── librarian-core/ # Core functionality and storage
|
||||
│ └── plugins/
|
||||
│ ├── librarian-chunker/ # Content chunking
|
||||
│ ├── librarian-extractor/ # Content extraction with AI
|
||||
│ ├── librarian-scraper/ # Web scraping and crawling
|
||||
│ └── librarian-vspace/ # Vector space operations
|
||||
```
|
||||
|
||||
## Components
|
||||
|
||||
- **Atlas Librarian**: Main application with API, web app, and recipe management
|
||||
- **Librarian Core**: Shared utilities, storage, and Supabase integration
|
||||
- **Chunker Plugin**: Splits content into processable chunks
|
||||
- **Extractor Plugin**: Extracts and sanitizes content using AI
|
||||
- **Scraper Plugin**: Crawls and downloads web content
|
||||
- **VSpace Plugin**: Vector embeddings and similarity search
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. Clone the repository
|
||||
2. Install dependencies for each component
|
||||
3. Configure environment variables
|
||||
4. Run the main application
|
||||
|
||||
## Features
|
||||
|
||||
- Web content scraping and crawling
|
||||
- AI-powered content extraction and sanitization
|
||||
- Intelligent content chunking
|
||||
- Vector embeddings for semantic search
|
||||
- Supabase integration for data storage
|
||||
- Modular plugin architecture
|
||||
|
||||
---
|
||||
|
||||
*For detailed documentation, see the individual component directories.*
|
@ -3,18 +3,13 @@ import importlib
|
||||
|
||||
__all__ = []
|
||||
|
||||
# Iterate over all modules in this package
|
||||
for finder, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
# import the sub-module
|
||||
module = importlib.import_module(f"{__name__}.{module_name}")
|
||||
|
||||
# decide which names to re-export:
|
||||
# use module.__all__ if it exists, otherwise every non-private attribute
|
||||
public_names = getattr(
|
||||
module, "__all__", [n for n in dir(module) if not n.startswith("_")]
|
||||
)
|
||||
|
||||
# bring each name into the package namespace
|
||||
for name in public_names:
|
||||
globals()[name] = getattr(module, name)
|
||||
__all__.append(name) # type: ignore
|
||||
|
@ -1,6 +1,3 @@
|
||||
# ------------------------------------------------------ #
|
||||
# Workers have to be imported here to be discovered by the worker loader
|
||||
# ------------------------------------------------------ #
|
||||
from librarian_chunker.chunker import Chunker
|
||||
from librarian_extractor.ai_sanitizer import AISanitizer
|
||||
from librarian_extractor.extractor import Extractor
|
||||
|
@ -32,12 +32,9 @@ from atlas_librarian.stores.workers import WORKERS
|
||||
router = APIRouter(tags=["recipes"])
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Pydantic models #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class RecipeRequest(BaseModel):
|
||||
workers: List[str] = Field(min_length=1)
|
||||
payload: dict # input of the first worker
|
||||
payload: dict
|
||||
|
||||
|
||||
class RecipeMetadata(BaseModel):
|
||||
@ -47,17 +44,11 @@ class RecipeMetadata(BaseModel):
|
||||
flow_run_id: str | None = None
|
||||
|
||||
|
||||
# in-memory “DB”
|
||||
_RECIPES: dict[str, RecipeMetadata] = {}
|
||||
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# routes #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@router.post("/run", status_code=202, response_model=list[FlowArtifact])
|
||||
def run_recipe(req: RecipeRequest) -> list[FlowArtifact]:
|
||||
# validation of worker chain
|
||||
for w in req.workers:
|
||||
if w not in WORKERS:
|
||||
raise HTTPException(400, f"Unknown worker: {w}")
|
||||
@ -71,7 +62,6 @@ def run_recipe(req: RecipeRequest) -> list[FlowArtifact]:
|
||||
async def _run_worker(worker: type[Worker], art: FlowArtifact) -> FlowArtifact:
|
||||
return await worker.flow()(art)
|
||||
|
||||
# Kick off the first worker
|
||||
art: FlowArtifact = anyio.run(_run_worker, start_worker, FlowArtifact(data=payload, run_id=str(uuid.uuid4()), dir=Path(".")))
|
||||
artifacts.append(art)
|
||||
|
||||
|
@ -16,9 +16,6 @@ from pydantic import BaseModel
|
||||
router = APIRouter(tags=["runs"])
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# response model #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class RunInfo(BaseModel):
|
||||
run_id: str
|
||||
worker: str
|
||||
@ -27,9 +24,6 @@ class RunInfo(BaseModel):
|
||||
data: dict | None = None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helper #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _open_store(run_id: str) -> WorkerStore:
|
||||
try:
|
||||
return WorkerStore.open(run_id)
|
||||
@ -37,16 +31,11 @@ def _open_store(run_id: str) -> WorkerStore:
|
||||
raise HTTPException(status_code=404, detail="Run-id not found") from exc
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# routes #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@router.get("/{run_id}", response_model=RunInfo)
|
||||
def get_run(run_id: str) -> RunInfo:
|
||||
"""
|
||||
Return coarse-grained information about a single flow run.
|
||||
|
||||
For the web-UI we expose only minimal metadata plus the local directory
|
||||
where files were written; clients can read further details from disk.
|
||||
Returns coarse-grained info for a flow run, including local data directory.
|
||||
Web UI uses this for minimal metadata display.
|
||||
"""
|
||||
store = _open_store(run_id)
|
||||
meta = store.metadata # {'worker_name': …, 'state': …, …}
|
||||
@ -79,7 +68,6 @@ def get_latest_run(worker_name: str) -> RunInfo | None:
|
||||
@router.get("/{run_id}/artifact")
|
||||
def get_artifact(run_id: str) -> str:
|
||||
store = _open_store(run_id)
|
||||
# Check if the artifact.md file exists
|
||||
if not store._run_dir.joinpath("artifact.md").exists():
|
||||
raise HTTPException(status_code=404, detail="Artifact not found")
|
||||
return store._run_dir.joinpath("artifact.md").read_text()
|
||||
|
@ -22,21 +22,15 @@ from atlas_librarian.stores.workers import WORKERS
|
||||
router = APIRouter(tags=["workers"])
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# response schema #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
class Order(BaseModel):
|
||||
order_id: str
|
||||
worker_name: str
|
||||
payload: dict
|
||||
|
||||
# Job is accepted and will be moved in the Job Pool
|
||||
def accept(self):
|
||||
_ORDER_POOL[self.order_id] = self
|
||||
return self.order_id
|
||||
|
||||
# Job is completed and will be removed from the Job Pool
|
||||
def complete(self):
|
||||
del _ORDER_POOL[self.order_id]
|
||||
|
||||
@ -44,9 +38,6 @@ class Order(BaseModel):
|
||||
_ORDER_POOL: dict[str, Order] = {}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _get_worker_or_404(name: str):
|
||||
cls = WORKERS.get(name)
|
||||
if cls is None:
|
||||
@ -60,13 +51,8 @@ def _get_artifact_from_payload(
|
||||
run_id = payload.get("run_id") or None
|
||||
input_data = cls.input_model.model_validate(payload["data"])
|
||||
|
||||
# Making sure the payload is valid
|
||||
return FlowArtifact.new(data=input_data, dir=dir, run_id=run_id)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# GET Routes #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
class WorkerMeta(BaseModel):
|
||||
name: str
|
||||
input: str
|
||||
@ -89,11 +75,6 @@ def list_orders() -> list[Order]:
|
||||
return list(_ORDER_POOL.values())
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# POST Routes #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
# ---------- Submit and get the result ----------------------------------------
|
||||
@router.post("/{worker_name}/submit", response_model=FlowArtifact, status_code=202)
|
||||
def submit_worker(worker_name: str, payload: dict[str, Any]) -> FlowArtifact:
|
||||
cls = _get_worker_or_404(worker_name)
|
||||
@ -108,7 +89,6 @@ def submit_worker(worker_name: str, payload: dict[str, Any]) -> FlowArtifact:
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Submit on existing run ----------------------------------------------
|
||||
@router.post(
|
||||
"/{worker_name}/submit/{prev_run_id}/chain",
|
||||
response_model=FlowArtifact,
|
||||
@ -136,7 +116,6 @@ def submit_chain(worker_name: str, prev_run_id: str) -> FlowArtifact:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Submit and chain, with the latest output of a worker
|
||||
@router.post(
|
||||
"/{worker_name}/submit/{prev_worker_name}/chain/latest",
|
||||
response_model=FlowArtifact | None,
|
||||
@ -162,14 +141,12 @@ def submit_chain_latest(worker_name: str, prev_worker_name: str) -> FlowArtifact
|
||||
return cls.submit(art)
|
||||
|
||||
|
||||
# ---------- Place an Order and get a receipt ----------------------------------------------------
|
||||
@router.post("/{worker_name}/order", response_model=Order, status_code=202)
|
||||
def place_order(worker_name: str, payload: dict[str, Any]) -> Order:
|
||||
cls = _get_worker_or_404(worker_name)
|
||||
|
||||
try:
|
||||
art = _get_artifact_from_payload(payload, cls)
|
||||
# order_id = str(uuid.uuid4())
|
||||
order_id = "731ce6ef-ccdc-44bd-b152-da126f104db1"
|
||||
order = Order(order_id=order_id, worker_name=worker_name, payload=art.model_dump())
|
||||
order.accept()
|
||||
|
@ -1,4 +1,3 @@
|
||||
# atlas_librarian/app.py
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
@ -8,7 +7,6 @@ from librarian_core.utils.secrets_loader import load_env
|
||||
from atlas_librarian.api import recipes_router, runs_router, worker_router
|
||||
from atlas_librarian.stores import discover_workers
|
||||
|
||||
# Application description for OpenAPI docs
|
||||
APP_DESCRIPTION = """
|
||||
Atlas Librarian API Gateway 🚀
|
||||
|
||||
@ -32,11 +30,10 @@ def create_app() -> FastAPI:
|
||||
|
||||
app = FastAPI(
|
||||
title="Atlas Librarian API",
|
||||
version="0.1.0", # Use semantic versioning
|
||||
version="0.1.0",
|
||||
description=APP_DESCRIPTION,
|
||||
)
|
||||
|
||||
# Configure CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@ -45,23 +42,18 @@ def create_app() -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include all API routers
|
||||
app.include_router(worker_router, prefix=f"{API_PREFIX}/worker")
|
||||
app.include_router(runs_router, prefix=f"{API_PREFIX}/runs")
|
||||
app.include_router(recipes_router, prefix=f"{API_PREFIX}/recipes")
|
||||
|
||||
return app
|
||||
|
||||
# Create the app instance
|
||||
app = create_app()
|
||||
|
||||
|
||||
@app.get("/", tags=["Root"], summary="API Root/Health Check")
|
||||
async def read_root():
|
||||
"""
|
||||
Provides a basic health check endpoint.
|
||||
Returns a welcome message indicating the API is running.
|
||||
"""
|
||||
"""Basic health check endpoint."""
|
||||
return {"message": "Welcome to Atlas Librarian API"}
|
||||
|
||||
|
||||
|
@ -71,14 +71,12 @@ recipes = [
|
||||
# "download",
|
||||
"extract",
|
||||
],
|
||||
# Default-Parameters
|
||||
"parameters": {
|
||||
"crawl": example_study_program,
|
||||
"download": example_moodle_index,
|
||||
"extract": example_downloaded_courses,
|
||||
},
|
||||
},
|
||||
# All steps in one go
|
||||
{
|
||||
"name": "quick-all",
|
||||
"steps": ["crawl", "download", "extract"],
|
||||
|
@ -1,4 +1,3 @@
|
||||
# atlas_librarian/stores/worker_store.py
|
||||
"""
|
||||
Auto-discovers every third-party Worker package that exposes an entry-point
|
||||
|
||||
@ -26,32 +25,24 @@ except ImportError: # Py < 3.10 → fall back to back-port
|
||||
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
WORKERS: Dict[str, Type[Worker]] = {} # key = Worker.worker_name
|
||||
WORKERS: Dict[str, Type[Worker]] = {}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
def _register_worker_class(obj: object) -> None:
|
||||
"""
|
||||
Inspect *obj* and register it if it looks like a Worker subclass
|
||||
produced by the metaclass in *librarian_core*.
|
||||
"""
|
||||
"""Registers valid Worker subclasses."""
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and issubclass(obj, Worker)
|
||||
and obj is not Worker # not the abstract base
|
||||
and obj is not Worker
|
||||
):
|
||||
WORKERS[obj.worker_name] = obj # type: ignore[arg-type]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
def _import_ep(ep: EntryPoint):
|
||||
"""Load the object referenced by an entry-point."""
|
||||
return ep.load()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
def discover_workers(group: str = "librarian.worker") -> None:
|
||||
"""
|
||||
Discover all entry-points of *group* and populate ``WORKERS``.
|
||||
@ -70,13 +61,11 @@ def discover_workers(group: str = "librarian.worker") -> None:
|
||||
print(f"[Worker-Loader] Failed to load entry-point {ep!r}: {exc}")
|
||||
continue
|
||||
|
||||
# If a module was loaded, inspect its attributes; else try registering directly
|
||||
if isinstance(loaded, ModuleType):
|
||||
for attr in loaded.__dict__.values():
|
||||
_register_worker_class(attr)
|
||||
else:
|
||||
_register_worker_class(loaded)
|
||||
|
||||
# Register any Worker subclasses imported directly (e.g., loaded via atlas_librarian/api/__init__)
|
||||
for cls in Worker.__subclasses__():
|
||||
_register_worker_class(cls)
|
||||
|
@ -27,9 +27,6 @@ from librarian_core.utils import path_utils
|
||||
class WorkerStore:
|
||||
"""Never exposed to worker code – all access is via helper methods."""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# constructors #
|
||||
# ------------------------------------------------------------------ #
|
||||
@classmethod
|
||||
def new(cls, *, worker_name: str, flow_id: str) -> "WorkerStore":
|
||||
run_dir = path_utils.get_run_dir(worker_name, flow_id, create=True)
|
||||
@ -53,9 +50,6 @@ class WorkerStore:
|
||||
return cls(candidate, meta["worker_name"], run_id)
|
||||
raise FileNotFoundError(run_id)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# life-cycle #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __init__(self, run_dir: Path, worker_name: str, flow_id: str):
|
||||
self._run_dir = run_dir
|
||||
self._worker_name = worker_name
|
||||
@ -71,9 +65,6 @@ class WorkerStore:
|
||||
self._entry_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._exit_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# entry / exit handling #
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def entry_dir(self) -> Path:
|
||||
return self._entry_dir
|
||||
@ -115,9 +106,6 @@ class WorkerStore:
|
||||
shutil.copy2(src_path, dst)
|
||||
return dst
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# result persistence #
|
||||
# ------------------------------------------------------------------ #
|
||||
def save_model(
|
||||
self,
|
||||
model: BaseModel,
|
||||
@ -145,9 +133,6 @@ class WorkerStore:
|
||||
def cleanup(self) -> None:
|
||||
shutil.rmtree(self._work_dir, ignore_errors=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# public helpers (API needs these) #
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def data_dir(self) -> Path:
|
||||
return self._run_dir / "data"
|
||||
@ -201,16 +186,12 @@ class WorkerStore:
|
||||
|
||||
latest_run_dir = sorted_runs[-1][1]
|
||||
|
||||
# Load the model
|
||||
return { # That is a FlowArtifact
|
||||
return {
|
||||
"run_id": latest_run_dir.name,
|
||||
"dir": latest_run_dir / "data",
|
||||
"data": WorkerStore.open(latest_run_dir.name).load_model(as_dict=True), # type: ignore
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internals #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _write_meta(self, *, state: str) -> None:
|
||||
meta = {
|
||||
"worker_name": self._worker_name,
|
||||
@ -233,9 +214,6 @@ class WorkerStore:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# clean-up #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
shutil.rmtree(self._work_dir, ignore_errors=True)
|
||||
|
@ -45,7 +45,6 @@ class SupabaseGateway:
|
||||
self.client = get_client()
|
||||
self.schema = _cfg.db_schema if _cfg else "library"
|
||||
|
||||
# ---------- internal ----------
|
||||
def _rpc(self, fn: str, payload: Dict[str, Any] | None = None):
|
||||
resp = (
|
||||
self.client.schema(self.schema)
|
||||
|
@ -15,16 +15,10 @@ from prefect.artifacts import acreate_markdown_artifact
|
||||
|
||||
from librarian_core.storage.worker_store import WorkerStore
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# type parameters #
|
||||
# --------------------------------------------------------------------------- #
|
||||
InT = TypeVar("InT", bound=BaseModel)
|
||||
OutT = TypeVar("OutT", bound=BaseModel)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# envelope returned by every worker flow #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class FlowArtifact(BaseModel, Generic[OutT]):
|
||||
run_id: str | None = None
|
||||
dir: Path | None = None
|
||||
@ -34,23 +28,17 @@ class FlowArtifact(BaseModel, Generic[OutT]):
|
||||
def new(cls, run_id: str | None = None, dir: Path | None = None, data: OutT | None = None) -> FlowArtifact:
|
||||
if not data:
|
||||
raise ValueError("data is required")
|
||||
# Intermediate Worker
|
||||
if run_id and dir:
|
||||
return FlowArtifact(run_id=run_id, dir=dir, data=data)
|
||||
|
||||
# Initial Worker
|
||||
else:
|
||||
return FlowArtifact(data=data)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# metaclass: adds a Prefect flow + envelope to each Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class _WorkerMeta(type):
|
||||
def __new__(mcls, name, bases, ns, **kw):
|
||||
cls = super().__new__(mcls, name, bases, dict(ns))
|
||||
|
||||
if name == "Worker" and cls.__module__ == __name__:
|
||||
return cls # abstract base
|
||||
return cls
|
||||
|
||||
if not (hasattr(cls, "input_model") and hasattr(cls, "output_model")):
|
||||
raise TypeError(f"{name}: declare 'input_model' / 'output_model'.")
|
||||
@ -62,10 +50,9 @@ class _WorkerMeta(type):
|
||||
cls._prefect_flow = mcls._build_prefect_flow(cls) # type: ignore
|
||||
return cls
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
@staticmethod
|
||||
def _build_prefect_flow(cls_ref):
|
||||
"""Create the Prefect flow and return it."""
|
||||
"""Builds the Prefect flow for the worker."""
|
||||
InArt = cls_ref.input_artifact # noqa: F841
|
||||
OutModel: type[BaseModel] = cls_ref.output_model # noqa: F841
|
||||
worker_name: str = cls_ref.worker_name
|
||||
@ -82,9 +69,7 @@ class _WorkerMeta(type):
|
||||
|
||||
inst = cls_ref()
|
||||
inst._inject_store(store)
|
||||
# run worker ------------------------------------------------
|
||||
run_res = inst.__run__(in_art.data)
|
||||
# allow sync or async implementations
|
||||
if inspect.iscoroutine(run_res):
|
||||
result = await run_res
|
||||
else:
|
||||
@ -104,7 +89,6 @@ class _WorkerMeta(type):
|
||||
description=f"{worker_name} output"
|
||||
)
|
||||
|
||||
# save the markdown artifact in the flow directory
|
||||
md_file = store._run_dir / "artifact.md"
|
||||
md_file.write_text(md_table)
|
||||
|
||||
@ -112,9 +96,8 @@ class _WorkerMeta(type):
|
||||
|
||||
return flow(name=worker_name, log_prints=True)(_core)
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
def _create_input_artifact(cls):
|
||||
"""Create & attach a pydantic model ‹InputArtifact› = {dir?, data}."""
|
||||
"""Creates the `InputArtifact` model for the worker."""
|
||||
DirField = (Path | None, None)
|
||||
DataField = (cls.input_model, ...) # type: ignore # required
|
||||
art_name = f"{cls.__name__}InputArtifact"
|
||||
@ -124,35 +107,27 @@ class _WorkerMeta(type):
|
||||
cls.input_artifact = artifact # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# public Worker base #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Worker(Generic[InT, OutT], metaclass=_WorkerMeta):
|
||||
"""
|
||||
Derive from this class, set *input_model* / *output_model*, and implement
|
||||
an **async** ``__run__(payload: input_model)``.
|
||||
Base class for workers. Subclasses must define `input_model`,
|
||||
`output_model`, and implement `__run__`.
|
||||
"""
|
||||
|
||||
input_model: ClassVar[type[BaseModel]]
|
||||
output_model: ClassVar[type[BaseModel]]
|
||||
input_artifact: ClassVar[type[BaseModel]] # injected by metaclass
|
||||
input_artifact: ClassVar[type[BaseModel]]
|
||||
worker_name: ClassVar[str]
|
||||
_prefect_flow: ClassVar[Callable[[FlowArtifact[InT]], Awaitable[FlowArtifact[OutT]]]]
|
||||
|
||||
# injected at runtime
|
||||
entry: Path
|
||||
_store: WorkerStore
|
||||
entry: Path # The entry directory for the worker's temporary files
|
||||
_store: WorkerStore # The WorkerStore instance for this run
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal wiring #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _inject_store(self, store: WorkerStore) -> None:
|
||||
"""Injects WorkerStore and sets entry directory."""
|
||||
self._store = store
|
||||
self.entry = store.entry_dir
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# developer helper #
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helper method for staging files
|
||||
def stage(
|
||||
self,
|
||||
src: Path | str,
|
||||
@ -163,30 +138,26 @@ class Worker(Generic[InT, OutT], metaclass=_WorkerMeta):
|
||||
) -> Path:
|
||||
return self._store.stage(src, new_name=new_name, sanitize=sanitize, move=move)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# convenience wrappers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@classmethod
|
||||
def flow(cls):
|
||||
"""Return the auto-generated Prefect flow."""
|
||||
"""Returns the Prefect flow."""
|
||||
return cls._prefect_flow
|
||||
|
||||
# submit variants --------------------------------------------------- #
|
||||
@classmethod
|
||||
def submit(cls, payload: FlowArtifact[InT]) -> FlowArtifact[OutT]:
|
||||
"""Submits payload to the Prefect flow."""
|
||||
async def _runner():
|
||||
art = await cls._prefect_flow(payload) # type: ignore[arg-type]
|
||||
return art
|
||||
|
||||
return anyio.run(_runner)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# abstract #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, payload: InT) -> OutT: ...
|
||||
async def __run__(self, payload: InT) -> OutT:
|
||||
"""Core logic, implemented by subclasses."""
|
||||
...
|
||||
|
||||
|
||||
# Should be overridden by the worker
|
||||
async def _to_markdown(self, data: OutT) -> str:
|
||||
"""Converts output to Markdown. Override for custom format."""
|
||||
md_table = pd.DataFrame([data.dict()]).to_markdown(index=False)
|
||||
return md_table
|
||||
|
@ -34,18 +34,15 @@ class Chunker(Worker[ExtractData, ChunkData]):
|
||||
|
||||
working_dir = get_temp_path()
|
||||
|
||||
# load NLP and tokenizer
|
||||
Chunker.nlp = spacy.load("xx_ent_wiki_sm")
|
||||
Chunker.nlp.add_pipe("sentencizer")
|
||||
Chunker.enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# chunk parameters
|
||||
Chunker.max_tokens = MAX_TOKENS
|
||||
Chunker.overlap_tokens = OVERLAP_TOKENS
|
||||
|
||||
result = ChunkData(terms=[])
|
||||
|
||||
# Loading files
|
||||
for term in payload.terms:
|
||||
chunked_term = ChunkedTerm(id=term.id, name=term.name)
|
||||
in_term_dir = self.entry / term.name
|
||||
@ -79,7 +76,6 @@ class Chunker(Worker[ExtractData, ChunkData]):
|
||||
|
||||
chunked_term.courses.append(chunked_course)
|
||||
|
||||
# Add the chunked term to the result
|
||||
result.terms.append(chunked_term)
|
||||
self.stage(out_term_dir)
|
||||
|
||||
@ -94,16 +90,11 @@ class Chunker(Worker[ExtractData, ChunkData]):
|
||||
lg.info(f"Chapter path: {chapter_path}")
|
||||
lg.info(f"Out course dir: {out_course_dir}")
|
||||
|
||||
# Extract the Text
|
||||
file_text = Chunker._extract_text(chapter_path / f.name)
|
||||
|
||||
# Chunk the Text
|
||||
chunks = Chunker._chunk_text(file_text, f.name, out_course_dir)
|
||||
|
||||
images_dir = out_course_dir / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract the Images
|
||||
images = Chunker._extract_images(chapter_path / f.name, images_dir)
|
||||
|
||||
return chunks, images
|
||||
@ -128,28 +119,22 @@ class Chunker(Worker[ExtractData, ChunkData]):
|
||||
def _chunk_text(text: str, f_name: str, out_course_dir: Path) -> list[Chunk]:
|
||||
lg = get_run_logger()
|
||||
lg.info(f"Chunking text for file {f_name}")
|
||||
# split text into sentences and get tokens
|
||||
nlp_doc = Chunker.nlp(text)
|
||||
sentences = [sent.text.strip() for sent in nlp_doc.sents]
|
||||
sentence_token_counts = [len(Chunker.enc.encode(s)) for s in sentences]
|
||||
lg.info(f"Extracted {len(sentences)} sentences with token counts: {sentence_token_counts}")
|
||||
|
||||
# Buffers
|
||||
chunks: list[Chunk] = []
|
||||
current_chunk = []
|
||||
current_token_total = 0
|
||||
|
||||
chunk_id = 0
|
||||
|
||||
for s, tc in zip(sentences, sentence_token_counts): # Pair sentences and tokens
|
||||
if tc + current_token_total <= MAX_TOKENS: # Check Token limit
|
||||
# Add sentences to chunk
|
||||
for s, tc in zip(sentences, sentence_token_counts):
|
||||
if tc + current_token_total <= MAX_TOKENS:
|
||||
current_chunk.append(s)
|
||||
current_token_total += tc
|
||||
else:
|
||||
# Flush Chunk
|
||||
chunk_text = "\n\n".join(current_chunk)
|
||||
|
||||
chunk_name = f"{f_name}_{chunk_id}"
|
||||
with open(
|
||||
out_course_dir / f"{chunk_name}.md", "w", encoding="utf-8"
|
||||
@ -165,14 +150,12 @@ class Chunker(Worker[ExtractData, ChunkData]):
|
||||
)
|
||||
)
|
||||
|
||||
# Get Overlap from Chunk
|
||||
token_ids = Chunker.enc.encode(chunk_text)
|
||||
overlap_ids = token_ids[-OVERLAP_TOKENS :]
|
||||
overlap_text = Chunker.enc.decode(overlap_ids)
|
||||
overlap_doc = Chunker.nlp(overlap_text)
|
||||
overlap_sents = [sent.text for sent in overlap_doc.sents]
|
||||
|
||||
# Start new Chunk
|
||||
current_chunk = overlap_sents + [s]
|
||||
current_token_total = sum(
|
||||
len(Chunker.enc.encode(s)) for s in current_chunk
|
||||
|
@ -2,10 +2,6 @@ from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Output models #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
id: str
|
||||
|
@ -32,9 +32,6 @@ from librarian_extractor.models.extract_data import (
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _clean_json(txt: str) -> str:
|
||||
txt = txt.strip()
|
||||
if txt.startswith("```"):
|
||||
@ -51,7 +48,7 @@ def _safe_json_load(txt: str) -> dict:
|
||||
|
||||
|
||||
def _merge_with_original(src: ExtractedCourse, patch: dict, lg) -> ExtractedCourse:
|
||||
"""Return *patch* merged with *src* so every id is preserved."""
|
||||
"""Merges LLM patch with source, preserving IDs."""
|
||||
try:
|
||||
tgt = ExtractedCourse.model_validate(patch)
|
||||
except ValidationError as err:
|
||||
@ -73,9 +70,6 @@ def _merge_with_original(src: ExtractedCourse, patch: dict, lg) -> ExtractedCour
|
||||
return tgt
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# OpenAI call – Prefect task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@task(
|
||||
name="sanitize_course_json",
|
||||
retries=2,
|
||||
@ -100,9 +94,6 @@ def sanitize_course_json(course_json: str, model: str, temperature: float) -> di
|
||||
return _safe_json_load(rsp.choices[0].message.content or "{}")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
input_model = ExtractData
|
||||
output_model = ExtractData
|
||||
@ -112,14 +103,12 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
self.model_name = model_name or os.getenv("OPENAI_MODEL", "gpt-4o-mini")
|
||||
self.temperature = temperature
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
def __run__(self, data: ExtractData) -> ExtractData:
|
||||
lg = get_run_logger()
|
||||
|
||||
futures: List[PrefectFuture] = []
|
||||
originals: List[ExtractedCourse] = []
|
||||
|
||||
# 1) submit all courses to the LLM
|
||||
for term in data.terms:
|
||||
for course in term.courses:
|
||||
futures.append(
|
||||
@ -133,7 +122,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
|
||||
wait(futures)
|
||||
|
||||
# 2) build new graph with merged results
|
||||
terms_out: List[ExtractedTerm] = []
|
||||
idx = 0
|
||||
for term in data.terms:
|
||||
@ -149,16 +137,12 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
|
||||
renamed = ExtractData(terms=terms_out)
|
||||
|
||||
# 3) stage files with their new names
|
||||
self._export_with_new_names(data, renamed, lg)
|
||||
|
||||
return renamed
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# staging helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _stage_or_warn(self, src: Path, dst: Path, lg):
|
||||
"""Copy *src* → *dst* (via self.stage). Warn if src missing."""
|
||||
"""Stages file, warns if source missing."""
|
||||
if not src.exists():
|
||||
lg.warning("Source missing – skipped %s", src)
|
||||
return
|
||||
@ -175,7 +159,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
|
||||
for term_old, term_new in zip(original.terms, renamed.terms):
|
||||
for course_old, course_new in zip(term_old.courses, term_new.courses):
|
||||
# ---------- content files (per chapter) -----------------
|
||||
for chap_old, chap_new in zip(course_old.chapters, course_new.chapters):
|
||||
n = min(len(chap_old.content_files), len(chap_new.content_files))
|
||||
for i in range(n):
|
||||
@ -196,7 +179,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
)
|
||||
self._stage_or_warn(src, dst, lg)
|
||||
|
||||
# ---------- media files (course-level “media”) ----------
|
||||
src_media_dir = (
|
||||
entry / term_old.name / course_old.name / "media"
|
||||
) # <─ fixed!
|
||||
@ -204,7 +186,6 @@ class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
if not src_media_dir.is_dir():
|
||||
continue
|
||||
|
||||
# build a flat list of (old, new) media filenames
|
||||
media_pairs: List[tuple[ExtractedFile, ExtractedFile]] = []
|
||||
for ch_o, ch_n in zip(course_old.chapters, course_new.chapters):
|
||||
media_pairs.extend(zip(ch_o.media_files, ch_n.media_files))
|
||||
|
@ -2,9 +2,6 @@
|
||||
Shared lists and prompts
|
||||
"""
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# file selection – keep only real documents we can show / convert #
|
||||
# -------------------------------------------------------------------- #
|
||||
CONTENT_FILE_EXTENSIONS = [
|
||||
"*.pdf",
|
||||
"*.doc",
|
||||
@ -26,9 +23,6 @@ MEDIA_FILE_EXTENSIONS = [
|
||||
"*.mp3",
|
||||
]
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# naming rules #
|
||||
# -------------------------------------------------------------------- #
|
||||
SANITIZE_REGEX = {
|
||||
"base": [r"\s*\(\d+\)$"],
|
||||
"course": [
|
||||
|
@ -51,9 +51,6 @@ ALL_EXTS = CONTENT_EXTS | MEDIA_EXTS
|
||||
_id_rx = re.compile(r"\.(\d{4,})[./]") # 1172180 from “..._.1172180/index.html”
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _hash_id(fname: str) -> str:
|
||||
return hashlib.sha1(fname.encode()).hexdigest()[:10]
|
||||
|
||||
@ -80,17 +77,14 @@ def _best_payload(node: Path) -> Path | None: # noqa: C901
|
||||
• File_xxx/dir → search inside /content or dir itself
|
||||
• File_xxx/index.html stub → parse to find linked file
|
||||
"""
|
||||
# 1) immediate hit
|
||||
if node.is_file() and node.suffix.lower() in ALL_EXTS:
|
||||
return node
|
||||
|
||||
# 2) if html stub try to parse inner link
|
||||
if node.is_file() and node.suffix.lower() in {".html", ".htm"}:
|
||||
hinted = _html_stub_target(node)
|
||||
if hinted:
|
||||
return _best_payload(hinted) # recurse
|
||||
return _best_payload(hinted)
|
||||
|
||||
# 3) directories to search
|
||||
roots: list[Path] = []
|
||||
if node.is_dir():
|
||||
roots.append(node)
|
||||
@ -120,9 +114,6 @@ def task_(**kw):
|
||||
return task(**kw)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Extractor(Worker[DownloadData, ExtractData]):
|
||||
input_model = DownloadData
|
||||
output_model = ExtractData
|
||||
@ -158,7 +149,6 @@ class Extractor(Worker[DownloadData, ExtractData]):
|
||||
lg.info("Extractor finished – %d terms", len(result.terms))
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task_()
|
||||
def _extract_course( # noqa: C901
|
||||
@ -238,9 +228,6 @@ class Extractor(Worker[DownloadData, ExtractData]):
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _copy_all(
|
||||
root: Path, dst_root: Path, c_meta: ExtractedCourse, media_dir: Path, lg
|
||||
|
@ -5,24 +5,24 @@ from pydantic import BaseModel, Field
|
||||
|
||||
class ExtractedFile(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the file, relative to ExtractedChapter.name
|
||||
name: str
|
||||
|
||||
|
||||
class ExtractedChapter(BaseModel):
|
||||
name: str # Name of the chapter directory, relative to ExtractedCourse.name
|
||||
name: str
|
||||
content_files: List[ExtractedFile] = Field(default_factory=list)
|
||||
media_files: List[ExtractedFile] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractedCourse(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the course directory, relative to ExtractedTerm.name
|
||||
name: str
|
||||
chapters: List[ExtractedChapter] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractedTerm(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the term directory, relative to ExtractMeta.dir
|
||||
name: str
|
||||
courses: List[ExtractedCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
@ -1,7 +1,3 @@
|
||||
# -------------------------------------------------------------------- #
|
||||
# LLM prompts #
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
PROMPT_COURSE = """
|
||||
General naming rules
|
||||
====================
|
||||
|
@ -1,8 +1,4 @@
|
||||
"""
|
||||
URLs used by the scraper.
|
||||
Functions marked as PUBLIC can be accessed without authentication.
|
||||
Functions marked as PRIVATE require authentication.
|
||||
"""
|
||||
"""Scraper URLs. PUBLIC/PRIVATE indicates auth requirement."""
|
||||
|
||||
BASE_URL = "https://moodle.fhgr.ch"
|
||||
|
||||
|
@ -12,20 +12,8 @@ from librarian_scraper.constants import PRIVATE_URLS, PUBLIC_URLS
|
||||
|
||||
|
||||
class CookieCrawler:
|
||||
"""
|
||||
Retrieve Moodle session cookies + sesskey via Playwright.
|
||||
"""Retrieve Moodle session cookies + sesskey via Playwright."""
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> crawler = CookieCrawler()
|
||||
>>> cookies, sesskey = await crawler.crawl() # inside async code
|
||||
# or
|
||||
>>> cookies, sesskey = CookieCrawler.crawl_sync() # plain scripts
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# construction #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __init__(self, *, headless: bool = True) -> None:
|
||||
self.headless = headless
|
||||
self.cookies: Optional[List[Cookie]] = None
|
||||
@ -38,13 +26,8 @@ class CookieCrawler:
|
||||
"Set MOODLE_USERNAME and MOODLE_PASSWORD as environment variables."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# public API #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def crawl(self) -> tuple[Cookies, str]:
|
||||
"""
|
||||
Async entry-point – await this inside FastAPI / Prefect etc.
|
||||
"""
|
||||
"""Async method to crawl cookies and sesskey."""
|
||||
async with async_playwright() as p:
|
||||
browser: Browser = await p.chromium.launch(headless=self.headless)
|
||||
page = await browser.new_page()
|
||||
@ -61,51 +44,34 @@ class CookieCrawler:
|
||||
|
||||
@classmethod
|
||||
def crawl_sync(cls, **kwargs) -> tuple[Cookies, str]:
|
||||
"""
|
||||
Synchronous helper for CLI / notebooks.
|
||||
|
||||
Detects whether an event loop is already running. If so, it
|
||||
schedules the coroutine and waits; otherwise it starts a fresh loop.
|
||||
"""
|
||||
"""Synchronous version of crawl. Handles event loop."""
|
||||
self = cls(**kwargs)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError: # no loop running → safe to create one
|
||||
except RuntimeError:
|
||||
return asyncio.run(self.crawl())
|
||||
|
||||
# An event loop exists – schedule coroutine
|
||||
return loop.run_until_complete(self.crawl())
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def _login(self, page: Page) -> None:
|
||||
"""Fill the SSO form and extract cookies + sesskey."""
|
||||
|
||||
# Select organisation / IdP
|
||||
await page.click("#wayf_submit_button")
|
||||
|
||||
# Wait for the credential form
|
||||
await page.wait_for_selector("form[method='post']", state="visible")
|
||||
|
||||
# Credentials
|
||||
await page.fill("input[id='username']", self.username)
|
||||
await page.fill("input[id='password']", self.password)
|
||||
await page.click("button[class='aai_login_button']")
|
||||
|
||||
# Wait for redirect to /my/ page (dashboard), this means the login is complete
|
||||
await page.wait_for_url(PRIVATE_URLS.dashboard)
|
||||
await page.wait_for_selector("body", state="attached")
|
||||
|
||||
# Navigate to personal course overview
|
||||
await page.goto(PRIVATE_URLS.user_courses)
|
||||
await page.wait_for_selector("body", state="attached")
|
||||
|
||||
# Collect session cookies
|
||||
self.cookies = await page.context.cookies()
|
||||
|
||||
# Extract sesskey from injected Moodle config
|
||||
try:
|
||||
self.sesskey = await page.evaluate(
|
||||
"() => window.M && M.cfg && M.cfg.sesskey"
|
||||
@ -119,13 +85,9 @@ class CookieCrawler:
|
||||
logging.debug("sesskey: %s", self.sesskey)
|
||||
logging.debug("cookies: %s", self.cookies)
|
||||
|
||||
# Dev convenience
|
||||
if not self.headless:
|
||||
await page.wait_for_timeout(5000)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# cookie conversion #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _to_cookiejar(self, raw: List[Cookie]) -> Cookies:
|
||||
jar = Cookies()
|
||||
for c in raw:
|
||||
|
@ -39,18 +39,12 @@ from librarian_scraper.models.crawl_data import (
|
||||
CrawlTerm,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# module-level shared items for static task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
_COOKIE_JAR: httpx.Cookies | None = None
|
||||
_DELAY: float = 0.0
|
||||
|
||||
CACHE_FILE = get_cache_root() / "librarian_no_access_cache.json"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# utility #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def looks_like_enrol(resp: httpx.Response) -> bool:
|
||||
txt = resp.text.lower()
|
||||
return (
|
||||
@ -60,21 +54,14 @@ def looks_like_enrol(resp: httpx.Response) -> bool:
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# main worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
input_model = CrawlProgram
|
||||
output_model = CrawlData
|
||||
|
||||
# toggles (env overrides)
|
||||
RELAXED: bool
|
||||
USER_SPECIFIC: bool
|
||||
CLEAR_CACHE: bool
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# flow entry-point #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, program: CrawlProgram) -> CrawlData:
|
||||
global _COOKIE_JAR, _DELAY
|
||||
lg = get_run_logger()
|
||||
@ -93,7 +80,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
batch,
|
||||
)
|
||||
|
||||
# --------------------------- login
|
||||
cookies, _ = await CookieCrawler().crawl()
|
||||
_COOKIE_JAR = cookies
|
||||
self._client = httpx.Client(cookies=cookies, follow_redirects=True)
|
||||
@ -102,14 +88,11 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
lg.error("Guest session detected – aborting crawl.")
|
||||
raise RuntimeError("Login failed")
|
||||
|
||||
# --------------------------- cache
|
||||
no_access: set[str] = set() if self.CLEAR_CACHE else self._load_cache()
|
||||
|
||||
# --------------------------- scrape terms (first two for dev)
|
||||
terms = self._crawl_terms(program.id)[:2]
|
||||
lg.info("Terms discovered: %d", len(terms))
|
||||
|
||||
# --------------------------- scrape courses
|
||||
for term in terms:
|
||||
courses = self._crawl_courses(term.id)
|
||||
lg.info("[%s] raw courses: %d", term.name, len(courses))
|
||||
@ -137,7 +120,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
)
|
||||
lg.info("[%s] kept: %d", term.name, len(term.courses))
|
||||
|
||||
# --------------------------- persist cache
|
||||
self._save_cache(no_access)
|
||||
|
||||
return CrawlData(
|
||||
@ -148,9 +130,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# static task inside class #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task(
|
||||
name="crawl_course",
|
||||
@ -198,9 +177,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
|
||||
return course_id, href.split("=")[-1]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _logged_in(self) -> bool:
|
||||
html = self._get_html(PUBLIC_URLS.index)
|
||||
return not parsel.Selector(text=html).css("div.usermenu span.login a")
|
||||
@ -246,9 +222,6 @@ class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
get_run_logger().warning("GET %s failed (%s)", url, exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# cache helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _load_cache() -> set[str]:
|
||||
try:
|
||||
|
@ -22,11 +22,8 @@ class IndexCrawler:
|
||||
self.debug = debug
|
||||
self.client = httpx.Client(cookies=cookies, follow_redirects=True)
|
||||
self.max_workers = max_workers
|
||||
|
||||
# When True the cached “no-access” set is ignored for this run
|
||||
self._ignore_cache: bool = False
|
||||
|
||||
# Load persisted cache of course-IDs the user cannot access
|
||||
if NO_ACCESS_CACHE_FILE.exists():
|
||||
try:
|
||||
self._no_access_cache: set[str] = set(json.loads(NO_ACCESS_CACHE_FILE.read_text()))
|
||||
@ -54,7 +51,7 @@ class IndexCrawler:
|
||||
|
||||
def crawl_index(self, userSpecific: bool = True, *, use_cache: bool = True) -> MoodleIndex:
|
||||
"""
|
||||
Build and return a `MoodleIndex`.
|
||||
Builds and returns a `MoodleIndex`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -65,21 +62,17 @@ class IndexCrawler:
|
||||
afresh. Newly discovered “no-access” courses are still written back to the
|
||||
cache at the end of the crawl.
|
||||
"""
|
||||
# Set runtime flag for has_user_access()
|
||||
self._ignore_cache = not use_cache
|
||||
|
||||
semesters = []
|
||||
# Get all courses for each semester and the courseid and name for each course.
|
||||
semesters = self.crawl_semesters()
|
||||
# Crawl only the latest two semesters to reduce load (remove once caching is implemented)
|
||||
for semester in semesters[:2]:
|
||||
courses = self.crawl_courses(semester)
|
||||
|
||||
# Crawl courses in parallel to speed things up
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as pool:
|
||||
list(pool.map(self.crawl_course, courses))
|
||||
|
||||
# Filter courses once all have been processed
|
||||
for course in courses:
|
||||
if userSpecific:
|
||||
if course.content_ressource_id:
|
||||
@ -87,8 +80,6 @@ class IndexCrawler:
|
||||
else:
|
||||
semester.courses.append(course)
|
||||
|
||||
# Only add semesters that have at least one course
|
||||
# Filter out semesters that ended up with no courses after crawling
|
||||
semesters: list[Semester] = [
|
||||
semester for semester in semesters if semester.courses
|
||||
]
|
||||
@ -100,21 +91,13 @@ class IndexCrawler:
|
||||
semesters=semesters,
|
||||
),
|
||||
)
|
||||
# Persist any newly discovered no-access courses
|
||||
self._save_no_access_cache()
|
||||
|
||||
# Restore default behaviour for subsequent calls
|
||||
self._ignore_cache = False
|
||||
|
||||
return created_index
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# High-level crawling helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
def crawl_semesters(self) -> list[Semester]:
|
||||
"""
|
||||
Crawl the semesters from the Moodle index page.
|
||||
"""
|
||||
"""Crawls semester data."""
|
||||
url = URLs.get_degree_program_url(self.degree_program.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
@ -126,9 +109,7 @@ class IndexCrawler:
|
||||
return []
|
||||
|
||||
def crawl_courses(self, semester: Semester) -> list[Course]:
|
||||
"""
|
||||
Crawl the courses from the Moodle index page.
|
||||
"""
|
||||
"""Crawls course data for a semester."""
|
||||
url = URLs.get_semester_url(semester_id=semester.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
@ -140,10 +121,7 @@ class IndexCrawler:
|
||||
return []
|
||||
|
||||
def crawl_course(self, course: Course) -> None:
|
||||
"""
|
||||
Crawl a single Moodle course page.
|
||||
"""
|
||||
|
||||
"""Crawls details for a single course."""
|
||||
hasAccess = self.has_user_access(course)
|
||||
|
||||
if not hasAccess:
|
||||
@ -154,13 +132,8 @@ class IndexCrawler:
|
||||
course.content_ressource_id = self.crawl_content_ressource_id(course)
|
||||
course.files = self.crawl_course_files(course)
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Networking utilities
|
||||
# --------------------------------------------------------------------- #
|
||||
def get_with_retries(self, url: str, retries: int = 3, delay: int = 1) -> httpx.Response:
|
||||
"""
|
||||
Simple GET with retries and exponential back-off.
|
||||
"""
|
||||
"""Simple GET with retries and exponential back-off."""
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
response = self.client.get(url)
|
||||
@ -183,17 +156,11 @@ class IndexCrawler:
|
||||
f.write(response.text)
|
||||
logging.info(f"Saved HTML to {filename}")
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Extractors
|
||||
# --------------------------------------------------------------------- #
|
||||
def extract_semesters(self, html: str) -> list[Semester]:
|
||||
"""Extracts semester names and IDs from HTML."""
|
||||
selector = parsel.Selector(text=html)
|
||||
|
||||
logging.info("Extracting semesters from the HTML content.")
|
||||
|
||||
semesters: list[Semester] = []
|
||||
|
||||
# Each semester sits in a collapsed container
|
||||
semester_containers = selector.css("div.category.notloaded.with_children.collapsed")
|
||||
|
||||
for container in semester_containers:
|
||||
@ -207,7 +174,6 @@ class IndexCrawler:
|
||||
)
|
||||
semester_id = anchor.attrib.get("href", "").split("=")[-1]
|
||||
|
||||
# Only keep semesters labeled FS or HS
|
||||
if "FS" not in semester_name and "HS" not in semester_name:
|
||||
continue
|
||||
|
||||
@ -250,12 +216,9 @@ class IndexCrawler:
|
||||
)
|
||||
course_id = anchor.attrib.get("href", "").split("=")[-1]
|
||||
|
||||
# Remove trailing semester tag and code patterns
|
||||
course_name = re.sub(r"\s*(FS|HS)\d{2}\s*", "", course_name)
|
||||
course_name = re.sub(r"\s*\(.*?\)\s*", "", course_name).strip()
|
||||
|
||||
# Try to locate a hero/overview image that belongs to this course box
|
||||
# Traverse up to the containing course box, then look for <div class="courseimage"><img ...>
|
||||
course_container = header.xpath('./ancestor::*[contains(@class,"coursebox")][1]')
|
||||
hero_src = (
|
||||
course_container.css("div.courseimage img::attr(src)").get("")
|
||||
@ -275,10 +238,7 @@ class IndexCrawler:
|
||||
return courses
|
||||
|
||||
def has_user_access(self, course: Course) -> bool:
|
||||
"""
|
||||
Return True only if the authenticated user can access the course (result cached).
|
||||
(i.e. the response is HTTP 200 **and** is not a redirected login/enrol page).
|
||||
"""
|
||||
"""Checks if user can access course (caches negative results)."""
|
||||
if not self._ignore_cache and course.id in self._no_access_cache:
|
||||
return False
|
||||
|
||||
@ -289,21 +249,19 @@ class IndexCrawler:
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# Detect Moodle redirection to a login or enrolment page
|
||||
final_url = str(res.url).lower()
|
||||
if "login" in final_url or "enrol" in final_url:
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# Some enrolment pages still return 200; look for HTML markers
|
||||
if "#page-enrol" in res.text or "you need to enrol" in res.text.lower():
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# If we got here the user has access; otherwise cache the deny
|
||||
return True
|
||||
|
||||
def crawl_content_ressource_id(self, course: Course) -> str:
|
||||
"""Extracts content resource ID for a course."""
|
||||
course_id = course.id
|
||||
url = URLs.get_course_url(course_id)
|
||||
res = self.get_with_retries(url)
|
||||
@ -311,12 +269,10 @@ class IndexCrawler:
|
||||
|
||||
try:
|
||||
logging.info("Searching for 'Download course content' link.")
|
||||
# Use parsel CSS selector to find the anchor tag with the specific data attribute
|
||||
download_link_selector = psl.css('a[data-downloadcourse="1"]')
|
||||
if not download_link_selector:
|
||||
raise ValueError("Download link not found.")
|
||||
|
||||
# Extract the href attribute from the first matching element
|
||||
href = download_link_selector[0].attrib.get("href")
|
||||
if not href:
|
||||
raise ValueError("Href attribute not found on the download link.")
|
||||
@ -334,9 +290,7 @@ class IndexCrawler:
|
||||
return ''
|
||||
|
||||
def crawl_course_files(self, course: Course) -> list[FileEntry]:
|
||||
"""
|
||||
Crawl the course files from the Moodle course page.
|
||||
"""
|
||||
"""Crawls file entries for a course."""
|
||||
url = URLs.get_course_url(course.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
@ -347,10 +301,8 @@ class IndexCrawler:
|
||||
|
||||
return []
|
||||
|
||||
# ----------------------------------------------------------------- #
|
||||
# Cache persistence helpers
|
||||
# ----------------------------------------------------------------- #
|
||||
def _save_no_access_cache(self) -> None:
|
||||
"""Saves course IDs the user cannot access to a cache file."""
|
||||
try:
|
||||
NO_ACCESS_CACHE_FILE.write_text(json.dumps(sorted(self._no_access_cache)))
|
||||
except Exception as exc:
|
||||
|
@ -1,9 +1,5 @@
|
||||
# TODO: Move to librarian-core
|
||||
"""
|
||||
All URLs used in the crawler.
|
||||
Functions marked as PUBLIC can be accessed without authentication.
|
||||
Functions marked as PRIVATE require authentication.
|
||||
"""
|
||||
"""Moodle URLs. PUBLIC/PRIVATE indicates auth requirement."""
|
||||
class URLs:
|
||||
base_url = "https://moodle.fhgr.ch"
|
||||
|
||||
@ -12,7 +8,6 @@ class URLs:
|
||||
"""PUBLIC"""
|
||||
return cls.base_url
|
||||
|
||||
# ------------------------- Moodle URLs -------------------------
|
||||
@classmethod
|
||||
def get_login_url(cls):
|
||||
"""PUBLIC"""
|
||||
|
@ -35,9 +35,6 @@ from librarian_scraper.models.download_data import (
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helper decorator #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def task_(**kw):
|
||||
kw.setdefault("log_prints", True)
|
||||
kw.setdefault("retries", 2)
|
||||
@ -45,9 +42,6 @@ def task_(**kw):
|
||||
return task(**kw)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# shared state for static task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
_COOKIE_JAR: httpx.Cookies | None = None
|
||||
_SESSKEY: str = ""
|
||||
_LIMIT: int = 2
|
||||
@ -57,27 +51,22 @@ _DELAY: float = 0.0
|
||||
class Downloader(Worker[CrawlData, DownloadData]):
|
||||
DOWNLOAD_URL = "https://moodle.fhgr.ch/course/downloadcontent.php"
|
||||
|
||||
# tuning
|
||||
CONCURRENCY = 8
|
||||
RELAXED = True # False → faster
|
||||
RELAXED = True
|
||||
|
||||
input_model = CrawlData
|
||||
output_model = DownloadData
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, crawl: CrawlData) -> DownloadData:
|
||||
global _COOKIE_JAR, _SESSKEY, _LIMIT, _DELAY
|
||||
lg = get_run_logger()
|
||||
|
||||
# ------------ login
|
||||
cookies, sesskey = await CookieCrawler().crawl()
|
||||
_COOKIE_JAR, _SESSKEY = cookies, sesskey
|
||||
|
||||
# ------------ tuning
|
||||
_LIMIT = 1 if self.RELAXED else max(1, min(self.CONCURRENCY, 8))
|
||||
_DELAY = CRAWLER["DELAY_SLOW"] if self.RELAXED else CRAWLER["DELAY_FAST"]
|
||||
|
||||
# ------------ working dir
|
||||
work_root = Path(get_temp_path()) / f"dl_{int(time.time())}"
|
||||
work_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@ -85,7 +74,6 @@ class Downloader(Worker[CrawlData, DownloadData]):
|
||||
futures = []
|
||||
term_dirs: List[Tuple[str, Path]] = []
|
||||
|
||||
# schedule downloads
|
||||
for term in crawl.degree_program.terms:
|
||||
term_dir = work_root / term.name
|
||||
term_dir.mkdir(parents=True, exist_ok=True)
|
||||
@ -101,18 +89,14 @@ class Downloader(Worker[CrawlData, DownloadData]):
|
||||
self._download_task.submit(course.content_ressource_id, dest)
|
||||
)
|
||||
|
||||
wait(futures) # block for all downloads
|
||||
wait(futures)
|
||||
|
||||
# stage term directories
|
||||
for name, dir_path in term_dirs:
|
||||
self.stage(dir_path, new_name=name, sanitize=False, move=True)
|
||||
|
||||
lg.info("Downloader finished – staged %d term folders", len(term_dirs))
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# static task #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task_()
|
||||
def _download_task(context_id: str, dest: Path) -> None:
|
||||
|
@ -135,9 +135,6 @@ MoodleIndex: {
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base Model
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlData(BaseModel):
|
||||
degree_program: CrawlProgram = Field(
|
||||
default_factory=lambda: CrawlProgram(id="", name="")
|
||||
@ -147,18 +144,12 @@ class CrawlData(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Degree Program
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlProgram(BaseModel):
|
||||
id: str = Field("1157", description="Unique identifier for the degree program.")
|
||||
name: str = Field("Computational and Data Science", description="Name of the degree program.")
|
||||
terms: list[CrawlTerm] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Term
|
||||
# ---------------------------------------------------------------------------
|
||||
_TERM_RE = re.compile(r"^(HS|FS)\d{2}$") # HS24 / FS25 …
|
||||
|
||||
|
||||
@ -168,9 +159,6 @@ class CrawlTerm(BaseModel):
|
||||
courses: list[CrawlCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Course
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlCourse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
@ -179,9 +167,6 @@ class CrawlCourse(BaseModel):
|
||||
files: list[CrawlFile] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Files
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlFile(BaseModel):
|
||||
id: str
|
||||
res_id: str
|
||||
|
@ -62,9 +62,6 @@ def _create_hnsw_index(
|
||||
except Exception:
|
||||
logger.exception("Failed to run create_or_reindex_hnsw")
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# single file #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def embed_single_file(
|
||||
*,
|
||||
course_id: str,
|
||||
@ -104,7 +101,6 @@ def embed_single_file(
|
||||
wf.process()
|
||||
return chunk_path
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
async def run_embedder(
|
||||
course: ChunkCourse,
|
||||
concat_path: Union[str, Path],
|
||||
|
@ -39,7 +39,6 @@ class EmbeddingWorkflow:
|
||||
|
||||
# No need to store db_schema/db_function here if inserter handles it
|
||||
|
||||
# ---------------- helpers ----------------
|
||||
def _load_chunk(self) -> Optional[str]:
|
||||
try:
|
||||
text = self.chunk_path.read_text(encoding="utf-8").strip()
|
||||
@ -85,7 +84,7 @@ class EmbeddingWorkflow:
|
||||
return False
|
||||
|
||||
logger.debug(f"Successfully processed and inserted {self.chunk_path}")
|
||||
return True # Indicate success
|
||||
return True
|
||||
|
||||
|
||||
# Keep __all__ if needed
|
||||
|
@ -22,7 +22,6 @@ from librarian_core.workers.base import Worker
|
||||
|
||||
from librarian_vspace.vecview.vecview import get_tsne_json
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
def _safe_get_logger(name: str):
|
||||
try:
|
||||
return get_run_logger()
|
||||
@ -30,9 +29,7 @@ def _safe_get_logger(name: str):
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pydantic payloads
|
||||
# ------------------------------------------------------------------ #
|
||||
class TsneExportInput(BaseModel):
|
||||
course_id: int
|
||||
limit: Optional[int] = None
|
||||
@ -48,7 +45,6 @@ class TsneExportOutput(BaseModel):
|
||||
json_path: Path
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
class TsneExportWorker(Worker[TsneExportInput, TsneExportOutput]):
|
||||
"""Runs the t‑SNE export inside a Prefect worker.""" # noqa: D401
|
||||
|
||||
|
@ -20,9 +20,6 @@ logger = logging.getLogger(__name__)
|
||||
DEFAULT_N_CLUSTERS = 8
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Internal helpers (kept minimal – no extra bells & whistles)
|
||||
# --------------------------------------------------------------------- #
|
||||
def _run_kmeans(df: pd.DataFrame, *, embedding_column: str, k: int = DEFAULT_N_CLUSTERS) -> pd.DataFrame:
|
||||
"""Adds a 'cluster' column using K‑means (string labels)."""
|
||||
if df.empty or embedding_column not in df.columns:
|
||||
@ -61,9 +58,6 @@ def _add_hover(df: pd.DataFrame) -> pd.DataFrame:
|
||||
return df
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Public helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
def get_tsne_dataframe(
|
||||
db_schema: str,
|
||||
db_function: str,
|
||||
|
@ -58,9 +58,6 @@ import pandas as pd
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import pairwise_distances_argmin_min
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Map Vectorbase credential names → Supabase names expected by loader code
|
||||
# ---------------------------------------------------------------------------
|
||||
_ALIAS_ENV_MAP = {
|
||||
"VECTORBASE_URL": "SUPABASE_URL",
|
||||
"VECTORBASE_API_KEY": "SUPABASE_KEY",
|
||||
@ -86,11 +83,8 @@ except ImportError as e:
|
||||
raise ImportError(f"Could not import VectorQueryLoader: {e}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging setup (used by both script and callable function)
|
||||
# ---------------------------------------------------------------------------
|
||||
# This basicConfig runs when the module is imported.
|
||||
# Callers might want to configure logging before importing.
|
||||
# Callers might want toconfigure logging before importing.
|
||||
# If logging is already configured, basicConfig does nothing.
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
@ -99,17 +93,11 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger(__name__) # Use __name__ for module-specific logger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper – JSON dump for centroid in YAML front‑matter
|
||||
# ---------------------------------------------------------------------------
|
||||
def centroid_to_json(vec: np.ndarray) -> str:
|
||||
"""Converts a numpy vector to a JSON string suitable for YAML frontmatter."""
|
||||
return json.dumps([float(x) for x in vec], ensure_ascii=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main clustering and export logic as a callable function
|
||||
# ---------------------------------------------------------------------------
|
||||
def run_cluster_export_job(
|
||||
course_id: Optional[int] = None, # Added course_id parameter
|
||||
output_dir: Union[str, Path] = "./cluster_md", # Output directory parameter
|
||||
@ -147,9 +135,7 @@ def run_cluster_export_job(
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Writing Markdown files to %s", output_path)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch embeddings - Now using VectorQueryLoader with filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
# Use parameters for loader config
|
||||
# --- FIX: Instantiate VectorQueryLoader ---
|
||||
@ -249,9 +235,7 @@ def run_cluster_export_job(
|
||||
# -------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prepare training sample and determine effective k
|
||||
# ---------------------------------------------------------------------------
|
||||
# Use the parameter train_sample_size
|
||||
train_vecs = embeddings[:train_sample_size]
|
||||
|
||||
@ -302,9 +286,7 @@ def run_cluster_export_job(
|
||||
raise RuntimeError(f"K-means clustering failed: {e}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assign every vector to its nearest centroid (full table)
|
||||
# ---------------------------------------------------------------------------
|
||||
logger.info("Assigning vectors to centroids...")
|
||||
try:
|
||||
# Use the determined embedding column for assignment as well
|
||||
@ -315,9 +297,7 @@ def run_cluster_export_job(
|
||||
logger.exception("Failed to assign vectors to centroids.")
|
||||
raise RuntimeError(f"Failed to assign vectors to centroids: {e}") from e
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write one Markdown file per cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
files_written_count = 0
|
||||
|
||||
logger.info("Writing cluster Markdown files to %s", output_path)
|
||||
@ -385,9 +365,7 @@ def run_cluster_export_job(
|
||||
return output_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Script entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Configuration via environment for script
|
||||
script_output_dir = Path(os.environ.get("OUTPUT_DIR", "./cluster_md")).expanduser()
|
||||
|
@ -37,15 +37,9 @@ from librarian_vspace.models.query_model import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Main helper
|
||||
# --------------------------------------------------------------------- #
|
||||
class VectorQuery(BaseVectorOperator):
|
||||
"""High‑level helper for vector searches via Supabase RPC."""
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Public – modern API
|
||||
# -----------------------------------------------------------------
|
||||
def search(self, request: VectorSearchRequest) -> VectorSearchResponse:
|
||||
"""Perform a similarity search and return structured results."""
|
||||
|
||||
@ -100,9 +94,6 @@ class VectorQuery(BaseVectorOperator):
|
||||
logger.exception("RPC 'vector_search' failed: %s", exc)
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Public – legacy compatibility
|
||||
# -----------------------------------------------------------------
|
||||
def get_chucklets_by_vector(
|
||||
self,
|
||||
*,
|
||||
|
@ -38,9 +38,6 @@ import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Optional dependencies #
|
||||
# --------------------------------------------------------------------------- #
|
||||
try:
|
||||
import psutil # type: ignore
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
@ -51,10 +48,6 @@ try:
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
torch = None # type: ignore
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Hardware discovery helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def logical_cores() -> int:
|
||||
@ -124,11 +117,6 @@ def cgroup_cpu_limit() -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# GPU helpers (CUDA only for now) #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def gpu_info() -> Optional[Dict[str, Any]]:
|
||||
"""Return basic info for the first CUDA device via *torch* (or ``None``)."""
|
||||
@ -147,11 +135,6 @@ def gpu_info() -> Optional[Dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Recommendation logic #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def recommended_workers(
|
||||
*,
|
||||
kind: Literal["cpu", "io", "gpu"] = "cpu",
|
||||
@ -194,11 +177,6 @@ def recommended_workers(
|
||||
return max(1, base)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Convenience snapshot of system info #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def system_snapshot() -> Dict[str, Any]:
|
||||
"""Return a JSON‑serialisable snapshot of parallelism‑related facts."""
|
||||
return {
|
||||
@ -211,11 +189,6 @@ def system_snapshot() -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# CLI #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _cli() -> None: # pragma: no cover
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="parallelism_advisor", description="Rule‑of‑thumb worker estimator"
|
||||
|
@ -24,13 +24,11 @@ class BaseVectorOperator:
|
||||
self.table: Optional[str] = None
|
||||
self._resolve_ids()
|
||||
|
||||
# ---------------- public helpers ----------------
|
||||
def table_fqn(self) -> str:
|
||||
if not self.table:
|
||||
raise RuntimeError("VectorOperator not initialised – no table")
|
||||
return f"{self.schema}.{self.table}"
|
||||
|
||||
# ---------------- internals ----------------
|
||||
def _resolve_ids(self) -> None:
|
||||
self.model_id = self._rpc_get_model_id(self.model)
|
||||
if self.model_id is None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user