Initialize Monorepo
This commit is contained in:
commit
f80792d739
364
.gitignore
vendored
Normal file
364
.gitignore
vendored
Normal file
@ -0,0 +1,364 @@
|
||||
# Created by https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,web,pycharm+all
|
||||
# Edit at https://www.toptal.com/developers/gitignore?templates=macos,windows,linux,python,web,pycharm+all
|
||||
|
||||
### Linux ###
|
||||
*~
|
||||
|
||||
# temporary files which can be created if a process still has a handle open of a deleted file
|
||||
.fuse_hidden*
|
||||
|
||||
# KDE directory preferences
|
||||
.directory
|
||||
|
||||
# Linux trash folder which might appear on any partition or disk
|
||||
.Trash-*
|
||||
|
||||
# .nfs files are created when an open file is removed but is still being accessed
|
||||
.nfs*
|
||||
|
||||
### macOS ###
|
||||
# General
|
||||
.DS_Store
|
||||
.AppleDouble
|
||||
.LSOverride
|
||||
|
||||
# Icon must end with two \r
|
||||
Icon
|
||||
|
||||
|
||||
# Thumbnails
|
||||
._*
|
||||
|
||||
# Files that might appear in the root of a volume
|
||||
.DocumentRevisions-V100
|
||||
.fseventsd
|
||||
.Spotlight-V100
|
||||
.TemporaryItems
|
||||
.Trashes
|
||||
.VolumeIcon.icns
|
||||
.com.apple.timemachine.donotpresent
|
||||
|
||||
# Directories potentially created on remote AFP share
|
||||
.AppleDB
|
||||
.AppleDesktop
|
||||
Network Trash Folder
|
||||
Temporary Items
|
||||
.apdisk
|
||||
|
||||
### macOS Patch ###
|
||||
# iCloud generated files
|
||||
*.icloud
|
||||
|
||||
### PyCharm+all ###
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# AWS User-specific
|
||||
.idea/**/aws.xml
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# SonarLint plugin
|
||||
.idea/sonarlint/
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### PyCharm+all Patch ###
|
||||
# Ignore everything but code style settings and run configurations
|
||||
# that are supposed to be shared within teams.
|
||||
|
||||
.idea/*
|
||||
|
||||
!.idea/codeStyles
|
||||
!.idea/runConfigurations
|
||||
|
||||
### Python ###
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
### Python Patch ###
|
||||
# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
|
||||
poetry.toml
|
||||
|
||||
# ruff
|
||||
.ruff_cache/
|
||||
|
||||
# LSP config files
|
||||
pyrightconfig.json
|
||||
|
||||
### Web ###
|
||||
*.asp
|
||||
*.cer
|
||||
*.csr
|
||||
*.css
|
||||
*.htm
|
||||
*.html
|
||||
*.js
|
||||
*.jsp
|
||||
*.php
|
||||
*.rss
|
||||
*.wasm
|
||||
*.wat
|
||||
*.xhtml
|
||||
|
||||
### Windows ###
|
||||
# Windows thumbnail cache files
|
||||
Thumbs.db
|
||||
Thumbs.db:encryptable
|
||||
ehthumbs.db
|
||||
ehthumbs_vista.db
|
||||
|
||||
# Dump file
|
||||
*.stackdump
|
||||
|
||||
# Folder config file
|
||||
[Dd]esktop.ini
|
||||
|
||||
# Recycle Bin used on file shares
|
||||
$RECYCLE.BIN/
|
||||
|
||||
# Windows Installer files
|
||||
*.cab
|
||||
*.msi
|
||||
*.msix
|
||||
*.msm
|
||||
*.msp
|
||||
|
||||
# Windows shortcuts
|
||||
*.lnk
|
||||
|
||||
# End of https://www.toptal.com/developers/gitignore/api/macos,windows,linux,python,web,pycharm+all
|
||||
# local env files
|
||||
**/.env*.local
|
||||
**/.env
|
||||
!**/.env.example
|
||||
|
||||
# vercel
|
||||
.vercel
|
||||
|
||||
# typescript
|
||||
*.tsbuildinfo
|
||||
next-env.d.ts
|
13
librarian/librarian-core/README.md
Normal file
13
librarian/librarian-core/README.md
Normal file
@ -0,0 +1,13 @@
|
||||
# Usage
|
||||
|
||||
In your `pyproject.toml` add the following code:
|
||||
|
||||
```toml
|
||||
dependencies = [
|
||||
"librarian-core",
|
||||
"...other dependencies"
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
librarian-core = { git = "https://github.com/DotNaos/librarian-core", rev = "dev" }
|
||||
```
|
38
librarian/librarian-core/pyproject.toml
Normal file
38
librarian/librarian-core/pyproject.toml
Normal file
@ -0,0 +1,38 @@
|
||||
[project]
|
||||
name = "librarian-core"
|
||||
version = "0.1.6"
|
||||
readme = "README.md"
|
||||
description = "Shared datamodel & utils for the Librarian project"
|
||||
requires-python = ">=3.10"
|
||||
authors = [
|
||||
{ name = "DotNaos", email = "schuetzoliver00@gmail.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"pandas>=2.2.3",
|
||||
"platformdirs>=4.3.7",
|
||||
"pydantic-settings>=2.9.1",
|
||||
"supabase",
|
||||
"tabulate>=0.9.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0", # Testing framework
|
||||
"pytest-cov", # Coverage reporting
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
# src/ layout
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/librarian_core"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
pythonpath = ["src"]
|
||||
testpaths = ["tests"]
|
||||
addopts = "--cov=librarian_core --cov-report=term-missing"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["librarian_core"]
|
20
librarian/librarian-core/src/librarian_core/__init__.py
Normal file
20
librarian/librarian-core/src/librarian_core/__init__.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pkgutil
|
||||
import importlib
|
||||
|
||||
__all__ = []
|
||||
|
||||
# Iterate over all modules in this package
|
||||
for finder, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
# import the sub-module
|
||||
module = importlib.import_module(f"{__name__}.{module_name}")
|
||||
|
||||
# decide which names to re-export:
|
||||
# use module.__all__ if it exists, otherwise every non-private attribute
|
||||
public_names = getattr(
|
||||
module, "__all__", [n for n in dir(module) if not n.startswith("_")]
|
||||
)
|
||||
|
||||
# bring each name into the package namespace
|
||||
for name in public_names:
|
||||
globals()[name] = getattr(module, name)
|
||||
__all__.append(name) # type: ignore
|
@ -0,0 +1,5 @@
|
||||
from librarian_core.storage.worker_store import WorkerStore
|
||||
|
||||
__all__ = [
|
||||
"WorkerStore",
|
||||
]
|
@ -0,0 +1,243 @@
|
||||
"""
|
||||
librarian_core.storage.worker_store
|
||||
===================================
|
||||
|
||||
Persistent directory layout
|
||||
---------------------------
|
||||
<data_root>/flows/<worker>/<run_id>/
|
||||
meta.json # worker_name, state, timestamps …
|
||||
result.json # pydantic-serialised return model
|
||||
data/ # files staged by the worker
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from librarian_core.utils import path_utils
|
||||
|
||||
|
||||
class WorkerStore:
|
||||
"""Never exposed to worker code – all access is via helper methods."""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# constructors #
|
||||
# ------------------------------------------------------------------ #
|
||||
@classmethod
|
||||
def new(cls, *, worker_name: str, flow_id: str) -> "WorkerStore":
|
||||
run_dir = path_utils.get_run_dir(worker_name, flow_id, create=True)
|
||||
store = cls(run_dir, worker_name, flow_id)
|
||||
store._write_meta(state="RUNNING")
|
||||
return store
|
||||
|
||||
@classmethod
|
||||
def open(cls, run_id: str) -> "WorkerStore":
|
||||
"""
|
||||
Locate `<flows>/<worker>/<run_id>` regardless of worker name.
|
||||
"""
|
||||
flows_dir = path_utils.get_flows_dir()
|
||||
for worker_dir in flows_dir.iterdir():
|
||||
candidate = worker_dir / run_id
|
||||
if candidate.exists():
|
||||
meta_path = candidate / "meta.json"
|
||||
if not meta_path.is_file():
|
||||
continue
|
||||
meta = json.loads(meta_path.read_text())
|
||||
return cls(candidate, meta["worker_name"], run_id)
|
||||
raise FileNotFoundError(run_id)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# life-cycle #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __init__(self, run_dir: Path, worker_name: str, flow_id: str):
|
||||
self._run_dir = run_dir
|
||||
self._worker_name = worker_name
|
||||
self._flow_id = flow_id
|
||||
|
||||
cache_root = path_utils.get_cache_root()
|
||||
self._work_dir = Path(
|
||||
tempfile.mkdtemp(prefix=f"{self._flow_id}-", dir=cache_root)
|
||||
)
|
||||
|
||||
self._entry_dir = self._work_dir / "entry"
|
||||
self._exit_dir = self._work_dir / "exit"
|
||||
self._entry_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._exit_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# entry / exit handling #
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def entry_dir(self) -> Path:
|
||||
return self._entry_dir
|
||||
|
||||
def prime_with_input(self, src: Optional[Path]) -> None:
|
||||
if src and src.exists():
|
||||
shutil.copytree(src, self._entry_dir, dirs_exist_ok=True)
|
||||
|
||||
def stage(
|
||||
self,
|
||||
src: Path | str,
|
||||
*,
|
||||
new_name: str | None = None,
|
||||
sanitize: bool = True,
|
||||
move: bool = False,
|
||||
) -> Path:
|
||||
src_path = Path(src).expanduser().resolve()
|
||||
if not src_path.exists():
|
||||
raise FileNotFoundError(src_path)
|
||||
|
||||
name = new_name or src_path.name
|
||||
if sanitize:
|
||||
name = path_utils._sanitize(name)
|
||||
|
||||
dst = self._exit_dir / name
|
||||
if dst.exists():
|
||||
if dst.is_file():
|
||||
dst.unlink()
|
||||
else:
|
||||
shutil.rmtree(dst)
|
||||
|
||||
if move:
|
||||
src_path.rename(dst)
|
||||
else:
|
||||
if src_path.is_dir():
|
||||
shutil.copytree(src_path, dst)
|
||||
else:
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src_path, dst)
|
||||
return dst
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# result persistence #
|
||||
# ------------------------------------------------------------------ #
|
||||
def save_model(
|
||||
self,
|
||||
model: BaseModel,
|
||||
*,
|
||||
filename: str = "result.json",
|
||||
**json_kwargs: Any,
|
||||
) -> Path:
|
||||
json_kwargs.setdefault("indent", 2)
|
||||
target = self._run_dir / filename
|
||||
target.write_text(model.model_dump_json(**json_kwargs))
|
||||
return target
|
||||
|
||||
def persist_exit(self) -> Path:
|
||||
"""
|
||||
Move the *exit* directory to the persistent *data* slot and mark
|
||||
the run completed.
|
||||
"""
|
||||
data_dir = self.data_dir
|
||||
if data_dir.exists():
|
||||
shutil.rmtree(data_dir)
|
||||
self._exit_dir.rename(data_dir)
|
||||
self._write_meta(state="COMPLETED")
|
||||
return data_dir
|
||||
|
||||
def cleanup(self) -> None:
|
||||
shutil.rmtree(self._work_dir, ignore_errors=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# public helpers (API needs these) #
|
||||
# ------------------------------------------------------------------ #
|
||||
@property
|
||||
def data_dir(self) -> Path:
|
||||
return self._run_dir / "data"
|
||||
|
||||
@property
|
||||
def meta_path(self) -> Path:
|
||||
return self._run_dir / "meta.json"
|
||||
|
||||
@property
|
||||
def metadata(self) -> dict[str, Any]:
|
||||
return json.loads(self.meta_path.read_text())
|
||||
|
||||
def load_model(self, *, as_dict: bool = False) -> dict | BaseModel | None:
|
||||
res_file = self._run_dir / "result.json"
|
||||
if not res_file.is_file():
|
||||
return None
|
||||
data = json.loads(res_file.read_text())
|
||||
if as_dict:
|
||||
return data
|
||||
# try to reconstruct a Pydantic model if possible
|
||||
try:
|
||||
OutputModel: Type[BaseModel] | None = self._guess_output_model()
|
||||
if OutputModel:
|
||||
return TypeAdapter(OutputModel).validate_python(data)
|
||||
except Exception:
|
||||
pass
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
# TODO: Should return a Flowartifact, but circular import is messing
|
||||
def load_latest(worker_name: str) -> dict[str, Any] | None:
|
||||
flows_dir = path_utils.get_flows_dir()
|
||||
worker_dir = flows_dir / worker_name
|
||||
if not worker_dir.exists():
|
||||
return None
|
||||
|
||||
runs: list[tuple[datetime, Path]] = []
|
||||
for run_id in worker_dir.iterdir():
|
||||
if not run_id.is_dir():
|
||||
continue
|
||||
meta_path = run_id / "meta.json"
|
||||
if not meta_path.is_file():
|
||||
continue
|
||||
meta = json.loads(meta_path.read_text())
|
||||
if meta["state"] == "COMPLETED":
|
||||
runs.append((datetime.fromisoformat(meta["timestamp"]), run_id))
|
||||
|
||||
if not runs:
|
||||
return None
|
||||
sorted_runs = sorted(runs, key=lambda x: x[0])
|
||||
|
||||
latest_run_dir = sorted_runs[-1][1]
|
||||
|
||||
# Load the model
|
||||
return { # That is a FlowArtifact
|
||||
"run_id": latest_run_dir.name,
|
||||
"dir": latest_run_dir / "data",
|
||||
"data": WorkerStore.open(latest_run_dir.name).load_model(as_dict=True), # type: ignore
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internals #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _write_meta(self, *, state: str) -> None:
|
||||
meta = {
|
||||
"worker_name": self._worker_name,
|
||||
"run_id": self._flow_id,
|
||||
"state": state,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
self.meta_path.write_text(json.dumps(meta, indent=2))
|
||||
|
||||
def _guess_output_model(self) -> Optional[Type[BaseModel]]:
|
||||
"""
|
||||
Best-effort import of `<worker_name>.output_model`.
|
||||
"""
|
||||
try:
|
||||
from importlib import import_module
|
||||
|
||||
# workers are registered with dotted names in plugin_loader
|
||||
mod = import_module(self._worker_name)
|
||||
return getattr(mod, "output_model", None)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# clean-up #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __del__(self) -> None:
|
||||
try:
|
||||
shutil.rmtree(self._work_dir, ignore_errors=True)
|
||||
except Exception:
|
||||
pass
|
@ -0,0 +1,3 @@
|
||||
from .client import get_client, SupabaseGateway
|
||||
|
||||
__all__ = ["get_client", "SupabaseGateway"]
|
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict
|
||||
from pydantic import BaseModel
|
||||
from supabase import create_client, Client
|
||||
import os, logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _Cfg(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
db_schema: str = "library"
|
||||
|
||||
|
||||
def _load_cfg() -> _Cfg:
|
||||
return _Cfg(
|
||||
url=os.getenv("SUPABASE_URL", ""),
|
||||
key=os.getenv("SUPABASE_API_KEY", ""),
|
||||
)
|
||||
|
||||
|
||||
_client: Client | None = None
|
||||
_cfg: _Cfg | None = None
|
||||
|
||||
|
||||
def get_client() -> Client:
|
||||
global _client, _cfg
|
||||
if _client:
|
||||
return _client
|
||||
_cfg = _load_cfg()
|
||||
if not _cfg.url or not _cfg.key:
|
||||
raise RuntimeError("SUPABASE_URL or SUPABASE_API_KEY missing")
|
||||
_client = create_client(_cfg.url, _cfg.key)
|
||||
return _client
|
||||
|
||||
|
||||
class SupabaseGateway:
|
||||
"""
|
||||
Thin wrapper around Client with `schema()` pre-selected
|
||||
and a helper `_rpc()` that raises RuntimeError on error.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.client = get_client()
|
||||
self.schema = _cfg.db_schema if _cfg else "library"
|
||||
|
||||
# ---------- internal ----------
|
||||
def _rpc(self, fn: str, payload: Dict[str, Any] | None = None):
|
||||
resp = (
|
||||
self.client.schema(self.schema)
|
||||
.rpc(fn, payload or {})
|
||||
.execute()
|
||||
.model_dump()
|
||||
)
|
||||
if resp.get("error"):
|
||||
log.error("%s error: %s", fn, resp["error"])
|
||||
raise RuntimeError(resp["error"])
|
||||
log.debug("%s OK", fn)
|
||||
return resp.get("data")
|
61
librarian/librarian-core/src/librarian_core/supabase/rpc.py
Normal file
61
librarian/librarian-core/src/librarian_core/supabase/rpc.py
Normal file
@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from librarian_scraper.models import CrawlCourse, CrawlTerm, MoodleIndex
|
||||
|
||||
from librarian_core.supabase.client import SupabaseGateway
|
||||
|
||||
gw = SupabaseGateway() # singleton gateway
|
||||
|
||||
|
||||
# -------- public API --------
|
||||
def upload_index(index: MoodleIndex) -> None:
|
||||
dp = index.degree_program
|
||||
_upsert_degree_program(dp.id, dp.name)
|
||||
for term in dp.terms:
|
||||
_upsert_term(term)
|
||||
_upsert_courses(term.courses, term_id=term.id, prog_id=dp.id)
|
||||
|
||||
|
||||
def upload_modules(modules_index) -> None:
|
||||
# TODO – same pattern
|
||||
...
|
||||
|
||||
|
||||
# -------- helpers --------
|
||||
def _upsert_degree_program(dp_id: str, name: str):
|
||||
gw._rpc(
|
||||
"upsert_degree_program",
|
||||
{
|
||||
"p_program_id": dp_id,
|
||||
"p_program_name": name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _upsert_term(term: CrawlTerm):
|
||||
# TODO: Change to term, when supabase is updated
|
||||
gw._rpc(
|
||||
"upsert_semester",
|
||||
{
|
||||
"p_semester_id": term.id,
|
||||
"p_semester_name": term.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _upsert_courses(courses: List[CrawlCourse], *, term_id: str, prog_id: str):
|
||||
# TODO: Change to term, when supabase is updated
|
||||
for c in courses:
|
||||
gw._rpc(
|
||||
"upsert_course",
|
||||
{
|
||||
"p_course_id": c.id,
|
||||
"p_course_name": c.name,
|
||||
"p_semester_id": term_id,
|
||||
"p_program_id": prog_id,
|
||||
"p_hero_image": c.hero_image,
|
||||
"p_content_ressource_id": c.content_ressource_id,
|
||||
},
|
||||
)
|
@ -0,0 +1,11 @@
|
||||
from .chunk_data import (
|
||||
ChunkCourse,
|
||||
ChunkFile,
|
||||
ChunkData,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChunkData",
|
||||
"ChunkCourse",
|
||||
"ChunkFile",
|
||||
]
|
@ -0,0 +1,19 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: Move to librarian-chunker
|
||||
|
||||
class ChunkFile(BaseModel):
|
||||
name: str = Field(..., description="Name of the file")
|
||||
id: str = Field(..., description="ID of the file")
|
||||
|
||||
|
||||
class ChunkCourse(BaseModel):
|
||||
id: str = Field(..., description="ID of the course")
|
||||
name: str = Field(..., description="Name of the course")
|
||||
files: List[ChunkFile] = Field(..., description="List of files in the course")
|
||||
|
||||
|
||||
class ChunkData(BaseModel):
|
||||
courses: List[ChunkCourse] = Field(..., description="List of courses")
|
@ -0,0 +1,25 @@
|
||||
from librarian_core.utils.path_utils import (
|
||||
copy_to_temp_dir,
|
||||
get_cache_root,
|
||||
get_config_root,
|
||||
get_data_root,
|
||||
get_flow_name_from_id,
|
||||
get_flows_dir,
|
||||
get_run_dir,
|
||||
get_temp_path,
|
||||
get_workers_dir,
|
||||
)
|
||||
from librarian_core.utils.secrets_loader import load_env
|
||||
|
||||
__all__ = [
|
||||
"load_env",
|
||||
"get_temp_path",
|
||||
"get_run_dir",
|
||||
"get_flow_name_from_id",
|
||||
"copy_to_temp_dir",
|
||||
"get_cache_root",
|
||||
"get_data_root",
|
||||
"get_config_root",
|
||||
"get_flows_dir",
|
||||
"get_workers_dir",
|
||||
]
|
196
librarian/librarian-core/src/librarian_core/utils/path_utils.py
Normal file
196
librarian/librarian-core/src/librarian_core/utils/path_utils.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""
|
||||
librarian_core/utils/path_utils.py
|
||||
==================================
|
||||
|
||||
Unified helpers for every path the Atlas-Librarian project uses.
|
||||
|
||||
Key features
|
||||
------------
|
||||
* XDG- and ENV-aware roots for **data**, **config**, and **cache**.
|
||||
* Dedicated sub-trees for *flows* (per-worker run directories) and
|
||||
*workers* (registrations, static assets, …).
|
||||
* Convenience helpers:
|
||||
- `get_run_dir(worker, run_id)`
|
||||
- `get_flow_name_from_id(run_id)` ← Prefect lookup (lazy import)
|
||||
- `get_temp_path()` / `copy_to_temp_dir()`
|
||||
* **Single source of truth** – change the root once, everything follows.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from platformdirs import (
|
||||
user_cache_dir,
|
||||
user_config_dir,
|
||||
user_data_dir,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Root directories (honours $LIBRARIAN_*_DIR, falls back to XDG) #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
_APP_NAME = "atlas-librarian"
|
||||
|
||||
_DATA_ROOT = Path(
|
||||
os.getenv("LIBRARIAN_DATA_DIR", user_data_dir(_APP_NAME))
|
||||
).expanduser()
|
||||
_CONFIG_ROOT = Path(
|
||||
os.getenv("LIBRARIAN_CONFIG_DIR", user_config_dir(_APP_NAME))
|
||||
).expanduser()
|
||||
_CACHE_ROOT = Path(
|
||||
os.getenv("LIBRARIAN_CACHE_DIR", user_cache_dir(_APP_NAME))
|
||||
).expanduser()
|
||||
|
||||
# Project-specific sub-trees
|
||||
_FLOWS_DIR = _DATA_ROOT / "flows" # <data>/flows/<worker>/<run_id>/
|
||||
_WORKERS_DIR = _DATA_ROOT / "workers" # static registration cache, etc.
|
||||
|
||||
# Ensure that the basic tree always exists
|
||||
for p in (_DATA_ROOT, _CONFIG_ROOT, _CACHE_ROOT, _FLOWS_DIR, _WORKERS_DIR):
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Public helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
# -- roots --
|
||||
|
||||
|
||||
def get_data_root() -> Path:
|
||||
return _DATA_ROOT
|
||||
|
||||
|
||||
def get_config_root() -> Path:
|
||||
return _CONFIG_ROOT
|
||||
|
||||
|
||||
def get_cache_root() -> Path:
|
||||
return _CACHE_ROOT
|
||||
|
||||
|
||||
def get_flows_dir() -> Path:
|
||||
return _FLOWS_DIR
|
||||
|
||||
|
||||
def get_workers_dir() -> Path:
|
||||
return _WORKERS_DIR
|
||||
|
||||
|
||||
# -- flow-run directories ---------------------------------------------------- #
|
||||
|
||||
|
||||
def get_run_dir(worker_name: str, run_id: str, *, create: bool = True) -> Path:
|
||||
"""
|
||||
Absolute path for one specific Prefect flow-run.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> get_run_dir("downloader", "1234abcd")
|
||||
~/.local/share/atlas-librarian/flows/downloader/1234abcd
|
||||
"""
|
||||
safe_worker = _sanitize(worker_name)
|
||||
path = _FLOWS_DIR / safe_worker / run_id
|
||||
if create:
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
||||
def get_flow_name_from_id(run_id: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a Prefect *run-id* → *flow name*.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The flow (worker) name or *None* if the ID cannot be found.
|
||||
"""
|
||||
try:
|
||||
from prefect.client.orchestration import get_client
|
||||
except ImportError: # Prefect not installed in caller env
|
||||
return None
|
||||
|
||||
try:
|
||||
import anyio
|
||||
|
||||
async def _lookup() -> Optional[str]:
|
||||
async with get_client() as client:
|
||||
fr = await client.read_flow_run(uuid.UUID(run_id))
|
||||
return fr.flow_name # type: ignore[attr-defined]
|
||||
|
||||
return anyio.run(_lookup)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# -- temporary workspace helpers -------------------------------------------- #
|
||||
|
||||
|
||||
def get_temp_path(prefix: str = "atlas") -> Path:
|
||||
"""
|
||||
Create a *unique* temporary directory inside the user cache.
|
||||
|
||||
The directory is **not** deleted automatically – callers decide.
|
||||
"""
|
||||
ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
rand = uuid.uuid4().hex[:8]
|
||||
tmp_root = _CACHE_ROOT / "tmp"
|
||||
tmp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
path = Path(
|
||||
tempfile.mkdtemp(
|
||||
dir=tmp_root,
|
||||
prefix=f"{prefix}-{ts}-{rand}-",
|
||||
)
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def copy_to_temp_dir(src: Path | str, *, prefix: str = "atlas") -> Path:
|
||||
"""
|
||||
Recursively copy *src* into a fresh temporary directory.
|
||||
|
||||
Returns the destination path.
|
||||
"""
|
||||
src_path = Path(src).expanduser().resolve()
|
||||
if not src_path.exists():
|
||||
raise FileNotFoundError(src_path)
|
||||
|
||||
dst = get_temp_path(prefix=prefix)
|
||||
shutil.copytree(src_path, dst, dirs_exist_ok=True)
|
||||
return dst
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# internal helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
def _sanitize(name: str) -> str:
|
||||
"""Replace path-hostile characters – keeps things safe across OSes."""
|
||||
return "".join(c if c.isalnum() or c in "-._" else "_" for c in name)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# exports #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
__all__ = [
|
||||
# roots
|
||||
"get_data_root",
|
||||
"get_config_root",
|
||||
"get_cache_root",
|
||||
"get_flows_dir",
|
||||
"get_workers_dir",
|
||||
# flow-run helpers
|
||||
"get_run_dir",
|
||||
"get_flow_name_from_id",
|
||||
# temporary space
|
||||
"get_temp_path",
|
||||
"copy_to_temp_dir",
|
||||
]
|
@ -0,0 +1,25 @@
|
||||
"""
|
||||
Secrets live in a classic .env **outside** the JSON settings file.
|
||||
Load order:
|
||||
|
||||
1. ENV LIBRARIAN_CREDENTIALS_PATH (override)
|
||||
2. ~/.config/atlas-librarian/credentials.env (XDG path)
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import os
|
||||
import logging
|
||||
import dotenv
|
||||
from librarian_core.utils.path_utils import get_config_root
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_env() -> None:
|
||||
path = Path(os.getenv("LIBRARIAN_CREDENTIALS_PATH", get_config_root() / "credentials.env"))
|
||||
|
||||
if path.exists():
|
||||
dotenv.load_dotenv(path)
|
||||
log.debug("Secrets loaded from %s", path)
|
||||
else:
|
||||
log.debug("No credentials.env found (looked in %s)", path)
|
@ -0,0 +1,3 @@
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
__all__ = ["Worker"]
|
192
librarian/librarian-core/src/librarian_core/workers/base.py
Normal file
192
librarian/librarian-core/src/librarian_core/workers/base.py
Normal file
@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import ClassVar, Generic, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
import anyio
|
||||
from prefect import flow, get_run_logger
|
||||
from prefect.runtime import flow_run
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from prefect.artifacts import acreate_markdown_artifact
|
||||
|
||||
from librarian_core.storage.worker_store import WorkerStore
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# type parameters #
|
||||
# --------------------------------------------------------------------------- #
|
||||
InT = TypeVar("InT", bound=BaseModel)
|
||||
OutT = TypeVar("OutT", bound=BaseModel)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# envelope returned by every worker flow #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class FlowArtifact(BaseModel, Generic[OutT]):
|
||||
run_id: str | None = None
|
||||
dir: Path | None = None
|
||||
data: OutT | None = None
|
||||
|
||||
@classmethod
|
||||
def new(cls, run_id: str | None = None, dir: Path | None = None, data: OutT | None = None) -> FlowArtifact:
|
||||
if not data:
|
||||
raise ValueError("data is required")
|
||||
# Intermediate Worker
|
||||
if run_id and dir:
|
||||
return FlowArtifact(run_id=run_id, dir=dir, data=data)
|
||||
|
||||
# Initial Worker
|
||||
else:
|
||||
return FlowArtifact(data=data)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# metaclass: adds a Prefect flow + envelope to each Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class _WorkerMeta(type):
|
||||
def __new__(mcls, name, bases, ns, **kw):
|
||||
cls = super().__new__(mcls, name, bases, dict(ns))
|
||||
|
||||
if name == "Worker" and cls.__module__ == __name__:
|
||||
return cls # abstract base
|
||||
|
||||
if not (hasattr(cls, "input_model") and hasattr(cls, "output_model")):
|
||||
raise TypeError(f"{name}: declare 'input_model' / 'output_model'.")
|
||||
if "__run__" not in cls.__dict__:
|
||||
raise TypeError(f"{name}: implement async '__run__(payload)'.")
|
||||
|
||||
cls.worker_name = name # type: ignore
|
||||
cls._create_input_artifact()
|
||||
cls._prefect_flow = mcls._build_prefect_flow(cls) # type: ignore
|
||||
return cls
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
@staticmethod
|
||||
def _build_prefect_flow(cls_ref):
|
||||
"""Create the Prefect flow and return it."""
|
||||
InArt = cls_ref.input_artifact # noqa: F841
|
||||
OutModel: type[BaseModel] = cls_ref.output_model # noqa: F841
|
||||
worker_name: str = cls_ref.worker_name
|
||||
|
||||
async def _core(in_art: FlowArtifact[InT]): # type: ignore[name-defined]
|
||||
logger = get_run_logger()
|
||||
run_id = flow_run.get_id() or uuid.uuid4().hex
|
||||
logger.info("%s started (run_id=%s)", worker_name, run_id)
|
||||
|
||||
store = WorkerStore.new(worker_name=worker_name, flow_id=run_id)
|
||||
|
||||
if in_art.dir and in_art.dir.exists() and in_art.dir != Path("."):
|
||||
store.prime_with_input(in_art.dir)
|
||||
|
||||
inst = cls_ref()
|
||||
inst._inject_store(store)
|
||||
# run worker ------------------------------------------------
|
||||
run_res = inst.__run__(in_art.data)
|
||||
# allow sync or async implementations
|
||||
if inspect.iscoroutine(run_res):
|
||||
result = await run_res
|
||||
else:
|
||||
result = run_res
|
||||
|
||||
store.save_model(result)
|
||||
store.persist_exit()
|
||||
store.cleanup()
|
||||
logger.info("%s finished", worker_name)
|
||||
|
||||
artifact = FlowArtifact(run_id=run_id, dir=store.data_dir, data=result)
|
||||
|
||||
md_table = await inst._to_markdown(result)
|
||||
await acreate_markdown_artifact(
|
||||
key=f"{worker_name.lower()}-artifact",
|
||||
markdown=md_table,
|
||||
description=f"{worker_name} output"
|
||||
)
|
||||
|
||||
# save the markdown artifact in the flow directory
|
||||
md_file = store._run_dir / "artifact.md"
|
||||
md_file.write_text(md_table)
|
||||
|
||||
return artifact
|
||||
|
||||
return flow(name=worker_name, log_prints=True)(_core)
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
def _create_input_artifact(cls):
|
||||
"""Create & attach a pydantic model ‹InputArtifact› = {dir?, data}."""
|
||||
DirField = (Path | None, None)
|
||||
DataField = (cls.input_model, ...) # type: ignore # required
|
||||
art_name = f"{cls.__name__}InputArtifact"
|
||||
|
||||
artifact = create_model(art_name, dir=DirField, data=DataField) # type: ignore[arg-type]
|
||||
artifact.__doc__ = f"Artifact for {cls.__name__} input"
|
||||
cls.input_artifact = artifact # type: ignore[attr-defined]
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# public Worker base #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Worker(Generic[InT, OutT], metaclass=_WorkerMeta):
|
||||
"""
|
||||
Derive from this class, set *input_model* / *output_model*, and implement
|
||||
an **async** ``__run__(payload: input_model)``.
|
||||
"""
|
||||
|
||||
input_model: ClassVar[type[BaseModel]]
|
||||
output_model: ClassVar[type[BaseModel]]
|
||||
input_artifact: ClassVar[type[BaseModel]] # injected by metaclass
|
||||
worker_name: ClassVar[str]
|
||||
_prefect_flow: ClassVar[Callable[[FlowArtifact[InT]], Awaitable[FlowArtifact[OutT]]]]
|
||||
|
||||
# injected at runtime
|
||||
entry: Path
|
||||
_store: WorkerStore
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal wiring #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _inject_store(self, store: WorkerStore) -> None:
|
||||
self._store = store
|
||||
self.entry = store.entry_dir
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# developer helper #
|
||||
# ------------------------------------------------------------------ #
|
||||
def stage(
|
||||
self,
|
||||
src: Path | str,
|
||||
*,
|
||||
new_name: str | None = None,
|
||||
sanitize: bool = True,
|
||||
move: bool = False,
|
||||
) -> Path:
|
||||
return self._store.stage(src, new_name=new_name, sanitize=sanitize, move=move)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# convenience wrappers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@classmethod
|
||||
def flow(cls):
|
||||
"""Return the auto-generated Prefect flow."""
|
||||
return cls._prefect_flow
|
||||
|
||||
# submit variants --------------------------------------------------- #
|
||||
@classmethod
|
||||
def submit(cls, payload: FlowArtifact[InT]) -> FlowArtifact[OutT]:
|
||||
async def _runner():
|
||||
art = await cls._prefect_flow(payload) # type: ignore[arg-type]
|
||||
return art
|
||||
|
||||
return anyio.run(_runner)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# abstract #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, payload: InT) -> OutT: ...
|
||||
|
||||
|
||||
# Should be overridden by the worker
|
||||
async def _to_markdown(self, data: OutT) -> str:
|
||||
md_table = pd.DataFrame([data.dict()]).to_markdown(index=False)
|
||||
return md_table
|
1197
librarian/librarian-core/uv.lock
generated
Normal file
1197
librarian/librarian-core/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
21
librarian/plugins/librarian-chunker/README.md
Normal file
21
librarian/plugins/librarian-chunker/README.md
Normal file
@ -0,0 +1,21 @@
|
||||
# Chunker
|
||||
|
||||
Extract text, chunk it, and save images from a PDF.
|
||||
|
||||
chunks is a List[str] of ~800-token strings (100-token overlap).
|
||||
Outputs (text files and images) are written under extracted_content/<pdf_basename>/.
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from chunker import Chunker
|
||||
|
||||
chunker = Chunker("path/to/file.pdf")
|
||||
chunks = chunker.run()
|
||||
|
||||
|
||||
|
||||
Setup:
|
||||
pip install -r requirements.txt
|
||||
python -m spacy download xx_ent_wiki_sm
|
||||
|
||||
|
40
librarian/plugins/librarian-chunker/pyproject.toml
Normal file
40
librarian/plugins/librarian-chunker/pyproject.toml
Normal file
@ -0,0 +1,40 @@
|
||||
[project]
|
||||
name = "librarian-chunker"
|
||||
version = "0.1.0"
|
||||
description = "Chunker for Librarian"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"pdfplumber",
|
||||
"pymupdf",
|
||||
"tiktoken",
|
||||
"spacy",
|
||||
"sentence-transformers",
|
||||
"pydantic",
|
||||
"prefect",
|
||||
"librarian-core",
|
||||
"python-pptx",
|
||||
"python-docx",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.21"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/librarian_chunker"]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.uv.sources]
|
||||
#librarian-core = { git = "https://github.com/DotNaos/librarian-core", rev = "dev" }
|
||||
|
||||
[project.entry-points."librarian.workers"]
|
||||
chunker = "librarian_chunker.chunker:Chunker"
|
||||
|
||||
|
||||
# ───────── optional: dev / test extras ─────────
|
||||
[project.optional-dependencies]
|
||||
dev = ["ruff", "pytest", "mypy"]
|
@ -0,0 +1,3 @@
|
||||
from .chunker import Chunker
|
||||
|
||||
__all__ = ["Chunker"]
|
@ -0,0 +1,217 @@
|
||||
import os
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pdfplumber
|
||||
import pymupdf
|
||||
import spacy
|
||||
import tiktoken
|
||||
from librarian_core.utils.path_utils import get_temp_path
|
||||
from librarian_core.workers.base import Worker
|
||||
from librarian_extractor.models.extract_data import ExtractData, ExtractedFile
|
||||
from prefect import get_run_logger, task
|
||||
from prefect.cache_policies import NO_CACHE
|
||||
from prefect.futures import wait
|
||||
|
||||
from librarian_chunker.models.chunk_data import (
|
||||
Chunk,
|
||||
ChunkData,
|
||||
ChunkedCourse,
|
||||
ChunkedTerm,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 800
|
||||
OVERLAP_TOKENS = 100
|
||||
|
||||
|
||||
class Chunker(Worker[ExtractData, ChunkData]):
|
||||
input_model = ExtractData
|
||||
output_model = ChunkData
|
||||
|
||||
async def __run__(self, payload: ExtractData) -> ChunkData: # noqa: D401
|
||||
lg = get_run_logger()
|
||||
lg.info("Chunker started")
|
||||
|
||||
working_dir = get_temp_path()
|
||||
|
||||
# load NLP and tokenizer
|
||||
Chunker.nlp = spacy.load("xx_ent_wiki_sm")
|
||||
Chunker.nlp.add_pipe("sentencizer")
|
||||
Chunker.enc = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# chunk parameters
|
||||
Chunker.max_tokens = MAX_TOKENS
|
||||
Chunker.overlap_tokens = OVERLAP_TOKENS
|
||||
|
||||
result = ChunkData(terms=[])
|
||||
|
||||
# Loading files
|
||||
for term in payload.terms:
|
||||
chunked_term = ChunkedTerm(id=term.id, name=term.name)
|
||||
in_term_dir = self.entry / term.name
|
||||
|
||||
out_term_dir = working_dir / term.name
|
||||
out_term_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for course in term.courses:
|
||||
chunked_course = ChunkedCourse(
|
||||
id=course.id, name=course.name, chunks=[]
|
||||
)
|
||||
in_course_dir = in_term_dir / course.name
|
||||
|
||||
out_course_dir = out_term_dir / course.name
|
||||
out_course_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
futs = []
|
||||
# All chunks are just in the course dir, so no new dir
|
||||
for chap in course.chapters:
|
||||
chapter_path = in_course_dir / chap.name
|
||||
|
||||
for f in chap.content_files:
|
||||
futs.append(
|
||||
self._chunk_file.submit(f, chapter_path, out_course_dir)
|
||||
)
|
||||
wait(futs)
|
||||
for fut in futs:
|
||||
chunks, images = fut.result()
|
||||
chunked_course.chunks.extend(chunks)
|
||||
chunked_course.images.extend(images)
|
||||
|
||||
chunked_term.courses.append(chunked_course)
|
||||
|
||||
# Add the chunked term to the result
|
||||
result.terms.append(chunked_term)
|
||||
self.stage(out_term_dir)
|
||||
|
||||
return result
|
||||
@staticmethod
|
||||
@task(log_prints=True)
|
||||
def _chunk_file(
|
||||
f: ExtractedFile, chapter_path: Path, out_course_dir: Path
|
||||
) -> tuple[list[Chunk], list[str]]:
|
||||
lg = get_run_logger()
|
||||
lg.info(f"Chunking file {f.name}")
|
||||
lg.info(f"Chapter path: {chapter_path}")
|
||||
lg.info(f"Out course dir: {out_course_dir}")
|
||||
|
||||
# Extract the Text
|
||||
file_text = Chunker._extract_text(chapter_path / f.name)
|
||||
|
||||
# Chunk the Text
|
||||
chunks = Chunker._chunk_text(file_text, f.name, out_course_dir)
|
||||
|
||||
images_dir = out_course_dir / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Extract the Images
|
||||
images = Chunker._extract_images(chapter_path / f.name, images_dir)
|
||||
|
||||
return chunks, images
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_text(file_path: Path) -> str:
|
||||
if not file_path.suffix == ".pdf":
|
||||
return ""
|
||||
|
||||
extracted_text = ""
|
||||
|
||||
with pdfplumber.open(file_path) as pdf:
|
||||
for i in range(len(pdf.pages)):
|
||||
current_page = pdf.pages[i]
|
||||
text = current_page.extract_text() or ""
|
||||
extracted_text += text
|
||||
|
||||
return extracted_text
|
||||
|
||||
@staticmethod
|
||||
def _chunk_text(text: str, f_name: str, out_course_dir: Path) -> list[Chunk]:
|
||||
lg = get_run_logger()
|
||||
lg.info(f"Chunking text for file {f_name}")
|
||||
# split text into sentences and get tokens
|
||||
nlp_doc = Chunker.nlp(text)
|
||||
sentences = [sent.text.strip() for sent in nlp_doc.sents]
|
||||
sentence_token_counts = [len(Chunker.enc.encode(s)) for s in sentences]
|
||||
lg.info(f"Extracted {len(sentences)} sentences with token counts: {sentence_token_counts}")
|
||||
|
||||
# Buffers
|
||||
chunks: list[Chunk] = []
|
||||
current_chunk = []
|
||||
current_token_total = 0
|
||||
|
||||
chunk_id = 0
|
||||
|
||||
for s, tc in zip(sentences, sentence_token_counts): # Pair sentences and tokens
|
||||
if tc + current_token_total <= MAX_TOKENS: # Check Token limit
|
||||
# Add sentences to chunk
|
||||
current_chunk.append(s)
|
||||
current_token_total += tc
|
||||
else:
|
||||
# Flush Chunk
|
||||
chunk_text = "\n\n".join(current_chunk)
|
||||
|
||||
chunk_name = f"{f_name}_{chunk_id}"
|
||||
with open(
|
||||
out_course_dir / f"{chunk_name}.md", "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(chunk_text)
|
||||
chunk_id += 1
|
||||
|
||||
chunks.append(
|
||||
Chunk(
|
||||
id=f"{f_name}_{chunk_id}",
|
||||
name=f"{f_name}_{chunk_id}.md",
|
||||
tokens=len(Chunker.enc.encode(chunk_text)),
|
||||
)
|
||||
)
|
||||
|
||||
# Get Overlap from Chunk
|
||||
token_ids = Chunker.enc.encode(chunk_text)
|
||||
overlap_ids = token_ids[-OVERLAP_TOKENS :]
|
||||
overlap_text = Chunker.enc.decode(overlap_ids)
|
||||
overlap_doc = Chunker.nlp(overlap_text)
|
||||
overlap_sents = [sent.text for sent in overlap_doc.sents]
|
||||
|
||||
# Start new Chunk
|
||||
current_chunk = overlap_sents + [s]
|
||||
current_token_total = sum(
|
||||
len(Chunker.enc.encode(s)) for s in current_chunk
|
||||
)
|
||||
|
||||
if current_chunk:
|
||||
chunk_text = "\n\n".join(current_chunk)
|
||||
chunk_name = f"{f_name}_{chunk_id}"
|
||||
with open(out_course_dir / f"{chunk_name}.md", "w", encoding="utf-8") as f:
|
||||
f.write(chunk_text)
|
||||
chunks.append(
|
||||
Chunk(
|
||||
id=f"{f_name}_{chunk_id}",
|
||||
name=f"{f_name}_{chunk_id}",
|
||||
tokens=len(Chunker.enc.encode(chunk_text)),
|
||||
)
|
||||
)
|
||||
lg.info(f"Created {len(chunks)} chunks for file {f_name}")
|
||||
return chunks
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(file: Path, img_folder: Path) -> list[str]:
|
||||
images_list = []
|
||||
if not file.suffix == ".pdf":
|
||||
return []
|
||||
|
||||
with pymupdf.open(file) as doc:
|
||||
for i in range(len(doc)):
|
||||
images = doc.get_page_images(i)
|
||||
|
||||
for img in images:
|
||||
img_xref = img[0]
|
||||
image = doc.extract_image(img_xref)
|
||||
img_content = image["image"]
|
||||
img_ext = image["ext"]
|
||||
img_name = f"img_page{i + 1}_{img_xref}.{img_ext}"
|
||||
img_file_path = img_folder / img_name
|
||||
|
||||
with open(img_file_path, "wb") as img_file:
|
||||
img_file.write(img_content)
|
||||
images_list.append(img_name)
|
||||
return images_list
|
@ -0,0 +1,3 @@
|
||||
from .chunk_data import Chunk, ChunkedCourse, ChunkedTerm, ChunkData
|
||||
|
||||
__all__ = ["Chunk", "ChunkedCourse", "ChunkedTerm", "ChunkData"]
|
@ -0,0 +1,29 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Output models #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
tokens: int
|
||||
|
||||
class ChunkedCourse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
chunks: List[Chunk] = Field(default_factory=list)
|
||||
images: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChunkedTerm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
courses: List[ChunkedCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChunkData(BaseModel):
|
||||
terms: List[ChunkedTerm]
|
3544
librarian/plugins/librarian-chunker/uv.lock
generated
Normal file
3544
librarian/plugins/librarian-chunker/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
0
librarian/plugins/librarian-extractor/README.md
Normal file
0
librarian/plugins/librarian-extractor/README.md
Normal file
40
librarian/plugins/librarian-extractor/pyproject.toml
Normal file
40
librarian/plugins/librarian-extractor/pyproject.toml
Normal file
@ -0,0 +1,40 @@
|
||||
[project]
|
||||
name = "librarian-extractor"
|
||||
version = "0.1.0"
|
||||
description = "Librarian extractor plugin"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "DotNaos", email = "schuetzoliver00@gmail.com" },
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"librarian-core",
|
||||
"importlib_metadata; python_version<'3.10'",
|
||||
"ollama>=0.4.8",
|
||||
"parsel>=1.10.0",
|
||||
"prefect>=3.4.1",
|
||||
"openai>=1.78.1",
|
||||
]
|
||||
|
||||
#[tool.uv.sources]
|
||||
#librarian-core = { git = "https://github.com/DotNaos/librarian-core", rev = "main" }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.21"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/librarian_extractor/"]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
|
||||
# ───────── optional: dev / test extras ─────────
|
||||
[project.optional-dependencies]
|
||||
dev = ["ruff", "pytest", "mypy"]
|
||||
|
||||
[project.entry-points."librarian.workers"]
|
||||
extractor = "librarian_extractor.extractor:Extractor"
|
||||
ai_sanitizer = "librarian_extractor.ai_sanitizer:AISanitizer"
|
@ -0,0 +1,4 @@
|
||||
from librarian_extractor.ai_sanitizer.ai_sanitizer import AISanitizer
|
||||
from librarian_extractor.extractor.extractor import Extractor
|
||||
|
||||
__all__ = ["Extractor", "AISanitizer"]
|
@ -0,0 +1,3 @@
|
||||
from librarian_extractor.ai_sanitizer.ai_sanitizer import AISanitizer
|
||||
|
||||
__all__ = ["AISanitizer"]
|
@ -0,0 +1,215 @@
|
||||
"""
|
||||
AI-powered sanitizer
|
||||
====================
|
||||
• in : ExtractData (tree from Extractor)
|
||||
• out : ExtractData (same graph but with prettier names)
|
||||
|
||||
Changes vs. previous revision
|
||||
-----------------------------
|
||||
✓ Media files resolved at course-level `media/` folder
|
||||
✓ Missing sources are warned, not raised
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
from prefect import get_run_logger, task
|
||||
from prefect.futures import PrefectFuture, wait
|
||||
from pydantic import ValidationError
|
||||
|
||||
from librarian_core.workers.base import Worker
|
||||
from librarian_extractor.prompts import PROMPT_COURSE
|
||||
from librarian_extractor.models.extract_data import (
|
||||
ExtractData,
|
||||
ExtractedCourse,
|
||||
ExtractedFile,
|
||||
ExtractedTerm,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _clean_json(txt: str) -> str:
|
||||
txt = txt.strip()
|
||||
if txt.startswith("```"):
|
||||
txt = txt.lstrip("`")
|
||||
if "\n" in txt:
|
||||
txt = txt.split("\n", 1)[1]
|
||||
if txt.rstrip().endswith("```"):
|
||||
txt = txt.rstrip()[:-3]
|
||||
return txt.strip()
|
||||
|
||||
|
||||
def _safe_json_load(txt: str) -> dict:
|
||||
return json.loads(_clean_json(txt))
|
||||
|
||||
|
||||
def _merge_with_original(src: ExtractedCourse, patch: dict, lg) -> ExtractedCourse:
|
||||
"""Return *patch* merged with *src* so every id is preserved."""
|
||||
try:
|
||||
tgt = ExtractedCourse.model_validate(patch)
|
||||
except ValidationError as err:
|
||||
lg.warning("LLM payload invalid – keeping original (%s)", err)
|
||||
return src
|
||||
|
||||
if not tgt.id:
|
||||
tgt.id = src.id
|
||||
|
||||
for ch_src, ch_tgt in zip(src.chapters, tgt.chapters):
|
||||
if not ch_tgt.name:
|
||||
ch_tgt.name = ch_src.name
|
||||
for f_src, f_tgt in zip(ch_src.content_files, ch_tgt.content_files):
|
||||
if not f_tgt.id:
|
||||
f_tgt.id = f_src.id
|
||||
for f_src, f_tgt in zip(ch_src.media_files, ch_tgt.media_files):
|
||||
if not f_tgt.id:
|
||||
f_tgt.id = f_src.id
|
||||
return tgt
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# OpenAI call – Prefect task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@task(
|
||||
name="sanitize_course_json",
|
||||
retries=2,
|
||||
retry_delay_seconds=5,
|
||||
log_prints=True,
|
||||
)
|
||||
def sanitize_course_json(course_json: str, model: str, temperature: float) -> dict:
|
||||
rsp = openai.chat.completions.create(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
messages=[
|
||||
{"role": "system", "content": PROMPT_COURSE},
|
||||
{"role": "user", "content": course_json},
|
||||
],
|
||||
)
|
||||
usage = rsp.usage
|
||||
get_run_logger().info(
|
||||
"LLM tokens – prompt: %s, completion: %s",
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens,
|
||||
)
|
||||
return _safe_json_load(rsp.choices[0].message.content or "{}")
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class AISanitizer(Worker[ExtractData, ExtractData]):
|
||||
input_model = ExtractData
|
||||
output_model = ExtractData
|
||||
|
||||
def __init__(self, model_name: str | None = None, temperature: float = 0.0):
|
||||
super().__init__()
|
||||
self.model_name = model_name or os.getenv("OPENAI_MODEL", "gpt-4o-mini")
|
||||
self.temperature = temperature
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
def __run__(self, data: ExtractData) -> ExtractData:
|
||||
lg = get_run_logger()
|
||||
|
||||
futures: List[PrefectFuture] = []
|
||||
originals: List[ExtractedCourse] = []
|
||||
|
||||
# 1) submit all courses to the LLM
|
||||
for term in data.terms:
|
||||
for course in term.courses:
|
||||
futures.append(
|
||||
sanitize_course_json.submit(
|
||||
json.dumps(course.model_dump(), ensure_ascii=False),
|
||||
self.model_name,
|
||||
self.temperature,
|
||||
)
|
||||
)
|
||||
originals.append(course)
|
||||
|
||||
wait(futures)
|
||||
|
||||
# 2) build new graph with merged results
|
||||
terms_out: List[ExtractedTerm] = []
|
||||
idx = 0
|
||||
for term in data.terms:
|
||||
new_courses: List[ExtractedCourse] = []
|
||||
for _ in term.courses:
|
||||
clean_dict = futures[idx].result()
|
||||
merged = _merge_with_original(originals[idx], clean_dict, lg)
|
||||
new_courses.append(merged)
|
||||
idx += 1
|
||||
terms_out.append(
|
||||
ExtractedTerm(id=term.id, name=term.name, courses=new_courses)
|
||||
)
|
||||
|
||||
renamed = ExtractData(terms=terms_out)
|
||||
|
||||
# 3) stage files with their new names
|
||||
self._export_with_new_names(data, renamed, lg)
|
||||
|
||||
return renamed
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# staging helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _stage_or_warn(self, src: Path, dst: Path, lg):
|
||||
"""Copy *src* → *dst* (via self.stage). Warn if src missing."""
|
||||
if not src.exists():
|
||||
lg.warning("Source missing – skipped %s", src)
|
||||
return
|
||||
self.stage(src, new_name=str(dst), sanitize=False)
|
||||
lg.debug("Stage %s → %s", src.name, dst)
|
||||
|
||||
def _export_with_new_names(
|
||||
self,
|
||||
original: ExtractData,
|
||||
renamed: ExtractData,
|
||||
lg,
|
||||
):
|
||||
entry = Path(self.entry)
|
||||
|
||||
for term_old, term_new in zip(original.terms, renamed.terms):
|
||||
for course_old, course_new in zip(term_old.courses, term_new.courses):
|
||||
# ---------- content files (per chapter) -----------------
|
||||
for chap_old, chap_new in zip(course_old.chapters, course_new.chapters):
|
||||
n = min(len(chap_old.content_files), len(chap_new.content_files))
|
||||
for i in range(n):
|
||||
fo = chap_old.content_files[i]
|
||||
fn = chap_new.content_files[i]
|
||||
src = (
|
||||
entry
|
||||
/ term_old.name
|
||||
/ course_old.name
|
||||
/ chap_old.name
|
||||
/ fo.name
|
||||
)
|
||||
dst = (
|
||||
Path(term_new.name)
|
||||
/ course_new.name
|
||||
/ chap_new.name
|
||||
/ fn.name
|
||||
)
|
||||
self._stage_or_warn(src, dst, lg)
|
||||
|
||||
# ---------- media files (course-level “media”) ----------
|
||||
src_media_dir = (
|
||||
entry / term_old.name / course_old.name / "media"
|
||||
) # <─ fixed!
|
||||
dst_media_dir = Path(term_new.name) / course_new.name / "media"
|
||||
if not src_media_dir.is_dir():
|
||||
continue
|
||||
|
||||
# build a flat list of (old, new) media filenames
|
||||
media_pairs: List[tuple[ExtractedFile, ExtractedFile]] = []
|
||||
for ch_o, ch_n in zip(course_old.chapters, course_new.chapters):
|
||||
media_pairs.extend(zip(ch_o.media_files, ch_n.media_files))
|
||||
|
||||
for fo, fn in media_pairs:
|
||||
src = src_media_dir / fo.name
|
||||
dst = dst_media_dir / fn.name
|
||||
self._stage_or_warn(src, dst, lg)
|
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Shared lists and prompts
|
||||
"""
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# file selection – keep only real documents we can show / convert #
|
||||
# -------------------------------------------------------------------- #
|
||||
CONTENT_FILE_EXTENSIONS = [
|
||||
"*.pdf",
|
||||
"*.doc",
|
||||
"*.docx",
|
||||
"*.ppt",
|
||||
"*.pptx",
|
||||
"*.txt",
|
||||
"*.rtf",
|
||||
]
|
||||
|
||||
MEDIA_FILE_EXTENSIONS = [
|
||||
"*.jpg",
|
||||
"*.jpeg",
|
||||
"*.png",
|
||||
"*.gif",
|
||||
"*.svg",
|
||||
"*.mp4",
|
||||
"*.mov",
|
||||
"*.mp3",
|
||||
]
|
||||
|
||||
# -------------------------------------------------------------------- #
|
||||
# naming rules #
|
||||
# -------------------------------------------------------------------- #
|
||||
SANITIZE_REGEX = {
|
||||
"base": [r"\s*\(\d+\)$"],
|
||||
"course": [
|
||||
r"^\d+\.\s*",
|
||||
r"\s*\([^)]*\)",
|
||||
r"\s*(?:FS|HS)\d{2}$",
|
||||
],
|
||||
"chapter": [
|
||||
r"^\d+\.?\s*",
|
||||
r"\s*SW_\d+\s*(?:___)?\s*KW_\d+\s*",
|
||||
r"\bKapitel[_\s]*\d+\b",
|
||||
],
|
||||
"file": [
|
||||
r",", # ← new : drop commas
|
||||
r",?\s*inkl\.?\s*",
|
||||
r"\(File\)",
|
||||
r"```json",
|
||||
],
|
||||
}
|
||||
|
||||
BLACKLIST_REGEX = {
|
||||
"chapter": [r"^allgemeine informationen$"],
|
||||
"ressource_types": [
|
||||
"(Forum)",
|
||||
"(URL)",
|
||||
"(External tool)",
|
||||
"(Text and media area)",
|
||||
],
|
||||
}
|
||||
|
||||
RESSOURCE_TYPES = BLACKLIST_REGEX["ressource_types"]
|
||||
BASE_BLACKLIST_REGEX = SANITIZE_REGEX["base"]
|
||||
|
||||
MAX_FILENAME_LENGTH = 100
|
||||
|
@ -0,0 +1,3 @@
|
||||
from librarian_extractor.extractor.extractor import Extractor
|
||||
|
||||
__all__ = ["Extractor"]
|
@ -0,0 +1,301 @@
|
||||
"""
|
||||
Extractor Worker – resilient version
|
||||
------------------------------------
|
||||
* Finds the real payload even when the link goes to
|
||||
File_…/index.html first.
|
||||
* No `iterdir` on non-directories.
|
||||
* Keeps all earlier features: id parsing, allowed-suffix filter,
|
||||
media-folder, sanitising.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import re
|
||||
import shutil
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import lxml.html
|
||||
import parsel
|
||||
from librarian_core.utils.path_utils import get_temp_path
|
||||
from librarian_core.workers.base import Worker
|
||||
from librarian_scraper.models.download_data import DownloadData
|
||||
from prefect import get_run_logger, task
|
||||
from prefect.futures import wait
|
||||
|
||||
from librarian_extractor.constants import (
|
||||
CONTENT_FILE_EXTENSIONS,
|
||||
MEDIA_FILE_EXTENSIONS,
|
||||
)
|
||||
from librarian_extractor.models.extract_data import (
|
||||
ExtractData,
|
||||
ExtractedChapter,
|
||||
ExtractedCourse,
|
||||
ExtractedFile,
|
||||
ExtractedTerm,
|
||||
)
|
||||
from librarian_extractor.sanitizers import (
|
||||
annotate_chapter_name,
|
||||
is_chapter_allowed,
|
||||
sanitize_chapter_name,
|
||||
sanitize_course_name,
|
||||
sanitize_file_name,
|
||||
)
|
||||
|
||||
CONTENT_EXTS = {Path(p).suffix.lower() for p in CONTENT_FILE_EXTENSIONS}
|
||||
MEDIA_EXTS = {Path(p).suffix.lower() for p in MEDIA_FILE_EXTENSIONS}
|
||||
ALL_EXTS = CONTENT_EXTS | MEDIA_EXTS
|
||||
|
||||
_id_rx = re.compile(r"\.(\d{4,})[./]") # 1172180 from “..._.1172180/index.html”
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helpers #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def _hash_id(fname: str) -> str:
|
||||
return hashlib.sha1(fname.encode()).hexdigest()[:10]
|
||||
|
||||
|
||||
def _html_stub_target(html_file: Path) -> Path | None:
|
||||
"""Parse a Moodle *index.html* stub and return the first file link."""
|
||||
try:
|
||||
tree = lxml.html.parse(html_file) # type: ignore[arg-type]
|
||||
hrefs = tree.xpath("//ul/li/a/@href")
|
||||
for h in hrefs:
|
||||
h = h.split("?")[0].split("#")[0]
|
||||
p = html_file.parent / h
|
||||
if p.exists():
|
||||
return p
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _best_payload(node: Path) -> Path | None: # noqa: C901
|
||||
"""
|
||||
Return the real document given *node* which may be:
|
||||
• the actual file → return it
|
||||
• File_xxx/dir → search inside /content or dir itself
|
||||
• File_xxx/index.html stub → parse to find linked file
|
||||
"""
|
||||
# 1) immediate hit
|
||||
if node.is_file() and node.suffix.lower() in ALL_EXTS:
|
||||
return node
|
||||
|
||||
# 2) if html stub try to parse inner link
|
||||
if node.is_file() and node.suffix.lower() in {".html", ".htm"}:
|
||||
hinted = _html_stub_target(node)
|
||||
if hinted:
|
||||
return _best_payload(hinted) # recurse
|
||||
|
||||
# 3) directories to search
|
||||
roots: list[Path] = []
|
||||
if node.is_dir():
|
||||
roots.append(node)
|
||||
elif node.is_file():
|
||||
roots.append(node.parent)
|
||||
|
||||
for r in list(roots):
|
||||
if r.is_dir() and (r / "content").is_dir():
|
||||
roots.insert(0, r / "content") # prefer content folder
|
||||
|
||||
for r in roots:
|
||||
if not r.is_dir():
|
||||
continue
|
||||
files = [p for p in r.iterdir() if p.is_file() and p.suffix.lower() in ALL_EXTS]
|
||||
if len(files) == 1:
|
||||
return files[0]
|
||||
return None
|
||||
|
||||
|
||||
def _file_id_from_href(href: str) -> str:
|
||||
m = _id_rx.search(href)
|
||||
return m.group(1) if m else ""
|
||||
|
||||
|
||||
def task_(**kw):
|
||||
kw.setdefault("log_prints", True)
|
||||
return task(**kw)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# Worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Extractor(Worker[DownloadData, ExtractData]):
|
||||
input_model = DownloadData
|
||||
output_model = ExtractData
|
||||
|
||||
async def __run__(self, downloads: DownloadData) -> ExtractData:
|
||||
lg = get_run_logger()
|
||||
work_root = Path(get_temp_path()) / "extract"
|
||||
work_root.mkdir(parents=True, exist_ok=True)
|
||||
self.out_dir = work_root
|
||||
|
||||
result = ExtractData()
|
||||
futs = []
|
||||
entry_dir = self.entry
|
||||
|
||||
for t in downloads.terms:
|
||||
(work_root / t.name).mkdir(exist_ok=True)
|
||||
result.terms.append(ExtractedTerm(id=t.id, name=t.name))
|
||||
for c in t.courses:
|
||||
futs.append(
|
||||
self._extract_course.submit(t.name, c.id, work_root, entry_dir)
|
||||
)
|
||||
|
||||
done, _ = wait(futs)
|
||||
for fut in done:
|
||||
term, meta = fut.result()
|
||||
if meta:
|
||||
next(t for t in result.terms if t.name == term).courses.append(meta)
|
||||
|
||||
for term in result.terms:
|
||||
self.stage(
|
||||
work_root / term.name, new_name=term.name, sanitize=False, move=True
|
||||
)
|
||||
lg.info("Extractor finished – %d terms", len(result.terms))
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task_()
|
||||
def _extract_course( # noqa: C901
|
||||
term: str, cid: str, out_root: Path, entry_dir: Path
|
||||
) -> Tuple[str, ExtractedCourse | None]:
|
||||
lg = get_run_logger()
|
||||
z = entry_dir / term / f"{cid}.zip"
|
||||
if not z.is_file():
|
||||
lg.warning("ZIP missing %s", z)
|
||||
return term, None
|
||||
|
||||
tmp = Path(get_temp_path()) / f"u{cid}"
|
||||
tmp.mkdir(exist_ok=True)
|
||||
try:
|
||||
with zipfile.ZipFile(z) as zf:
|
||||
zf.extractall(tmp)
|
||||
|
||||
html, root = Extractor._index_html(tmp)
|
||||
if not html:
|
||||
lg.warning("index.html missing for %s", cid)
|
||||
return term, None
|
||||
|
||||
cname = Extractor._course_name(html) or cid
|
||||
c_meta = ExtractedCourse(id=cid, name=cname)
|
||||
media_dir = out_root / term / cname / "media"
|
||||
|
||||
structure = Extractor._outline(html)
|
||||
if not structure:
|
||||
Extractor._copy_all(
|
||||
root, out_root / term / cname, c_meta, media_dir, lg # type: ignore
|
||||
)
|
||||
return term, c_meta
|
||||
|
||||
chap_no = 0
|
||||
for title, links in structure:
|
||||
if not is_chapter_allowed(title):
|
||||
continue
|
||||
chap_no += 1
|
||||
chap_name = annotate_chapter_name(sanitize_chapter_name(title), chap_no)
|
||||
chap_dir = out_root / term / cname / chap_name
|
||||
chap_dir.mkdir(parents=True, exist_ok=True)
|
||||
chap_meta = ExtractedChapter(name=chap_name)
|
||||
|
||||
for text, href in links:
|
||||
target = _best_payload(root / href.lstrip("./"))
|
||||
if not target:
|
||||
lg.debug("payload not found %s", href)
|
||||
continue
|
||||
|
||||
base = sanitize_file_name(text)
|
||||
if not Path(base).suffix:
|
||||
base += target.suffix # ensure extension
|
||||
|
||||
dst = (
|
||||
media_dir / base
|
||||
if target.suffix.lower() in MEDIA_EXTS
|
||||
else chap_dir / base
|
||||
)
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(target, dst)
|
||||
|
||||
fid = _file_id_from_href(href) or _hash_id(dst.name)
|
||||
meta_obj = ExtractedFile(id=fid, name=dst.name)
|
||||
(
|
||||
chap_meta.media_files
|
||||
if dst.is_relative_to(media_dir)
|
||||
else chap_meta.content_files
|
||||
).append(meta_obj)
|
||||
|
||||
if chap_meta.content_files or chap_meta.media_files:
|
||||
c_meta.chapters.append(chap_meta)
|
||||
|
||||
if c_meta.chapters:
|
||||
lg.info("Extracted %s (%d chap.)", cname, len(c_meta.chapters))
|
||||
return term, c_meta
|
||||
return term, None
|
||||
finally:
|
||||
shutil.rmtree(tmp, ignore_errors=True)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _copy_all(
|
||||
root: Path, dst_root: Path, c_meta: ExtractedCourse, media_dir: Path, lg
|
||||
):
|
||||
chap = ExtractedChapter(name="Everything")
|
||||
dst_root.mkdir(parents=True, exist_ok=True)
|
||||
for fp in root.rglob("*"):
|
||||
if fp.is_file() and fp.suffix.lower() in ALL_EXTS:
|
||||
dst = (
|
||||
media_dir if fp.suffix.lower() in MEDIA_EXTS else dst_root
|
||||
) / fp.name
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(fp, dst)
|
||||
chap.content_files.append(
|
||||
ExtractedFile(id=_hash_id(fp.name), name=dst.name)
|
||||
)
|
||||
if chap.content_files:
|
||||
c_meta.chapters.append(chap)
|
||||
lg.info("Fallback copy %d files", len(chap.content_files))
|
||||
|
||||
@staticmethod
|
||||
def _index_html(root: Path) -> Tuple[str, Path | None]:
|
||||
for idx in root.rglob("index.html"):
|
||||
try:
|
||||
return idx.read_text("utf-8", errors="ignore"), idx.parent
|
||||
except Exception:
|
||||
continue
|
||||
return "", None
|
||||
|
||||
@staticmethod
|
||||
def _course_name(html: str) -> str:
|
||||
sel = parsel.Selector(text=html)
|
||||
return sanitize_course_name(sel.css("h1 a::text").get(default="").strip())
|
||||
|
||||
@staticmethod
|
||||
def _outline(html: str):
|
||||
t = lxml.html.fromstring(html)
|
||||
res = []
|
||||
for h3 in t.xpath("//h3"):
|
||||
title = h3.text_content().strip()
|
||||
ul = next((s for s in h3.itersiblings() if s.tag == "ul"), None)
|
||||
if ul is None:
|
||||
continue
|
||||
links = []
|
||||
for a in ul.findall(".//a"):
|
||||
if "(File)" in (a.text_content() or ""):
|
||||
sel = parsel.Selector(
|
||||
text=lxml.html.tostring(a, encoding="unicode") # type: ignore
|
||||
)
|
||||
links.append(
|
||||
(
|
||||
sel.css("::text").get().strip(), # type: ignore
|
||||
sel.css("::attr(href)").get().strip(), # type: ignore
|
||||
)
|
||||
)
|
||||
if links:
|
||||
res.append((title, links))
|
||||
return res
|
@ -0,0 +1,30 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ExtractedFile(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the file, relative to ExtractedChapter.name
|
||||
|
||||
|
||||
class ExtractedChapter(BaseModel):
|
||||
name: str # Name of the chapter directory, relative to ExtractedCourse.name
|
||||
content_files: List[ExtractedFile] = Field(default_factory=list)
|
||||
media_files: List[ExtractedFile] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractedCourse(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the course directory, relative to ExtractedTerm.name
|
||||
chapters: List[ExtractedChapter] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractedTerm(BaseModel):
|
||||
id: str
|
||||
name: str # Name of the term directory, relative to ExtractMeta.dir
|
||||
courses: List[ExtractedCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ExtractData(BaseModel):
|
||||
terms: List[ExtractedTerm] = Field(default_factory=list)
|
@ -0,0 +1,29 @@
|
||||
# -------------------------------------------------------------------- #
|
||||
# LLM prompts #
|
||||
# -------------------------------------------------------------------- #
|
||||
|
||||
PROMPT_COURSE = """
|
||||
General naming rules
|
||||
====================
|
||||
* Use underscores instead of spaces.
|
||||
* Keep meaningful numbers / IDs.
|
||||
* Remove Date information, except it is absolutely necessary.
|
||||
* -> Normalize dates / months ("Februar" → "02").
|
||||
* Remove redundant semester / university codes (e.g. FS24, HS, FHGR, CDS).
|
||||
* Remove redundancy in general. ( DRY - Don't Repeat Yourself )
|
||||
* Trim superfluous parts like duplicate week information ("1_SW_01_KW_08" → "SW_01").
|
||||
* Only keep have one enumarator at a time, so "1_SW_01" → "SW_01".
|
||||
* Preserve file extensions!
|
||||
* Avoid repeated dots and illegal filesystem characters (colon, slash, …).
|
||||
|
||||
The most important rule is to keep everything as consistent as possible.
|
||||
|
||||
Important – DO NOT:
|
||||
* change the JSON structure,
|
||||
* change or reorder any `id`,
|
||||
* add any keys.
|
||||
|
||||
Return **only** the modified JSON for the course you receive.
|
||||
|
||||
Everything should be in english after the sanitization.
|
||||
""".strip()
|
@ -0,0 +1,71 @@
|
||||
"""
|
||||
Name-sanitising helpers
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from librarian_extractor.constants import (
|
||||
BASE_BLACKLIST_REGEX,
|
||||
BLACKLIST_REGEX,
|
||||
MAX_FILENAME_LENGTH,
|
||||
RESSOURCE_TYPES,
|
||||
SANITIZE_REGEX,
|
||||
)
|
||||
|
||||
_INVALID_FS_CHARS = re.compile(r'[\\/:*?"<>|]')
|
||||
_WS = re.compile(r"\s+")
|
||||
_DUP_DOTS = re.compile(r"\.\.+")
|
||||
_TRAILING_NUM = re.compile(r"_\(\d+\)$")
|
||||
|
||||
|
||||
def _sanitize_name(name: str, extra_patterns: list[str]) -> str:
|
||||
original = name
|
||||
for rt in RESSOURCE_TYPES:
|
||||
name = name.replace(rt, "")
|
||||
for rx in BASE_BLACKLIST_REGEX + extra_patterns:
|
||||
name = re.sub(rx, "", name, flags=re.IGNORECASE)
|
||||
name = _INVALID_FS_CHARS.sub("_", name)
|
||||
name = _DUP_DOTS.sub(".", name)
|
||||
name = _WS.sub(" ", name).replace(" ", "_")
|
||||
name = re.sub(r"_+", "_", name).strip("_")
|
||||
base, dot, ext = name.rpartition(".")
|
||||
if dot:
|
||||
base = _TRAILING_NUM.sub("", base)
|
||||
dup = re.compile(rf"(?i)[._]{re.escape(ext)}$")
|
||||
base = dup.sub("", base)
|
||||
name = f"{base}.{ext}" if base else f".{ext}"
|
||||
else:
|
||||
name = _TRAILING_NUM.sub("", name)
|
||||
name = name.strip("_.")
|
||||
if len(name) > MAX_FILENAME_LENGTH:
|
||||
if dot and len(ext) < 10:
|
||||
avail = MAX_FILENAME_LENGTH - len(ext) - 1
|
||||
name = f"{base[:avail]}.{ext}"
|
||||
else:
|
||||
name = name[:MAX_FILENAME_LENGTH].rstrip("_")
|
||||
if not name or name == ".":
|
||||
name = re.sub(_INVALID_FS_CHARS, "_", original)[:MAX_FILENAME_LENGTH] or "file"
|
||||
return name
|
||||
|
||||
|
||||
def sanitize_course_name(name: str) -> str:
|
||||
return _sanitize_name(name, SANITIZE_REGEX["course"])
|
||||
|
||||
|
||||
def sanitize_chapter_name(name: str) -> str:
|
||||
return _sanitize_name(name, SANITIZE_REGEX["chapter"])
|
||||
|
||||
|
||||
def sanitize_file_name(name: str) -> str:
|
||||
return _sanitize_name(name, SANITIZE_REGEX["file"])
|
||||
|
||||
|
||||
def annotate_chapter_name(name: str, idx: Optional[int] = None) -> str:
|
||||
return f"{idx}_{name}" if idx is not None else name
|
||||
|
||||
|
||||
def is_chapter_allowed(name: str) -> bool:
|
||||
return name.strip().lower() not in BLACKLIST_REGEX["chapter"]
|
215
librarian/plugins/librarian-extractor/uv.lock
generated
Normal file
215
librarian/plugins/librarian-extractor/uv.lock
generated
Normal file
@ -0,0 +1,215 @@
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.10"
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/35/2495c4ac46b980e4ca1f6ad6db102322ef3ad2410b79fdde159a4b0f3b92/exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc", size = 28883 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "librarian-core"
|
||||
version = "0.1.0"
|
||||
source = { git = "https://github.com/DotNaos/librarian-core?rev=main#a564a04ad1019cb196af1ee11d654b77839a469b" }
|
||||
|
||||
[[package]]
|
||||
name = "librarian-scraper"
|
||||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "librarian-core" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "mypy" },
|
||||
{ name = "pytest" },
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "importlib-metadata", marker = "python_full_version < '3.10'" },
|
||||
{ name = "librarian-core", git = "https://github.com/DotNaos/librarian-core?rev=main" },
|
||||
{ name = "mypy", marker = "extra == 'dev'" },
|
||||
{ name = "pytest", marker = "extra == 'dev'" },
|
||||
{ name = "ruff", marker = "extra == 'dev'" },
|
||||
]
|
||||
provides-extras = ["dev"]
|
||||
|
||||
[[package]]
|
||||
name = "mypy"
|
||||
version = "1.15.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mypy-extensions" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ce/43/d5e49a86afa64bd3839ea0d5b9c7103487007d728e1293f52525d6d5486a/mypy-1.15.0.tar.gz", hash = "sha256:404534629d51d3efea5c800ee7c42b72a6554d6c400e6a79eafe15d11341fd43", size = 3239717 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/68/f8/65a7ce8d0e09b6329ad0c8d40330d100ea343bd4dd04c4f8ae26462d0a17/mypy-1.15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:979e4e1a006511dacf628e36fadfecbcc0160a8af6ca7dad2f5025529e082c13", size = 10738433 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/95/9c0ecb8eacfe048583706249439ff52105b3f552ea9c4024166c03224270/mypy-1.15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c4bb0e1bd29f7d34efcccd71cf733580191e9a264a2202b0239da95984c5b559", size = 9861472 },
|
||||
{ url = "https://files.pythonhosted.org/packages/84/09/9ec95e982e282e20c0d5407bc65031dfd0f0f8ecc66b69538296e06fcbee/mypy-1.15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be68172e9fd9ad8fb876c6389f16d1c1b5f100ffa779f77b1fb2176fcc9ab95b", size = 11611424 },
|
||||
{ url = "https://files.pythonhosted.org/packages/78/13/f7d14e55865036a1e6a0a69580c240f43bc1f37407fe9235c0d4ef25ffb0/mypy-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c7be1e46525adfa0d97681432ee9fcd61a3964c2446795714699a998d193f1a3", size = 12365450 },
|
||||
{ url = "https://files.pythonhosted.org/packages/48/e1/301a73852d40c241e915ac6d7bcd7fedd47d519246db2d7b86b9d7e7a0cb/mypy-1.15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:2e2c2e6d3593f6451b18588848e66260ff62ccca522dd231cd4dd59b0160668b", size = 12551765 },
|
||||
{ url = "https://files.pythonhosted.org/packages/77/ba/c37bc323ae5fe7f3f15a28e06ab012cd0b7552886118943e90b15af31195/mypy-1.15.0-cp310-cp310-win_amd64.whl", hash = "sha256:6983aae8b2f653e098edb77f893f7b6aca69f6cffb19b2cc7443f23cce5f4828", size = 9274701 },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/bc/f6339726c627bd7ca1ce0fa56c9ae2d0144604a319e0e339bdadafbbb599/mypy-1.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2922d42e16d6de288022e5ca321cd0618b238cfc5570e0263e5ba0a77dbef56f", size = 10662338 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/90/8dcf506ca1a09b0d17555cc00cd69aee402c203911410136cd716559efe7/mypy-1.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2ee2d57e01a7c35de00f4634ba1bbf015185b219e4dc5909e281016df43f5ee5", size = 9787540 },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/05/a10f9479681e5da09ef2f9426f650d7b550d4bafbef683b69aad1ba87457/mypy-1.15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:973500e0774b85d9689715feeffcc980193086551110fd678ebe1f4342fb7c5e", size = 11538051 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e9/9a/1f7d18b30edd57441a6411fcbc0c6869448d1a4bacbaee60656ac0fc29c8/mypy-1.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5a95fb17c13e29d2d5195869262f8125dfdb5c134dc8d9a9d0aecf7525b10c2c", size = 12286751 },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/af/19ff499b6f1dafcaf56f9881f7a965ac2f474f69f6f618b5175b044299f5/mypy-1.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:1905f494bfd7d85a23a88c5d97840888a7bd516545fc5aaedff0267e0bb54e2f", size = 12421783 },
|
||||
{ url = "https://files.pythonhosted.org/packages/96/39/11b57431a1f686c1aed54bf794870efe0f6aeca11aca281a0bd87a5ad42c/mypy-1.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:c9817fa23833ff189db061e6d2eff49b2f3b6ed9856b4a0a73046e41932d744f", size = 9265618 },
|
||||
{ url = "https://files.pythonhosted.org/packages/98/3a/03c74331c5eb8bd025734e04c9840532226775c47a2c39b56a0c8d4f128d/mypy-1.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aea39e0583d05124836ea645f412e88a5c7d0fd77a6d694b60d9b6b2d9f184fd", size = 10793981 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f0/1a/41759b18f2cfd568848a37c89030aeb03534411eef981df621d8fad08a1d/mypy-1.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2f2147ab812b75e5b5499b01ade1f4a81489a147c01585cda36019102538615f", size = 9749175 },
|
||||
{ url = "https://files.pythonhosted.org/packages/12/7e/873481abf1ef112c582db832740f4c11b2bfa510e829d6da29b0ab8c3f9c/mypy-1.15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce436f4c6d218a070048ed6a44c0bbb10cd2cc5e272b29e7845f6a2f57ee4464", size = 11455675 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/d0/92ae4cde706923a2d3f2d6c39629134063ff64b9dedca9c1388363da072d/mypy-1.15.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8023ff13985661b50a5928fc7a5ca15f3d1affb41e5f0a9952cb68ef090b31ee", size = 12410020 },
|
||||
{ url = "https://files.pythonhosted.org/packages/46/8b/df49974b337cce35f828ba6fda228152d6db45fed4c86ba56ffe442434fd/mypy-1.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1124a18bc11a6a62887e3e137f37f53fbae476dc36c185d549d4f837a2a6a14e", size = 12498582 },
|
||||
{ url = "https://files.pythonhosted.org/packages/13/50/da5203fcf6c53044a0b699939f31075c45ae8a4cadf538a9069b165c1050/mypy-1.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:171a9ca9a40cd1843abeca0e405bc1940cd9b305eaeea2dda769ba096932bb22", size = 9366614 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/9b/fd2e05d6ffff24d912f150b87db9e364fa8282045c875654ce7e32fffa66/mypy-1.15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:93faf3fdb04768d44bf28693293f3904bbb555d076b781ad2530214ee53e3445", size = 10788592 },
|
||||
{ url = "https://files.pythonhosted.org/packages/74/37/b246d711c28a03ead1fd906bbc7106659aed7c089d55fe40dd58db812628/mypy-1.15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:811aeccadfb730024c5d3e326b2fbe9249bb7413553f15499a4050f7c30e801d", size = 9753611 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a6/ac/395808a92e10cfdac8003c3de9a2ab6dc7cde6c0d2a4df3df1b815ffd067/mypy-1.15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:98b7b9b9aedb65fe628c62a6dc57f6d5088ef2dfca37903a7d9ee374d03acca5", size = 11438443 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d2/8b/801aa06445d2de3895f59e476f38f3f8d610ef5d6908245f07d002676cbf/mypy-1.15.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c43a7682e24b4f576d93072216bf56eeff70d9140241f9edec0c104d0c515036", size = 12402541 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/67/5a4268782eb77344cc613a4cf23540928e41f018a9a1ec4c6882baf20ab8/mypy-1.15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:baefc32840a9f00babd83251560e0ae1573e2f9d1b067719479bfb0e987c6357", size = 12494348 },
|
||||
{ url = "https://files.pythonhosted.org/packages/83/3e/57bb447f7bbbfaabf1712d96f9df142624a386d98fb026a761532526057e/mypy-1.15.0-cp313-cp313-win_amd64.whl", hash = "sha256:b9378e2c00146c44793c98b8d5a61039a048e31f429fb0eb546d93f4b000bedf", size = 9373648 },
|
||||
{ url = "https://files.pythonhosted.org/packages/09/4e/a7d65c7322c510de2c409ff3828b03354a7c43f5a8ed458a7a131b41c7b9/mypy-1.15.0-py3-none-any.whl", hash = "sha256:5469affef548bd1895d86d3bf10ce2b44e33d86923c29e4d675b3e323437ea3e", size = 2221777 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mypy-extensions"
|
||||
version = "1.1.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/a1/d4/1fc4078c65507b51b96ca8f8c3ba19e6a61c8253c72794544580a7b6c24d/packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f", size = 165727 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484", size = 66469 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.3.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
||||
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
|
||||
{ name = "iniconfig" },
|
||||
{ name = "packaging" },
|
||||
{ name = "pluggy" },
|
||||
{ name = "tomli", marker = "python_full_version < '3.11'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ruff"
|
||||
version = "0.11.7"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5b/89/6f9c9674818ac2e9cc2f2b35b704b7768656e6b7c139064fc7ba8fbc99f1/ruff-0.11.7.tar.gz", hash = "sha256:655089ad3224070736dc32844fde783454f8558e71f501cb207485fe4eee23d4", size = 4054861 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/ec/21927cb906c5614b786d1621dba405e3d44f6e473872e6df5d1a6bca0455/ruff-0.11.7-py3-none-linux_armv6l.whl", hash = "sha256:d29e909d9a8d02f928d72ab7837b5cbc450a5bdf578ab9ebee3263d0a525091c", size = 10245403 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e2/af/fec85b6c2c725bcb062a354dd7cbc1eed53c33ff3aa665165871c9c16ddf/ruff-0.11.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:dd1fb86b168ae349fb01dd497d83537b2c5541fe0626e70c786427dd8363aaee", size = 11007166 },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/9a/2d0d260a58e81f388800343a45898fd8df73c608b8261c370058b675319a/ruff-0.11.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:d3d7d2e140a6fbbc09033bce65bd7ea29d6a0adeb90b8430262fbacd58c38ada", size = 10378076 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c2/c4/9b09b45051404d2e7dd6d9dbcbabaa5ab0093f9febcae664876a77b9ad53/ruff-0.11.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4809df77de390a1c2077d9b7945d82f44b95d19ceccf0c287c56e4dc9b91ca64", size = 10557138 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5e/5e/f62a1b6669870a591ed7db771c332fabb30f83c967f376b05e7c91bccd14/ruff-0.11.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f3a0c2e169e6b545f8e2dba185eabbd9db4f08880032e75aa0e285a6d3f48201", size = 10095726 },
|
||||
{ url = "https://files.pythonhosted.org/packages/45/59/a7aa8e716f4cbe07c3500a391e58c52caf665bb242bf8be42c62adef649c/ruff-0.11.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:49b888200a320dd96a68e86736cf531d6afba03e4f6cf098401406a257fcf3d6", size = 11672265 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/e3/101a8b707481f37aca5f0fcc3e42932fa38b51add87bfbd8e41ab14adb24/ruff-0.11.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:2b19cdb9cf7dae00d5ee2e7c013540cdc3b31c4f281f1dacb5a799d610e90db4", size = 12331418 },
|
||||
{ url = "https://files.pythonhosted.org/packages/dd/71/037f76cbe712f5cbc7b852e4916cd3cf32301a30351818d32ab71580d1c0/ruff-0.11.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:64e0ee994c9e326b43539d133a36a455dbaab477bc84fe7bfbd528abe2f05c1e", size = 11794506 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ca/de/e450b6bab1fc60ef263ef8fcda077fb4977601184877dce1c59109356084/ruff-0.11.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bad82052311479a5865f52c76ecee5d468a58ba44fb23ee15079f17dd4c8fd63", size = 13939084 },
|
||||
{ url = "https://files.pythonhosted.org/packages/0e/2c/1e364cc92970075d7d04c69c928430b23e43a433f044474f57e425cbed37/ruff-0.11.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7940665e74e7b65d427b82bffc1e46710ec7f30d58b4b2d5016e3f0321436502", size = 11450441 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/7d/1b048eb460517ff9accd78bca0fa6ae61df2b276010538e586f834f5e402/ruff-0.11.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:169027e31c52c0e36c44ae9a9c7db35e505fee0b39f8d9fca7274a6305295a92", size = 10441060 },
|
||||
{ url = "https://files.pythonhosted.org/packages/3a/57/8dc6ccfd8380e5ca3d13ff7591e8ba46a3b330323515a4996b991b10bd5d/ruff-0.11.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:305b93f9798aee582e91e34437810439acb28b5fc1fee6b8205c78c806845a94", size = 10058689 },
|
||||
{ url = "https://files.pythonhosted.org/packages/23/bf/20487561ed72654147817885559ba2aa705272d8b5dee7654d3ef2dbf912/ruff-0.11.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a681db041ef55550c371f9cd52a3cf17a0da4c75d6bd691092dfc38170ebc4b6", size = 11073703 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/27/04f2db95f4ef73dccedd0c21daf9991cc3b7f29901a4362057b132075aa4/ruff-0.11.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:07f1496ad00a4a139f4de220b0c97da6d4c85e0e4aa9b2624167b7d4d44fd6b6", size = 11532822 },
|
||||
{ url = "https://files.pythonhosted.org/packages/e1/72/43b123e4db52144c8add336581de52185097545981ff6e9e58a21861c250/ruff-0.11.7-py3-none-win32.whl", hash = "sha256:f25dfb853ad217e6e5f1924ae8a5b3f6709051a13e9dad18690de6c8ff299e26", size = 10362436 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c5/a0/3e58cd76fdee53d5c8ce7a56d84540833f924ccdf2c7d657cb009e604d82/ruff-0.11.7-py3-none-win_amd64.whl", hash = "sha256:0a931d85959ceb77e92aea4bbedfded0a31534ce191252721128f77e5ae1f98a", size = 11566676 },
|
||||
{ url = "https://files.pythonhosted.org/packages/68/ca/69d7c7752bce162d1516e5592b1cc6b6668e9328c0d270609ddbeeadd7cf/ruff-0.11.7-py3-none-win_arm64.whl", hash = "sha256:778c1e5d6f9e91034142dfd06110534ca13220bfaad5c3735f6cb844654f6177", size = 10677936 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/18/87/302344fed471e44a87289cf4967697d07e532f2421fdaf868a303cbae4ff/tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff", size = 17175 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/43/ca/75707e6efa2b37c77dadb324ae7d9571cb424e61ea73fad7c56c2d14527f/tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249", size = 131077 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/16/51ae563a8615d472fdbffc43a3f3d46588c264ac4f024f63f01283becfbb/tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6", size = 123429 },
|
||||
{ url = "https://files.pythonhosted.org/packages/f1/dd/4f6cd1e7b160041db83c694abc78e100473c15d54620083dbd5aae7b990e/tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a", size = 226067 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a9/6b/c54ede5dc70d648cc6361eaf429304b02f2871a345bbdd51e993d6cdf550/tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee", size = 236030 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1f/47/999514fa49cfaf7a92c805a86c3c43f4215621855d151b61c602abb38091/tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e", size = 240898 },
|
||||
{ url = "https://files.pythonhosted.org/packages/73/41/0a01279a7ae09ee1573b423318e7934674ce06eb33f50936655071d81a24/tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4", size = 229894 },
|
||||
{ url = "https://files.pythonhosted.org/packages/55/18/5d8bc5b0a0362311ce4d18830a5d28943667599a60d20118074ea1b01bb7/tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106", size = 245319 },
|
||||
{ url = "https://files.pythonhosted.org/packages/92/a3/7ade0576d17f3cdf5ff44d61390d4b3febb8a9fc2b480c75c47ea048c646/tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8", size = 238273 },
|
||||
{ url = "https://files.pythonhosted.org/packages/72/6f/fa64ef058ac1446a1e51110c375339b3ec6be245af9d14c87c4a6412dd32/tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff", size = 98310 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6a/1c/4a2dcde4a51b81be3530565e92eda625d94dafb46dbeb15069df4caffc34/tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b", size = 108309 },
|
||||
{ url = "https://files.pythonhosted.org/packages/52/e1/f8af4c2fcde17500422858155aeb0d7e93477a0d59a98e56cbfe75070fd0/tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea", size = 132762 },
|
||||
{ url = "https://files.pythonhosted.org/packages/03/b8/152c68bb84fc00396b83e7bbddd5ec0bd3dd409db4195e2a9b3e398ad2e3/tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8", size = 123453 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/d6/fc9267af9166f79ac528ff7e8c55c8181ded34eb4b0e93daa767b8841573/tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192", size = 233486 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5c/51/51c3f2884d7bab89af25f678447ea7d297b53b5a3b5730a7cb2ef6069f07/tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222", size = 242349 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ab/df/bfa89627d13a5cc22402e441e8a931ef2108403db390ff3345c05253935e/tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77", size = 252159 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/6e/fa2b916dced65763a5168c6ccb91066f7639bdc88b48adda990db10c8c0b/tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6", size = 237243 },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/04/885d3b1f650e1153cbb93a6a9782c58a972b94ea4483ae4ac5cedd5e4a09/tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd", size = 259645 },
|
||||
{ url = "https://files.pythonhosted.org/packages/9c/de/6b432d66e986e501586da298e28ebeefd3edc2c780f3ad73d22566034239/tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e", size = 244584 },
|
||||
{ url = "https://files.pythonhosted.org/packages/1c/9a/47c0449b98e6e7d1be6cbac02f93dd79003234ddc4aaab6ba07a9a7482e2/tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98", size = 98875 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ef/60/9b9638f081c6f1261e2688bd487625cd1e660d0a85bd469e91d8db969734/tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4", size = 109418 },
|
||||
{ url = "https://files.pythonhosted.org/packages/04/90/2ee5f2e0362cb8a0b6499dc44f4d7d48f8fff06d28ba46e6f1eaa61a1388/tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7", size = 132708 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c0/ec/46b4108816de6b385141f082ba99e315501ccd0a2ea23db4a100dd3990ea/tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c", size = 123582 },
|
||||
{ url = "https://files.pythonhosted.org/packages/a0/bd/b470466d0137b37b68d24556c38a0cc819e8febe392d5b199dcd7f578365/tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13", size = 232543 },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/e5/82e80ff3b751373f7cead2815bcbe2d51c895b3c990686741a8e56ec42ab/tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281", size = 241691 },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/7e/2a110bc2713557d6a1bfb06af23dd01e7dde52b6ee7dadc589868f9abfac/tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272", size = 251170 },
|
||||
{ url = "https://files.pythonhosted.org/packages/64/7b/22d713946efe00e0adbcdfd6d1aa119ae03fd0b60ebed51ebb3fa9f5a2e5/tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140", size = 236530 },
|
||||
{ url = "https://files.pythonhosted.org/packages/38/31/3a76f67da4b0cf37b742ca76beaf819dca0ebef26d78fc794a576e08accf/tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2", size = 258666 },
|
||||
{ url = "https://files.pythonhosted.org/packages/07/10/5af1293da642aded87e8a988753945d0cf7e00a9452d3911dd3bb354c9e2/tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744", size = 243954 },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/b9/1ed31d167be802da0fc95020d04cd27b7d7065cc6fbefdd2f9186f60d7bd/tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec", size = 98724 },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/32/b0963458706accd9afcfeb867c0f9175a741bf7b19cd424230714d722198/tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69", size = 109383 },
|
||||
{ url = "https://files.pythonhosted.org/packages/6e/c2/61d3e0f47e2b74ef40a68b9e6ad5984f6241a942f7cd3bbfbdbd03861ea9/tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc", size = 14257 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.13.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/37/23083fcd6e35492953e8d2aaaa68b860eb422b34627b13f2ce3eb6106061/typing_extensions-4.13.2.tar.gz", hash = "sha256:e6c81219bd689f51865d9e372991c540bda33a0379d5573cddb9a3a23f7caaef", size = 106967 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8b/54/b1ae86c0973cc6f0210b53d508ca3641fb6d0c56823f288d108bc7ab3cc8/typing_extensions-4.13.2-py3-none-any.whl", hash = "sha256:a439e7c04b49fec3e5d3e2beaa21755cadbbdc391694e28ccdd36ca4a1408f8c", size = 45806 },
|
||||
]
|
1
librarian/plugins/librarian-scraper/README.md
Normal file
1
librarian/plugins/librarian-scraper/README.md
Normal file
@ -0,0 +1 @@
|
||||
# Librarian Scraper
|
41
librarian/plugins/librarian-scraper/pyproject.toml
Normal file
41
librarian/plugins/librarian-scraper/pyproject.toml
Normal file
@ -0,0 +1,41 @@
|
||||
[project]
|
||||
name = "librarian-scraper"
|
||||
version = "0.2.1"
|
||||
description = "FastAPI gateway and runtime pipeline for Librarian"
|
||||
readme = "README.md"
|
||||
authors = [{ name = "DotNaos", email = "schuetzoliver00@gmail.com" }]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"importlib_metadata; python_version<'3.10'",
|
||||
"playwright>=1.51.0",
|
||||
"dotenv>=0.9.9",
|
||||
"parsel>=1.10.0",
|
||||
"librarian-core",
|
||||
"httpx>=0.28.1",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.21"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/librarian_scraper"]
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.uv.sources]
|
||||
#librarian-core = { git = "https://github.com/DotNaos/librarian-core", rev = "dev" }
|
||||
|
||||
[project.entry-points."librarian.workers"]
|
||||
crawler = "librarian_scraper.crawler:Crawler"
|
||||
downloader = "librarian_scraper.downloader:Downloader"
|
||||
|
||||
|
||||
# ───────── optional: dev / test extras ─────────
|
||||
[project.optional-dependencies]
|
||||
dev = ["ruff", "pytest", "mypy"]
|
||||
|
||||
[project.scripts]
|
||||
example = "examples.app:app"
|
@ -0,0 +1,12 @@
|
||||
from .crawler import (
|
||||
Crawler,
|
||||
)
|
||||
from .downloader import (
|
||||
Downloader,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Crawler",
|
||||
"Downloader",
|
||||
"Extractor",
|
||||
]
|
@ -0,0 +1,29 @@
|
||||
"""
|
||||
URLs used by the scraper.
|
||||
Functions marked as PUBLIC can be accessed without authentication.
|
||||
Functions marked as PRIVATE require authentication.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://moodle.fhgr.ch"
|
||||
|
||||
CRAWLER = {
|
||||
"DELAY_SLOW": 2.0,
|
||||
"DELAY_FAST": 0.5,
|
||||
"BATCH_SLOW": 2,
|
||||
"BATCH_FAST": 8,
|
||||
}
|
||||
|
||||
class PUBLIC_URLS:
|
||||
base_url = BASE_URL
|
||||
login = f"{BASE_URL}/login/index.php"
|
||||
index = f"{BASE_URL}/course/index.php"
|
||||
degree_program = lambda degree_program_id: f"{BASE_URL}/course/index.php?categoryid={degree_program_id}"
|
||||
category = lambda category_id: f"{BASE_URL}/course/index.php?categoryid={category_id}"
|
||||
term = lambda term_id: f"{BASE_URL}/course/index.php?categoryid={term_id}"
|
||||
|
||||
class PRIVATE_URLS:
|
||||
user_courses = f"{BASE_URL}/my/courses.php"
|
||||
dashboard = f"{BASE_URL}/my/"
|
||||
course = lambda course_id: f"{BASE_URL}/course/view.php?id={course_id}"
|
||||
files = lambda context_id: f"{BASE_URL}/course/downloadcontent.php?contextid={context_id}"
|
||||
file = lambda file_id: f"{BASE_URL}/mod/resource/view.php?id={file_id}"
|
@ -0,0 +1,7 @@
|
||||
from librarian_scraper.crawler.cookie_crawler import CookieCrawler
|
||||
from librarian_scraper.crawler.crawler import Crawler
|
||||
|
||||
__all__ = [
|
||||
"CookieCrawler",
|
||||
"Crawler",
|
||||
]
|
@ -0,0 +1,138 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from httpx import Cookies
|
||||
from playwright.async_api import Browser, Cookie, Page, async_playwright
|
||||
|
||||
from librarian_scraper.constants import PRIVATE_URLS, PUBLIC_URLS
|
||||
|
||||
|
||||
class CookieCrawler:
|
||||
"""
|
||||
Retrieve Moodle session cookies + sesskey via Playwright.
|
||||
|
||||
Usage
|
||||
-----
|
||||
>>> crawler = CookieCrawler()
|
||||
>>> cookies, sesskey = await crawler.crawl() # inside async code
|
||||
# or
|
||||
>>> cookies, sesskey = CookieCrawler.crawl_sync() # plain scripts
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# construction #
|
||||
# ------------------------------------------------------------------ #
|
||||
def __init__(self, *, headless: bool = True) -> None:
|
||||
self.headless = headless
|
||||
self.cookies: Optional[List[Cookie]] = None
|
||||
self.sesskey: str = ""
|
||||
|
||||
self.username: str = os.getenv("MOODLE_USERNAME", "")
|
||||
self.password: str = os.getenv("MOODLE_PASSWORD", "")
|
||||
if not self.username or not self.password:
|
||||
raise ValueError(
|
||||
"Set MOODLE_USERNAME and MOODLE_PASSWORD as environment variables."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# public API #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def crawl(self) -> tuple[Cookies, str]:
|
||||
"""
|
||||
Async entry-point – await this inside FastAPI / Prefect etc.
|
||||
"""
|
||||
async with async_playwright() as p:
|
||||
browser: Browser = await p.chromium.launch(headless=self.headless)
|
||||
page = await browser.new_page()
|
||||
await page.goto(PUBLIC_URLS.login)
|
||||
logging.info("Login page loaded: %s", page.url)
|
||||
|
||||
await self._login(page)
|
||||
await browser.close()
|
||||
|
||||
if not self.cookies:
|
||||
raise RuntimeError("Login failed – no cookies retrieved.")
|
||||
|
||||
return self._to_cookiejar(self.cookies), self.sesskey
|
||||
|
||||
@classmethod
|
||||
def crawl_sync(cls, **kwargs) -> tuple[Cookies, str]:
|
||||
"""
|
||||
Synchronous helper for CLI / notebooks.
|
||||
|
||||
Detects whether an event loop is already running. If so, it
|
||||
schedules the coroutine and waits; otherwise it starts a fresh loop.
|
||||
"""
|
||||
self = cls(**kwargs)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError: # no loop running → safe to create one
|
||||
return asyncio.run(self.crawl())
|
||||
|
||||
# An event loop exists – schedule coroutine
|
||||
return loop.run_until_complete(self.crawl())
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# internal helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def _login(self, page: Page) -> None:
|
||||
"""Fill the SSO form and extract cookies + sesskey."""
|
||||
|
||||
# Select organisation / IdP
|
||||
await page.click("#wayf_submit_button")
|
||||
|
||||
# Wait for the credential form
|
||||
await page.wait_for_selector("form[method='post']", state="visible")
|
||||
|
||||
# Credentials
|
||||
await page.fill("input[id='username']", self.username)
|
||||
await page.fill("input[id='password']", self.password)
|
||||
await page.click("button[class='aai_login_button']")
|
||||
|
||||
# Wait for redirect to /my/ page (dashboard), this means the login is complete
|
||||
await page.wait_for_url(PRIVATE_URLS.dashboard)
|
||||
await page.wait_for_selector("body", state="attached")
|
||||
|
||||
# Navigate to personal course overview
|
||||
await page.goto(PRIVATE_URLS.user_courses)
|
||||
await page.wait_for_selector("body", state="attached")
|
||||
|
||||
# Collect session cookies
|
||||
self.cookies = await page.context.cookies()
|
||||
|
||||
# Extract sesskey from injected Moodle config
|
||||
try:
|
||||
self.sesskey = await page.evaluate(
|
||||
"() => window.M && M.cfg && M.cfg.sesskey"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise RuntimeError("sesskey not found via JS evaluation") from exc
|
||||
|
||||
if not self.sesskey:
|
||||
raise RuntimeError("sesskey is empty after evaluation.")
|
||||
|
||||
logging.debug("sesskey: %s", self.sesskey)
|
||||
logging.debug("cookies: %s", self.cookies)
|
||||
|
||||
# Dev convenience
|
||||
if not self.headless:
|
||||
await page.wait_for_timeout(5000)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# cookie conversion #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _to_cookiejar(self, raw: List[Cookie]) -> Cookies:
|
||||
jar = Cookies()
|
||||
for c in raw:
|
||||
jar.set(
|
||||
name=c.get("name", ""),
|
||||
value=c.get("value", ""),
|
||||
domain=c.get("domain", "").lstrip("."),
|
||||
path=c.get("path", "/"),
|
||||
)
|
||||
return jar
|
@ -0,0 +1,264 @@
|
||||
"""
|
||||
librarian_scraper.crawler.crawler
|
||||
---------------------------------
|
||||
Scrapes Moodle degree programmes into CrawlData.
|
||||
• Hero images
|
||||
• Polite throttling / batching
|
||||
• Term-filter: only the latest two terms (dev)
|
||||
• USER_SPECIFIC flag to keep / drop inaccessible courses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import List, Tuple
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
if sys.platform == "win32":
|
||||
# Switch from Selector to Proactor so asyncio.subprocess works
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
|
||||
|
||||
import httpx
|
||||
import parsel
|
||||
from librarian_core.utils.path_utils import get_cache_root
|
||||
from librarian_core.workers.base import Worker
|
||||
from prefect import get_run_logger, task
|
||||
from prefect.futures import wait
|
||||
|
||||
from librarian_scraper.constants import CRAWLER, PRIVATE_URLS, PUBLIC_URLS
|
||||
from librarian_scraper.crawler.cookie_crawler import CookieCrawler
|
||||
from librarian_scraper.models.crawl_data import (
|
||||
CrawlCourse,
|
||||
CrawlData,
|
||||
CrawlProgram,
|
||||
CrawlTerm,
|
||||
)
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# module-level shared items for static task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
_COOKIE_JAR: httpx.Cookies | None = None
|
||||
_DELAY: float = 0.0
|
||||
|
||||
CACHE_FILE = get_cache_root() / "librarian_no_access_cache.json"
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# utility #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def looks_like_enrol(resp: httpx.Response) -> bool:
|
||||
txt = resp.text.lower()
|
||||
return (
|
||||
"login" in str(resp.url).lower()
|
||||
or "#page-enrol" in txt
|
||||
or "you need to enrol" in txt
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# main worker #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class Crawler(Worker[CrawlProgram, CrawlData]):
|
||||
input_model = CrawlProgram
|
||||
output_model = CrawlData
|
||||
|
||||
# toggles (env overrides)
|
||||
RELAXED: bool
|
||||
USER_SPECIFIC: bool
|
||||
CLEAR_CACHE: bool
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# flow entry-point #
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, program: CrawlProgram) -> CrawlData:
|
||||
global _COOKIE_JAR, _DELAY
|
||||
lg = get_run_logger()
|
||||
|
||||
self.RELAXED = os.getenv("SCRAPER_RELAXED", "true").lower() == "true"
|
||||
self.USER_SPECIFIC = os.getenv("SCRAPER_USER_SPECIFIC", "true").lower() == "true"
|
||||
self.CLEAR_CACHE = os.getenv("SCRAPER_CLEAR_CACHE", "false").lower() == "true"
|
||||
|
||||
_DELAY = CRAWLER["DELAY_SLOW"] if self.RELAXED else CRAWLER["DELAY_FAST"]
|
||||
batch = CRAWLER["BATCH_SLOW"] if self.RELAXED else CRAWLER["BATCH_FAST"]
|
||||
lg.info(
|
||||
"Mode=%s user_specific=%s delay=%.1fs batch=%s",
|
||||
"RELAXED" if self.RELAXED else "FAST",
|
||||
self.USER_SPECIFIC,
|
||||
_DELAY,
|
||||
batch,
|
||||
)
|
||||
|
||||
# --------------------------- login
|
||||
cookies, _ = await CookieCrawler().crawl()
|
||||
_COOKIE_JAR = cookies
|
||||
self._client = httpx.Client(cookies=cookies, follow_redirects=True)
|
||||
|
||||
if not self._logged_in():
|
||||
lg.error("Guest session detected – aborting crawl.")
|
||||
raise RuntimeError("Login failed")
|
||||
|
||||
# --------------------------- cache
|
||||
no_access: set[str] = set() if self.CLEAR_CACHE else self._load_cache()
|
||||
|
||||
# --------------------------- scrape terms (first two for dev)
|
||||
terms = self._crawl_terms(program.id)[:2]
|
||||
lg.info("Terms discovered: %d", len(terms))
|
||||
|
||||
# --------------------------- scrape courses
|
||||
for term in terms:
|
||||
courses = self._crawl_courses(term.id)
|
||||
lg.info("[%s] raw courses: %d", term.name, len(courses))
|
||||
|
||||
for i in range(0, len(courses), batch):
|
||||
futs = [
|
||||
self._crawl_course_task.submit(course.id)
|
||||
for course in courses[i : i + batch]
|
||||
]
|
||||
done, _ = wait(futs)
|
||||
|
||||
for fut in done:
|
||||
cid, res_id = fut.result()
|
||||
if res_id:
|
||||
next(
|
||||
c for c in courses if c.id == cid
|
||||
).content_ressource_id = res_id
|
||||
else:
|
||||
no_access.add(cid)
|
||||
|
||||
term.courses = (
|
||||
[c for c in courses if c.content_ressource_id]
|
||||
if self.USER_SPECIFIC
|
||||
else courses
|
||||
)
|
||||
lg.info("[%s] kept: %d", term.name, len(term.courses))
|
||||
|
||||
# --------------------------- persist cache
|
||||
self._save_cache(no_access)
|
||||
|
||||
return CrawlData(
|
||||
degree_program=CrawlProgram(
|
||||
id=program.id,
|
||||
name=program.name,
|
||||
terms=[t for t in terms if t.courses],
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# static task inside class #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task(
|
||||
name="crawl_course",
|
||||
retries=2,
|
||||
retry_delay_seconds=5,
|
||||
log_prints=True,
|
||||
cache_expiration=timedelta(days=1),
|
||||
)
|
||||
def _crawl_course_task(course_id: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Returns (course_id, content_resource_id or "").
|
||||
Never raises; logs reasons instead.
|
||||
"""
|
||||
lg = get_run_logger()
|
||||
assert _COOKIE_JAR is not None
|
||||
|
||||
url = PRIVATE_URLS.course(course_id)
|
||||
for attempt in (1, 2):
|
||||
try:
|
||||
r = httpx.get(
|
||||
url, cookies=_COOKIE_JAR, follow_redirects=True, timeout=30
|
||||
)
|
||||
r.raise_for_status()
|
||||
time.sleep(_DELAY)
|
||||
break
|
||||
except Exception as exc:
|
||||
lg.warning("GET %s failed (%s) attempt %d/2", url, exc, attempt)
|
||||
time.sleep(_DELAY)
|
||||
else:
|
||||
lg.warning("Course %s unreachable.", course_id)
|
||||
return course_id, ""
|
||||
|
||||
if looks_like_enrol(r):
|
||||
lg.info("No access to course %s (enrol / login page).", course_id)
|
||||
return course_id, ""
|
||||
|
||||
href = (
|
||||
parsel.Selector(r.text)
|
||||
.css('a[data-downloadcourse="1"]::attr(href)')
|
||||
.get("")
|
||||
)
|
||||
if not href:
|
||||
lg.info("Course %s has no downloadable content.", course_id)
|
||||
return course_id, ""
|
||||
|
||||
return course_id, href.split("=")[-1]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
def _logged_in(self) -> bool:
|
||||
html = self._get_html(PUBLIC_URLS.index)
|
||||
return not parsel.Selector(text=html).css("div.usermenu span.login a")
|
||||
|
||||
def _crawl_terms(self, dp_id: str) -> List[CrawlTerm]:
|
||||
html = self._get_html(PUBLIC_URLS.degree_program(dp_id))
|
||||
sel = parsel.Selector(text=html)
|
||||
out = []
|
||||
for a in sel.css("div.category h3.categoryname a"):
|
||||
name = a.xpath("text()").get("").strip()
|
||||
if re.match(r"^(FS|HS)\d{2}$", name):
|
||||
out.append(
|
||||
CrawlTerm(name=name, id=a.xpath("@href").get("").split("=")[-1])
|
||||
)
|
||||
order = {"FS": 0, "HS": 1}
|
||||
return sorted(
|
||||
out, key=lambda t: (2000 + int(t.name[2:]), order[t.name[:2]]), reverse=True
|
||||
)
|
||||
|
||||
def _crawl_courses(self, term_id: str) -> List[CrawlCourse]:
|
||||
html = self._get_html(PUBLIC_URLS.term(term_id))
|
||||
sel = parsel.Selector(text=html)
|
||||
courses = []
|
||||
for box in sel.css("div.coursebox"):
|
||||
anchor = box.css("h3.coursename a")
|
||||
if not anchor:
|
||||
continue
|
||||
cid = anchor.attrib.get("href", "").split("=")[-1]
|
||||
raw = anchor.xpath("text()").get("").strip()
|
||||
name = re.sub(r"\s*(FS|HS)\d{2}\s*", "", raw)
|
||||
name = re.sub(r"\s*\(.*?\)\s*", "", name).strip()
|
||||
hero = box.css("div.courseimage img::attr(src)").get("") or ""
|
||||
courses.append(CrawlCourse(id=cid, name=name, hero_image=hero))
|
||||
return courses
|
||||
|
||||
def _get_html(self, url: str) -> str:
|
||||
try:
|
||||
r = self._client.get(url, timeout=30)
|
||||
r.raise_for_status()
|
||||
time.sleep(_DELAY)
|
||||
return r.text
|
||||
except Exception as exc:
|
||||
get_run_logger().warning("GET %s failed (%s)", url, exc)
|
||||
return ""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# cache helpers #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
def _load_cache() -> set[str]:
|
||||
try:
|
||||
return set(json.loads(CACHE_FILE.read_text()))
|
||||
except Exception:
|
||||
return set()
|
||||
|
||||
@staticmethod
|
||||
def _save_cache(cache: set[str]) -> None:
|
||||
try:
|
||||
CACHE_FILE.write_text(json.dumps(sorted(cache), indent=2))
|
||||
except Exception as exc:
|
||||
get_run_logger().warning("Could not save cache: %s", exc)
|
@ -0,0 +1,357 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import parsel
|
||||
from librarian_core.model import Course, DegreeProgram, FileEntry, MoodleIndex, Semester
|
||||
|
||||
from . import URLs
|
||||
|
||||
CACHE_FILENAME = "librarian_no_access_cache.json"
|
||||
NO_ACCESS_CACHE_FILE = Path(tempfile.gettempdir()) / CACHE_FILENAME
|
||||
|
||||
|
||||
class IndexCrawler:
|
||||
def __init__(self, degree_program: DegreeProgram, cookies: httpx.Cookies, debug: bool = False, *, max_workers: int = 8) -> None:
|
||||
self.degree_program = degree_program
|
||||
self.debug = debug
|
||||
self.client = httpx.Client(cookies=cookies, follow_redirects=True)
|
||||
self.max_workers = max_workers
|
||||
|
||||
# When True the cached “no-access” set is ignored for this run
|
||||
self._ignore_cache: bool = False
|
||||
|
||||
# Load persisted cache of course-IDs the user cannot access
|
||||
if NO_ACCESS_CACHE_FILE.exists():
|
||||
try:
|
||||
self._no_access_cache: set[str] = set(json.loads(NO_ACCESS_CACHE_FILE.read_text()))
|
||||
except Exception:
|
||||
logging.warning("Failed to read no-access cache, starting fresh.")
|
||||
self._no_access_cache = set()
|
||||
else:
|
||||
self._no_access_cache = set()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.client.close()
|
||||
|
||||
def __del__(self):
|
||||
# Fallback in case the context manager isn’t used
|
||||
if not self.client.is_closed:
|
||||
self.client.close()
|
||||
|
||||
"""
|
||||
Crawl a single instance of MoodleIndex.
|
||||
This returns a MoodleIndex object populated with data.
|
||||
"""
|
||||
|
||||
def crawl_index(self, userSpecific: bool = True, *, use_cache: bool = True) -> MoodleIndex:
|
||||
"""
|
||||
Build and return a `MoodleIndex`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
userSpecific : bool
|
||||
When True, include only courses that expose a downloadable content resource.
|
||||
use_cache : bool, default True
|
||||
If False, bypass the persisted “no-access” cache so every course is probed
|
||||
afresh. Newly discovered “no-access” courses are still written back to the
|
||||
cache at the end of the crawl.
|
||||
"""
|
||||
# Set runtime flag for has_user_access()
|
||||
self._ignore_cache = not use_cache
|
||||
|
||||
semesters = []
|
||||
# Get all courses for each semester and the courseid and name for each course.
|
||||
semesters = self.crawl_semesters()
|
||||
# Crawl only the latest two semesters to reduce load (remove once caching is implemented)
|
||||
for semester in semesters[:2]:
|
||||
courses = self.crawl_courses(semester)
|
||||
|
||||
# Crawl courses in parallel to speed things up
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as pool:
|
||||
list(pool.map(self.crawl_course, courses))
|
||||
|
||||
# Filter courses once all have been processed
|
||||
for course in courses:
|
||||
if userSpecific:
|
||||
if course.content_ressource_id:
|
||||
semester.courses.append(course)
|
||||
else:
|
||||
semester.courses.append(course)
|
||||
|
||||
# Only add semesters that have at least one course
|
||||
# Filter out semesters that ended up with no courses after crawling
|
||||
semesters: list[Semester] = [
|
||||
semester for semester in semesters if semester.courses
|
||||
]
|
||||
|
||||
created_index = MoodleIndex(
|
||||
degree_program=DegreeProgram(
|
||||
name=self.degree_program.name,
|
||||
id=self.degree_program.id,
|
||||
semesters=semesters,
|
||||
),
|
||||
)
|
||||
# Persist any newly discovered no-access courses
|
||||
self._save_no_access_cache()
|
||||
|
||||
# Restore default behaviour for subsequent calls
|
||||
self._ignore_cache = False
|
||||
|
||||
return created_index
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# High-level crawling helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
def crawl_semesters(self) -> list[Semester]:
|
||||
"""
|
||||
Crawl the semesters from the Moodle index page.
|
||||
"""
|
||||
url = URLs.get_degree_program_url(self.degree_program.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
if res.status_code == 200:
|
||||
semesters = self.extract_semesters(res.text)
|
||||
logging.debug(f"Found semesters: {semesters}")
|
||||
return semesters
|
||||
|
||||
return []
|
||||
|
||||
def crawl_courses(self, semester: Semester) -> list[Course]:
|
||||
"""
|
||||
Crawl the courses from the Moodle index page.
|
||||
"""
|
||||
url = URLs.get_semester_url(semester_id=semester.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
if res.status_code == 200:
|
||||
courses = self.extract_courses(res.text)
|
||||
logging.debug(f"Found courses: {courses}")
|
||||
return courses
|
||||
|
||||
return []
|
||||
|
||||
def crawl_course(self, course: Course) -> None:
|
||||
"""
|
||||
Crawl a single Moodle course page.
|
||||
"""
|
||||
|
||||
hasAccess = self.has_user_access(course)
|
||||
|
||||
if not hasAccess:
|
||||
return
|
||||
|
||||
# TODO: Cache which courses the user has no access to, to avoid repeated requests
|
||||
|
||||
course.content_ressource_id = self.crawl_content_ressource_id(course)
|
||||
course.files = self.crawl_course_files(course)
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Networking utilities
|
||||
# --------------------------------------------------------------------- #
|
||||
def get_with_retries(self, url: str, retries: int = 3, delay: int = 1) -> httpx.Response:
|
||||
"""
|
||||
Simple GET with retries and exponential back-off.
|
||||
"""
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
response = self.client.get(url)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
logging.warning(f"Request to {url} failed ({e}), attempt {attempt}/{retries}")
|
||||
if attempt < retries:
|
||||
time.sleep(delay * (2 ** (attempt - 1)))
|
||||
raise Exception(f"Failed to GET {url} after {retries} attempts")
|
||||
|
||||
def save_html(self, url: str, response: httpx.Response) -> None:
|
||||
"""
|
||||
Persist raw HTML locally for debugging.
|
||||
"""
|
||||
filename = url.split("/")[-1] + ".html"
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
f.write(response.text)
|
||||
logging.info(f"Saved HTML to {filename}")
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Extractors
|
||||
# --------------------------------------------------------------------- #
|
||||
def extract_semesters(self, html: str) -> list[Semester]:
|
||||
selector = parsel.Selector(text=html)
|
||||
|
||||
logging.info("Extracting semesters from the HTML content.")
|
||||
|
||||
semesters: list[Semester] = []
|
||||
|
||||
# Each semester sits in a collapsed container
|
||||
semester_containers = selector.css("div.category.notloaded.with_children.collapsed")
|
||||
|
||||
for container in semester_containers:
|
||||
anchor = container.css("h3.categoryname.aabtn a")
|
||||
if not anchor:
|
||||
continue
|
||||
|
||||
anchor = anchor[0]
|
||||
semester_name = (
|
||||
anchor.xpath("text()").get("").replace("\n", "").replace("\t", "").strip()
|
||||
)
|
||||
semester_id = anchor.attrib.get("href", "").split("=")[-1]
|
||||
|
||||
# Only keep semesters labeled FS or HS
|
||||
if "FS" not in semester_name and "HS" not in semester_name:
|
||||
continue
|
||||
|
||||
semesters.append(Semester(name=semester_name, id=semester_id))
|
||||
|
||||
semester_order = {
|
||||
"FS": 0, # Frühjahrs‐/Spring Semester
|
||||
"HS": 1, # Herbst‐/Fall Semester
|
||||
}
|
||||
# Sort by year and then by FS before HS
|
||||
sorted_semesters = sorted(
|
||||
semesters,
|
||||
key=lambda s: (
|
||||
2000 + int(s.name[2:]), # parse "25" → int 25, add 2000 → 2025
|
||||
semester_order[s.name[:2]] # map "FS" → 0, "HS" → 1
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
return sorted_semesters
|
||||
|
||||
def extract_courses(self, html: str) -> list[Course]:
|
||||
"""
|
||||
Parse courses and capture optional “hero_image” (overview image) if present.
|
||||
"""
|
||||
selector = parsel.Selector(text=html)
|
||||
|
||||
logging.info("Extracting courses from the HTML content.")
|
||||
|
||||
courses: list[Course] = []
|
||||
|
||||
for header in selector.css("h3.coursename"):
|
||||
anchor = header.css("a")
|
||||
if not anchor:
|
||||
logging.warning("No course anchor found in the course header.")
|
||||
continue
|
||||
|
||||
anchor = anchor[0]
|
||||
course_name = (
|
||||
anchor.xpath("text()").get("").replace("\n", "").replace("\t", "").strip()
|
||||
)
|
||||
course_id = anchor.attrib.get("href", "").split("=")[-1]
|
||||
|
||||
# Remove trailing semester tag and code patterns
|
||||
course_name = re.sub(r"\s*(FS|HS)\d{2}\s*", "", course_name)
|
||||
course_name = re.sub(r"\s*\(.*?\)\s*", "", course_name).strip()
|
||||
|
||||
# Try to locate a hero/overview image that belongs to this course box
|
||||
# Traverse up to the containing course box, then look for <div class="courseimage"><img ...>
|
||||
course_container = header.xpath('./ancestor::*[contains(@class,"coursebox")][1]')
|
||||
hero_src = (
|
||||
course_container.css("div.courseimage img::attr(src)").get("")
|
||||
if course_container else ""
|
||||
)
|
||||
|
||||
courses.append(
|
||||
Course(
|
||||
id=course_id,
|
||||
name=course_name,
|
||||
activity_type="", # TODO: Make optional
|
||||
hero_image=hero_src or ""
|
||||
)
|
||||
)
|
||||
|
||||
logging.info(f"{len(courses)} courses extracted.")
|
||||
return courses
|
||||
|
||||
def has_user_access(self, course: Course) -> bool:
|
||||
"""
|
||||
Return True only if the authenticated user can access the course (result cached).
|
||||
(i.e. the response is HTTP 200 **and** is not a redirected login/enrol page).
|
||||
"""
|
||||
if not self._ignore_cache and course.id in self._no_access_cache:
|
||||
return False
|
||||
|
||||
url = URLs.get_course_url(course.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
if res.status_code != 200:
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# Detect Moodle redirection to a login or enrolment page
|
||||
final_url = str(res.url).lower()
|
||||
if "login" in final_url or "enrol" in final_url:
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# Some enrolment pages still return 200; look for HTML markers
|
||||
if "#page-enrol" in res.text or "you need to enrol" in res.text.lower():
|
||||
self._no_access_cache.add(course.id)
|
||||
return False
|
||||
|
||||
# If we got here the user has access; otherwise cache the deny
|
||||
return True
|
||||
|
||||
def crawl_content_ressource_id(self, course: Course) -> str:
|
||||
course_id = course.id
|
||||
url = URLs.get_course_url(course_id)
|
||||
res = self.get_with_retries(url)
|
||||
psl = parsel.Selector(res.text)
|
||||
|
||||
try:
|
||||
logging.info("Searching for 'Download course content' link.")
|
||||
# Use parsel CSS selector to find the anchor tag with the specific data attribute
|
||||
download_link_selector = psl.css('a[data-downloadcourse="1"]')
|
||||
if not download_link_selector:
|
||||
raise ValueError("Download link not found.")
|
||||
|
||||
# Extract the href attribute from the first matching element
|
||||
href = download_link_selector[0].attrib.get("href")
|
||||
if not href:
|
||||
raise ValueError("Href attribute not found on the download link.")
|
||||
|
||||
context_id = href.split("=")[-1]
|
||||
course.content_ressource_id = context_id
|
||||
|
||||
return context_id
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error extracting content resource ID for course '{course.name}': {e}",
|
||||
exc_info=False,
|
||||
)
|
||||
logging.debug("Debugging info: Error accessing course content.", exc_info=True)
|
||||
return ''
|
||||
|
||||
def crawl_course_files(self, course: Course) -> list[FileEntry]:
|
||||
"""
|
||||
Crawl the course files from the Moodle course page.
|
||||
"""
|
||||
url = URLs.get_course_url(course.id)
|
||||
res = self.get_with_retries(url)
|
||||
|
||||
if res.status_code == 200:
|
||||
files = [] # TODO: either implement this or remove, because files are extracted from the .zip file
|
||||
logging.debug(f"Found files: {files}")
|
||||
return files
|
||||
|
||||
return []
|
||||
|
||||
# ----------------------------------------------------------------- #
|
||||
# Cache persistence helpers
|
||||
# ----------------------------------------------------------------- #
|
||||
def _save_no_access_cache(self) -> None:
|
||||
try:
|
||||
NO_ACCESS_CACHE_FILE.write_text(json.dumps(sorted(self._no_access_cache)))
|
||||
except Exception as exc:
|
||||
logging.warning(f"Could not persist no-access cache: {exc}")
|
@ -0,0 +1,59 @@
|
||||
# TODO: Move to librarian-core
|
||||
"""
|
||||
All URLs used in the crawler.
|
||||
Functions marked as PUBLIC can be accessed without authentication.
|
||||
Functions marked as PRIVATE require authentication.
|
||||
"""
|
||||
class URLs:
|
||||
base_url = "https://moodle.fhgr.ch"
|
||||
|
||||
@classmethod
|
||||
def get_base_url(cls):
|
||||
"""PUBLIC"""
|
||||
return cls.base_url
|
||||
|
||||
# ------------------------- Moodle URLs -------------------------
|
||||
@classmethod
|
||||
def get_login_url(cls):
|
||||
"""PUBLIC"""
|
||||
return f"{cls.base_url}/login/index.php"
|
||||
|
||||
@classmethod
|
||||
def get_index_url(cls):
|
||||
"""PUBLIC"""
|
||||
return f"{cls.base_url}/course/index.php"
|
||||
|
||||
@classmethod
|
||||
def get_degree_program_url(cls, degree_program_id):
|
||||
"""PUBLIC"""
|
||||
return f"{cls.base_url}/course/index.php?categoryid={degree_program_id}"
|
||||
|
||||
@classmethod
|
||||
def get_category_url(cls, category_id):
|
||||
"""PUBLIC"""
|
||||
return f"{cls.base_url}/course/index.php?categoryid={category_id}"
|
||||
|
||||
@classmethod
|
||||
def get_semester_url(cls, semester_id):
|
||||
"""PUBLIC"""
|
||||
return f"{cls.base_url}/course/index.php?categoryid={semester_id}"
|
||||
|
||||
@classmethod
|
||||
def get_user_courses_url(cls):
|
||||
"""PRIVATE"""
|
||||
return f"{cls.base_url}/my/courses.php"
|
||||
|
||||
@classmethod
|
||||
def get_course_url(cls, course_id):
|
||||
"""PRIVATE"""
|
||||
return f"{cls.base_url}/course/view.php?id={course_id}"
|
||||
|
||||
@classmethod
|
||||
def get_files_url(cls, context_id):
|
||||
"""PRIVATE"""
|
||||
return f"{cls.base_url}/course/downloadcontent.php?contextid={context_id}"
|
||||
|
||||
@classmethod
|
||||
def get_file_url(cls, file_id):
|
||||
"""PRIVATE"""
|
||||
return f"{cls.base_url}/mod/resource/view.php?id={file_id}"
|
@ -0,0 +1,5 @@
|
||||
from .downloader import *
|
||||
|
||||
__all__ = [
|
||||
"Downloader",
|
||||
]
|
@ -0,0 +1,151 @@
|
||||
"""
|
||||
Downloader Worker
|
||||
=================
|
||||
Input : CrawlData (from the crawler)
|
||||
Output : DownloadData (metadata only; files staged)
|
||||
|
||||
Folder tree after run
|
||||
---------------------
|
||||
export_dir/
|
||||
└─ {TERM_NAME}/
|
||||
├─ {course_id}.zip
|
||||
└─ …
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import httpx
|
||||
from librarian_core.utils.path_utils import get_temp_path
|
||||
from librarian_core.workers.base import Worker
|
||||
from prefect import get_run_logger, task
|
||||
from prefect.futures import wait
|
||||
|
||||
from librarian_scraper.constants import CRAWLER
|
||||
from librarian_scraper.crawler.cookie_crawler import CookieCrawler
|
||||
from librarian_scraper.models.crawl_data import CrawlData
|
||||
from librarian_scraper.models.download_data import (
|
||||
DownloadCourse,
|
||||
DownloadData,
|
||||
DownloadTerm,
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# helper decorator #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def task_(**kw):
|
||||
kw.setdefault("log_prints", True)
|
||||
kw.setdefault("retries", 2)
|
||||
kw.setdefault("retry_delay_seconds", 5)
|
||||
return task(**kw)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# shared state for static task #
|
||||
# --------------------------------------------------------------------------- #
|
||||
_COOKIE_JAR: httpx.Cookies | None = None
|
||||
_SESSKEY: str = ""
|
||||
_LIMIT: int = 2
|
||||
_DELAY: float = 0.0
|
||||
|
||||
|
||||
class Downloader(Worker[CrawlData, DownloadData]):
|
||||
DOWNLOAD_URL = "https://moodle.fhgr.ch/course/downloadcontent.php"
|
||||
|
||||
# tuning
|
||||
CONCURRENCY = 8
|
||||
RELAXED = True # False → faster
|
||||
|
||||
input_model = CrawlData
|
||||
output_model = DownloadData
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
async def __run__(self, crawl: CrawlData) -> DownloadData:
|
||||
global _COOKIE_JAR, _SESSKEY, _LIMIT, _DELAY
|
||||
lg = get_run_logger()
|
||||
|
||||
# ------------ login
|
||||
cookies, sesskey = await CookieCrawler().crawl()
|
||||
_COOKIE_JAR, _SESSKEY = cookies, sesskey
|
||||
|
||||
# ------------ tuning
|
||||
_LIMIT = 1 if self.RELAXED else max(1, min(self.CONCURRENCY, 8))
|
||||
_DELAY = CRAWLER["DELAY_SLOW"] if self.RELAXED else CRAWLER["DELAY_FAST"]
|
||||
|
||||
# ------------ working dir
|
||||
work_root = Path(get_temp_path()) / f"dl_{int(time.time())}"
|
||||
work_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = DownloadData()
|
||||
futures = []
|
||||
term_dirs: List[Tuple[str, Path]] = []
|
||||
|
||||
# schedule downloads
|
||||
for term in crawl.degree_program.terms:
|
||||
term_dir = work_root / term.name
|
||||
term_dir.mkdir(parents=True, exist_ok=True)
|
||||
term_dirs.append((term.name, term_dir))
|
||||
|
||||
dl_term = DownloadTerm(id=term.id, name=term.name)
|
||||
result.terms.append(dl_term)
|
||||
|
||||
for course in term.courses:
|
||||
dest = term_dir / f"{course.id}.zip"
|
||||
dl_term.courses.append(DownloadCourse(id=course.id, name=course.name))
|
||||
futures.append(
|
||||
self._download_task.submit(course.content_ressource_id, dest)
|
||||
)
|
||||
|
||||
wait(futures) # block for all downloads
|
||||
|
||||
# stage term directories
|
||||
for name, dir_path in term_dirs:
|
||||
self.stage(dir_path, new_name=name, sanitize=False, move=True)
|
||||
|
||||
lg.info("Downloader finished – staged %d term folders", len(term_dirs))
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# static task #
|
||||
# ------------------------------------------------------------------ #
|
||||
@staticmethod
|
||||
@task_()
|
||||
def _download_task(context_id: str, dest: Path) -> None:
|
||||
lg = get_run_logger()
|
||||
if not context_id:
|
||||
lg.info("Skip (no context id) → %s", dest.name)
|
||||
return
|
||||
|
||||
async def fetch() -> bool:
|
||||
sem = asyncio.Semaphore(_LIMIT)
|
||||
|
||||
async with sem:
|
||||
data = {"sesskey": _SESSKEY, "download": 1, "contextid": context_id}
|
||||
async with httpx.AsyncClient(cookies=_COOKIE_JAR) as cli:
|
||||
try:
|
||||
async with cli.stream(
|
||||
"POST", Downloader.DOWNLOAD_URL, data=data, timeout=60
|
||||
) as r:
|
||||
r.raise_for_status()
|
||||
with dest.open("wb") as fh:
|
||||
async for chunk in r.aiter_bytes():
|
||||
fh.write(chunk)
|
||||
lg.info("Downloaded %s", dest)
|
||||
return True
|
||||
except httpx.HTTPStatusError as exc:
|
||||
lg.warning(
|
||||
"HTTP %s for %s", exc.response.status_code, dest.name
|
||||
)
|
||||
except Exception as exc:
|
||||
lg.warning("Error downloading %s (%s)", dest.name, exc)
|
||||
return False
|
||||
|
||||
ok = asyncio.run(fetch())
|
||||
if not ok and dest.exists():
|
||||
dest.unlink(missing_ok=True)
|
||||
time.sleep(_DELAY)
|
@ -0,0 +1,24 @@
|
||||
from librarian_scraper.models.crawl_data import (
|
||||
CrawlCourse,
|
||||
CrawlData,
|
||||
CrawlFile,
|
||||
CrawlProgram,
|
||||
CrawlTerm,
|
||||
)
|
||||
|
||||
from librarian_scraper.models.download_data import (
|
||||
DownloadCourse,
|
||||
DownloadData,
|
||||
DownloadTerm,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CrawlData",
|
||||
"CrawlCourse",
|
||||
"CrawlFile",
|
||||
"CrawlProgram",
|
||||
"CrawlTerm",
|
||||
"DownloadData",
|
||||
"DownloadCourse",
|
||||
"DownloadTerm",
|
||||
]
|
@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
"""
|
||||
Example of a MoodleIndex (JSON):
|
||||
MoodleIndex: {
|
||||
degree_program: {
|
||||
id: '1157',
|
||||
name: 'Computational and Data Science',
|
||||
terms: [
|
||||
{
|
||||
id: '1745',
|
||||
name: 'FS25',
|
||||
courses: [
|
||||
{
|
||||
id: '18863',
|
||||
name: 'Programmierung und Prompt Engineering II',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1159522/course/overviewfiles/PythonBooks.PNG',
|
||||
content_ressource_id: '1159522',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '18240',
|
||||
name: 'Effiziente Algorithmen',
|
||||
activity_type: '',
|
||||
hero_image: '',
|
||||
content_ressource_id: '1125554',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '18237',
|
||||
name: 'Mathematik II',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1125458/course/overviewfiles/Integration_Differential_b.png',
|
||||
content_ressource_id: '1125458',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '18236',
|
||||
name: '2025 FS FHGR CDS Numerische Methoden',
|
||||
activity_type: '',
|
||||
hero_image: '',
|
||||
content_ressource_id: '1125426',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '18228',
|
||||
name: 'Datenbanken und Datenverarbeitung',
|
||||
activity_type: '',
|
||||
hero_image: '',
|
||||
content_ressource_id: '1125170',
|
||||
files: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
id: '1746',
|
||||
name: 'HS24',
|
||||
courses: [
|
||||
{
|
||||
id: '18030',
|
||||
name: 'Bootcamp Wissenschaftliches Arbeiten',
|
||||
activity_type: '',
|
||||
hero_image: '',
|
||||
content_ressource_id: '1090544',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17527',
|
||||
name: 'Einführung in Data Science',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1059194/course/overviewfiles/cds1010.jpg',
|
||||
content_ressource_id: '1059194',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17526',
|
||||
name: 'Einführung in Computational Science',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1059162/course/overviewfiles/cds_intro_sim.jpg',
|
||||
content_ressource_id: '1059162',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17525',
|
||||
name: 'Mathematik I',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1059130/course/overviewfiles/AdobeStock_452512134.png',
|
||||
content_ressource_id: '1059130',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17507',
|
||||
name: 'Programmierung und Prompt Engineering',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1058554/course/overviewfiles/10714013_33861.jpg',
|
||||
content_ressource_id: '1058554',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17505',
|
||||
name: 'Algorithmen und Datenstrukturen',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1058490/course/overviewfiles/Bild1.png',
|
||||
content_ressource_id: '1058490',
|
||||
files: [],
|
||||
},
|
||||
{
|
||||
id: '17503',
|
||||
name: 'Computer Science',
|
||||
activity_type: '',
|
||||
hero_image:
|
||||
'https://moodle.fhgr.ch/pluginfile.php/1058426/course/overviewfiles/Titelbild.jpg',
|
||||
content_ressource_id: '1058426',
|
||||
files: [],
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
},
|
||||
timestamp: '2025-04-27T14:20:11.354825+00:00',
|
||||
};
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Base Model
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlData(BaseModel):
|
||||
degree_program: CrawlProgram = Field(
|
||||
default_factory=lambda: CrawlProgram(id="", name="")
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Degree Program
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlProgram(BaseModel):
|
||||
id: str = Field("1157", description="Unique identifier for the degree program.")
|
||||
name: str = Field("Computational and Data Science", description="Name of the degree program.")
|
||||
terms: list[CrawlTerm] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Term
|
||||
# ---------------------------------------------------------------------------
|
||||
_TERM_RE = re.compile(r"^(HS|FS)\d{2}$") # HS24 / FS25 …
|
||||
|
||||
|
||||
class CrawlTerm(BaseModel):
|
||||
id: str
|
||||
name: str = Field(..., pattern=_TERM_RE.pattern) # e.g. “HS24”
|
||||
courses: list[CrawlCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Course
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlCourse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
hero_image: str = ""
|
||||
content_ressource_id: str = ""
|
||||
files: list[CrawlFile] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Files
|
||||
# ---------------------------------------------------------------------------
|
||||
class CrawlFile(BaseModel):
|
||||
id: str
|
||||
res_id: str
|
||||
name: str
|
@ -0,0 +1,18 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DownloadCourse(BaseModel):
|
||||
id: str
|
||||
name: str # Stores the name of the zip file inside the term directory
|
||||
|
||||
|
||||
class DownloadTerm(BaseModel):
|
||||
id: str
|
||||
name: str # Stores the name of the term directory inside DownloadMeta.dir
|
||||
courses: List[DownloadCourse] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DownloadData(BaseModel):
|
||||
terms: List[DownloadTerm] = Field(default_factory=list)
|
1426
librarian/plugins/librarian-scraper/uv.lock
generated
Normal file
1426
librarian/plugins/librarian-scraper/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
5
librarian/plugins/librarian-vspace/README.md
Normal file
5
librarian/plugins/librarian-vspace/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# UV Update
|
||||
```shell
|
||||
uv lock --upgrade
|
||||
uv sync
|
||||
```
|
@ -0,0 +1,47 @@
|
||||
import os
|
||||
|
||||
def chunk_file(input_file, output_dir=None, start_num=1, padding=2):
|
||||
"""
|
||||
Split a file into chunks and save each chunk as a separate file.
|
||||
|
||||
Args:
|
||||
input_file (str): Path to the input file
|
||||
output_dir (str, optional): Directory to save chunk files. Defaults to current directory.
|
||||
start_num (int, optional): Starting number for the chunk files. Defaults to 1.
|
||||
padding (int, optional): Number of digits to pad the incremental numbers. Defaults to 2.
|
||||
"""
|
||||
if output_dir and not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
with open(input_file) as f:
|
||||
content = f.read()
|
||||
chunks = content.split("---")
|
||||
|
||||
chunk_count = start_num
|
||||
for chunk in chunks:
|
||||
chunk = chunk.replace('---', '').strip()
|
||||
if not chunk: # Skip empty chunks
|
||||
continue
|
||||
|
||||
# Define output path with padded incremental number
|
||||
file_name = f'chunk_{chunk_count:0{padding}d}.md'
|
||||
if output_dir:
|
||||
outfile_path = os.path.join(output_dir, file_name)
|
||||
else:
|
||||
outfile_path = file_name
|
||||
|
||||
with open(outfile_path, 'w') as outfile:
|
||||
outfile.write(chunk)
|
||||
|
||||
chunk_count += 1
|
||||
|
||||
return chunk_count - start_num # Return the number of chunks written
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
#input_file = "/home/gra/PycharmProjects/librarian_vspace/examples/chunks/knowledge_chunks_detailed.md"
|
||||
input_file = "/home/gra/PycharmProjects/librarian_vspace/examples/chunks/knowledge_chunks_1500.md"
|
||||
# You can specify an output directory or omit it to use the current directory
|
||||
output_dir = "/examples/chunks/chunk_md_x"
|
||||
chunk_file(input_file, output_dir)
|
||||
|
@ -0,0 +1,43 @@
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""examples/demo_run_cluster_export.py
|
||||
|
||||
Launch ClusterExportWorker via FlowArtifact wrapper, mirroring the embedder demo.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from librarian_vspace.vquery.cluster_export_worker import ClusterExportWorker, ClusterExportInput
|
||||
from librarian_core.workers.base import FlowArtifact
|
||||
|
||||
COURSE_ID = 15512 # example id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _load_env(path: Path) -> None:
|
||||
if not path.is_file():
|
||||
return
|
||||
for line in path.read_text().splitlines():
|
||||
if line.strip() and not line.startswith("#") and "=" in line:
|
||||
k, v = [p.strip() for p in line.split("=", 1)]
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
async def _main() -> None:
|
||||
payload = ClusterExportInput(course_id=COURSE_ID)
|
||||
|
||||
worker = ClusterExportWorker()
|
||||
art = FlowArtifact.new(run_id="", dir=Path.cwd(), data=payload)
|
||||
result_artifact = await worker.flow()(art) # FlowArtifact
|
||||
|
||||
output = result_artifact.data # ClusterExportOutput
|
||||
logger.info("✅ Worker finished – output directory: %s", output.output_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
APP_DIR = Path(__file__).resolve().parent
|
||||
_load_env(APP_DIR / ".env")
|
||||
asyncio.run(_main())
|
@ -0,0 +1,86 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Dict
|
||||
import json
|
||||
|
||||
from librarian_vspace.vecembed.embedder_worker import EmbedderWorker, EmbedderInput
|
||||
from librarian_core.workers.base import FlowArtifact
|
||||
from librarian_core.temp_payloads.chunk_data import ChunkData
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Configuration
|
||||
# ------------------------------------------------------------------ #
|
||||
# Folder with the small sample dataset (3 × .md files)
|
||||
DEMO_PATH: Path = Path("/home/gra/PycharmProjects/librarian_vspace/examples/chunks/moodle_chunks/51cd7cf6-e782-4f17-af00-30852cdcd5fc/51cd7cf6-e782-4f17-af00-30852cdcd5fc/data/FS25/Effiziente_Algorithmen").expanduser()
|
||||
#DEMO_PATH: Path = Path("/home/gra/PycharmProjects/librarian_vspace/examples/chunks/chunk_md").expanduser()
|
||||
|
||||
# Where to write the concatenated text file
|
||||
# (one level above the dataset folder keeps things tidy)
|
||||
COURSE_ID_POOL = [16301, 16091, 17505, 18239, 17503, 15512]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
INPUT_MODEL=Path("/home/gra/PycharmProjects/librarian_vspace/examples/chunks/moodle_chunks/51cd7cf6-e782-4f17-af00-30852cdcd5fc/51cd7cf6-e782-4f17-af00-30852cdcd5fc/result.json")
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
def _load_env(path: Path) -> None:
|
||||
"""Load KEY=VALUE pairs from a .env file if present."""
|
||||
if not path.is_file():
|
||||
return
|
||||
for line in path.read_text().splitlines():
|
||||
if line.strip() and not line.startswith("#") and "=" in line:
|
||||
k, v = [p.strip() for p in line.split("=", 1)]
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
|
||||
def discover_chunks(root: Path) -> List[Path]:
|
||||
"""Return all markdown files in the dataset folder."""
|
||||
return sorted(root.glob("*.md"))
|
||||
|
||||
|
||||
def build_course(root: Path) -> Dict[str, Any]:
|
||||
"""Minimal dict that satisfies EmbedderWorker's `chunk_course`."""
|
||||
files = [
|
||||
{"file_name": p.name, "file_id": str(random.getrandbits(24))}
|
||||
for p in discover_chunks(root)
|
||||
]
|
||||
if not files:
|
||||
raise FileNotFoundError(f"No .md files found in {root}")
|
||||
return {
|
||||
"path": str(root),
|
||||
"files": files,
|
||||
#"course_id": str(random.choice(COURSE_ID_POOL)),
|
||||
"course_id": "18240"
|
||||
}
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
async def _main() -> None:
|
||||
course = build_course(DEMO_PATH)
|
||||
concat_path = DEMO_PATH
|
||||
|
||||
with open(INPUT_MODEL, 'r') as file:
|
||||
json_data = json.load(file)
|
||||
|
||||
#payload = EmbedderInput(chunk_course=course, concat_path=concat_path)
|
||||
payload = ChunkData.model_validate_json(json_data)
|
||||
worker = EmbedderWorker()
|
||||
logger.info("🔨 Launching EmbedderWorker …")
|
||||
art = FlowArtifact.new(run_id="", dir=concat_path, data=payload)
|
||||
result = await worker.flow()(art) # type: ignore[arg-type]
|
||||
|
||||
logger.info("✅ Worker finished: %s", result)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
if __name__ == "__main__":
|
||||
APP_DIR = Path(__file__).resolve().parent
|
||||
_load_env(APP_DIR / ".env")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
asyncio.run(_main())
|
@ -0,0 +1,66 @@
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""examples/demo_run_query.py
|
||||
|
||||
Runs QueryWorker via FlowArtifact wrapper (mirrors cluster export demo).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from librarian_vspace.vquery.query_worker import QueryWorker, QueryInput
|
||||
from librarian_vspace.models.query_model import VectorSearchRequest
|
||||
from librarian_core.workers.base import FlowArtifact
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Config
|
||||
# ------------------------------------------------------------------ #
|
||||
SEARCH_STRING = "integration"
|
||||
COURSE_FILTER_GT = 900 # adjust if needed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _load_env(path: Path) -> None:
|
||||
if not path.is_file():
|
||||
return
|
||||
for line in path.read_text().splitlines():
|
||||
if line.strip() and not line.startswith("#") and "=" in line:
|
||||
k, v = [p.strip() for p in line.split("=", 1)]
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
async def _main() -> None:
|
||||
# Vector search request
|
||||
vs_req = VectorSearchRequest(
|
||||
interface_name=os.getenv("EMBED_INTERFACE", "ollama"),
|
||||
model_name=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
|
||||
search_string=SEARCH_STRING,
|
||||
filters={"file_id": ("gt", COURSE_FILTER_GT)},
|
||||
top_k=10,
|
||||
)
|
||||
|
||||
payload = QueryInput(
|
||||
request=vs_req,
|
||||
db_schema=os.getenv("VECTOR_SCHEMA", "librarian"),
|
||||
rpc_function=os.getenv("VECTOR_FUNCTION", "pdf_chunking"),
|
||||
embed_model=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
|
||||
)
|
||||
|
||||
worker = QueryWorker()
|
||||
art = FlowArtifact.new(run_id="", dir=Path.cwd(), data=payload)
|
||||
result_artifact = await worker.flow()(art) # FlowArtifact
|
||||
|
||||
response = result_artifact.data # VectorSearchResponse
|
||||
logger.info("✅ Worker finished – received %s results", response.total)
|
||||
for idx, ck in enumerate(response.results, 1):
|
||||
logger.info("• %s: %s", idx, ck.chunk[:80] + ("…" if len(ck.chunk or '') > 80 else ""))
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
APP_DIR = Path(__file__).resolve().parent
|
||||
_load_env(APP_DIR / ".env")
|
||||
asyncio.run(_main())
|
@ -0,0 +1,43 @@
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""examples/demo_run_tsne_export.py
|
||||
|
||||
Launch TsneExportWorker via FlowArtifact wrapper.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from librarian_vspace.vecview.tsne_export_worker import TsneExportWorker, TsneExportInput
|
||||
from librarian_core.workers.base import FlowArtifact
|
||||
|
||||
COURSE_ID = 18240 # choose a course with embeddings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _load_env(path: Path) -> None:
|
||||
if not path.is_file():
|
||||
return
|
||||
for line in path.read_text().splitlines():
|
||||
if line.strip() and not line.startswith("#") and "=" in line:
|
||||
k, v = [p.strip() for p in line.split("=", 1)]
|
||||
os.environ.setdefault(k, v)
|
||||
|
||||
async def _main() -> None:
|
||||
payload = TsneExportInput(course_id=COURSE_ID)
|
||||
|
||||
worker = TsneExportWorker()
|
||||
art = FlowArtifact.new(run_id="", dir=Path.cwd(), data=payload)
|
||||
result_artifact = await worker.flow()(art) # FlowArtifact
|
||||
|
||||
output = result_artifact.data # TsneExportOutput
|
||||
logger.info("✅ Worker finished – JSON file: %s", output.json_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||||
APP_DIR = Path(__file__).resolve().parent
|
||||
_load_env(APP_DIR / ".env")
|
||||
asyncio.run(_main())
|
@ -0,0 +1,4 @@
|
||||
from librarian_vspace.vutils.parallelism_advisor import recommended_workers
|
||||
print(recommended_workers(kind="cpu"))
|
||||
print(recommended_workers(kind="io"))
|
||||
print(recommended_workers(kind="gpu"))
|
141
librarian/plugins/librarian-vspace/examples/run_visualizer.py
Normal file
141
librarian/plugins/librarian-vspace/examples/run_visualizer.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Loads vector data using vecmap.loader, reduces dimensions via t-SNE,
|
||||
and launches an interactive 3D visualization using vecmap.visualizer (Dash/Plotly).
|
||||
|
||||
Configuration is primarily driven by environment variables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import pandas as pd
|
||||
|
||||
# Define application directory relative to this script file
|
||||
APP_DIR = pathlib.Path(__file__).resolve().parent
|
||||
# Define the source directory containing vecmap, vutils, etc.
|
||||
SRC_DIR = APP_DIR.parent / "src"
|
||||
# Define path to .env file relative to APP_DIR
|
||||
DOTENV_PATH = APP_DIR / ".env"
|
||||
|
||||
# --- Explicitly Manage sys.path ---
|
||||
app_dir_str = str(APP_DIR)
|
||||
src_dir_str = str(SRC_DIR)
|
||||
if app_dir_str in sys.path:
|
||||
try: sys.path.remove(app_dir_str)
|
||||
except ValueError: pass
|
||||
if src_dir_str not in sys.path:
|
||||
sys.path.insert(0, src_dir_str)
|
||||
elif sys.path[0] != src_dir_str:
|
||||
try: sys.path.remove(src_dir_str)
|
||||
except ValueError: pass
|
||||
sys.path.insert(0, src_dir_str)
|
||||
print(f"[DEBUG] sys.path start: {sys.path[:3]}")
|
||||
|
||||
# --- .env Loader ---
|
||||
def _load_env_file(path: pathlib.Path) -> None:
|
||||
print(f"Attempting to load .env file from: {path}")
|
||||
if not path.is_file(): print(f".env file not found at {path}, skipping."); return
|
||||
loaded, skipped = 0, 0
|
||||
try:
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip();
|
||||
if not line or line.startswith("#") or "=" not in line: continue
|
||||
key, val = line.split("=", 1); key, val = key.strip(), val.strip()
|
||||
if key not in os.environ: os.environ[key] = val; loaded += 1
|
||||
else: skipped += 1
|
||||
print(f"Loaded {loaded} new vars, skipped {skipped} existing vars from .env")
|
||||
except Exception as e: print(f"Error reading .env file at {path}: {e}")
|
||||
_load_env_file(DOTENV_PATH)
|
||||
|
||||
# --- Logging Setup ---
|
||||
log_level_str = os.getenv("VECMAP_DEBUG", "false").lower()
|
||||
log_level = logging.DEBUG if log_level_str in ("true", "1") else logging.INFO
|
||||
logging.basicConfig(level=log_level, format='[%(asctime)s] [%(levelname)s] [%(name)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
|
||||
if log_level > logging.DEBUG:
|
||||
for logger_name in ["urllib3", "httpx", "supabase"]: logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- Imports ---
|
||||
try:
|
||||
from librarian_vspace.vecmap.loader import VectorLoader, VectorLoaderError
|
||||
from librarian_vspace.vecmap.visualizer import VectorVisualizer # Removed DEFAULT_N_CLUSTERS import
|
||||
import librarian_vspace.vutils
|
||||
import librarian_vspace.vecembed
|
||||
logger.debug("Successfully imported components.")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import necessary modules: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
# --- Main Logic ---
|
||||
def main() -> None:
|
||||
logger.info("--- Starting VecMap Visualizer ---")
|
||||
|
||||
# --- Configuration ---
|
||||
db_schema = os.getenv("VECTOR_SCHEMA", "librarian")
|
||||
db_function = os.getenv("VECTOR_FUNCTION", "pdf_chunking")
|
||||
model_name = os.getenv("EMBED_MODEL", "snowflake-arctic-embed2")
|
||||
interface_name = os.getenv("EMBED_INTERFACE", "ollama")
|
||||
embedding_column = os.getenv("EMBEDDING_COLUMN", "embedding")
|
||||
try: limit_str = os.getenv("VECMAP_LIMIT"); data_limit = int(limit_str) if limit_str else None
|
||||
except ValueError: logger.warning(f"Invalid VECMAP_LIMIT. Ignoring."); data_limit = None
|
||||
try: perplexity_str = os.getenv("VECMAP_PERPLEXITY", "30.0"); tsne_perplexity = float(perplexity_str)
|
||||
except ValueError: logger.warning(f"Invalid VECMAP_PERPLEXITY. Using 30.0."); tsne_perplexity = 30.0
|
||||
|
||||
# n_clusters configuration removed
|
||||
|
||||
dash_host = os.getenv("VECMAP_HOST", "127.0.0.1")
|
||||
try: port_str = os.getenv("VECMAP_PORT", "8050"); dash_port = int(port_str)
|
||||
except ValueError: logger.warning(f"Invalid VECMAP_PORT. Using 8050."); dash_port = 8050
|
||||
dash_debug = log_level == logging.DEBUG
|
||||
|
||||
logger.info("Effective Configuration:")
|
||||
logger.info(f" Database: schema={db_schema}, function={db_function}")
|
||||
logger.info(f" Model/Interface: model={model_name}, interface={interface_name}")
|
||||
logger.info(f" Data Params: column={embedding_column}, limit={data_limit}")
|
||||
logger.info(f" Processing: perplexity={tsne_perplexity} (n_clusters is now dynamic)") # Updated log
|
||||
logger.info(f" Server: host={dash_host}, port={dash_port}, debug={dash_debug}")
|
||||
|
||||
# --- 1. Initial Load and Reduce ---
|
||||
initial_df_reduced = pd.DataFrame()
|
||||
try:
|
||||
logger.info("Performing initial data load and processing...")
|
||||
loader = VectorLoader(schema=db_schema, function=db_function, model=model_name, embedding_column=embedding_column)
|
||||
tsne_params = {"perplexity": tsne_perplexity}
|
||||
initial_df_reduced = loader.load_and_reduce(limit=data_limit, tsne_params=tsne_params)
|
||||
if initial_df_reduced.empty: logger.warning("Initial data load resulted in an empty dataset.")
|
||||
else: logger.info(f"Successfully loaded and reduced {len(initial_df_reduced)} vectors initially.")
|
||||
except VectorLoaderError as e: logger.error(f"Initial data load failed: {e}", exc_info=dash_debug)
|
||||
except Exception as e: logger.error(f"Unexpected error during initial data load: {e}", exc_info=dash_debug)
|
||||
|
||||
# --- 2. Initialize and Start Visualization ---
|
||||
try:
|
||||
logger.info("Initializing VectorVisualizer...")
|
||||
visualizer = VectorVisualizer(
|
||||
initial_data=initial_df_reduced,
|
||||
db_schema=db_schema,
|
||||
db_function=db_function,
|
||||
interface_name=interface_name,
|
||||
model_name=model_name,
|
||||
embedding_column=embedding_column,
|
||||
initial_limit=data_limit,
|
||||
initial_perplexity=tsne_perplexity
|
||||
# n_clusters argument removed
|
||||
)
|
||||
logger.info("Launching visualizer...")
|
||||
visualizer.run(host=dash_host, port=dash_port, debug=dash_debug)
|
||||
except TypeError as te:
|
||||
logger.error(f"TypeError during VectorVisualizer initialization: {te}", exc_info=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize or run visualizer: {e}", exc_info=dash_debug)
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("--- VecMap Visualizer finished ---")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
52
librarian/plugins/librarian-vspace/pyproject.toml
Normal file
52
librarian/plugins/librarian-vspace/pyproject.toml
Normal file
@ -0,0 +1,52 @@
|
||||
[project]
|
||||
name = "librarian-vspace"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "TheOriginalGraLargeShrimpakaReaper", email = "graber-michael@hotmail.com" }
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"librarian-core",
|
||||
"importlib_metadata; python_version<'3.10'",
|
||||
"dotenv>=0.9.9",
|
||||
"psycopg2-binary>=2.9.10",
|
||||
"python-dotenv>=1.1.0",
|
||||
"requests>=2.32.3",
|
||||
"supabase>=2.15.0",
|
||||
"numpy>=2.2.5",
|
||||
"dash>=3.0.4",
|
||||
"scikit-learn>=1.6.1",
|
||||
"plotly>=6.0.1",
|
||||
"pandas>=2.2.3",
|
||||
"pathlib>=1.0.1",
|
||||
"prefect>=3.4.1",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
librarian-core = { git = "https://github.com/DotNaos/librarian-core", rev = "dev" }
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling>=1.21"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/librarian_vspace"]
|
||||
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
|
||||
# ───────── optional: dev / test extras ─────────
|
||||
[project.optional-dependencies]
|
||||
dev = ["ruff", "pytest", "mypy"]
|
||||
|
||||
[project.entry-points."librarian.workers"]
|
||||
embedder = "librarian_vspace.vecembed:EmbedderWorker"
|
||||
clusterexporter = "librarian_vspace.vquery:ClusterExportWorker"
|
||||
tnseexport = "librarian_vspace.vecview:TsneExportWorker"
|
||||
vectorquerying = "librarian_vspace.vquery:QueryWorker"
|
||||
|
@ -0,0 +1,22 @@
|
||||
|
||||
"""Embedding‑related helpers."""
|
||||
import pkgutil
|
||||
import importlib
|
||||
|
||||
__all__ = []
|
||||
|
||||
# Iterate over all modules in this package
|
||||
for finder, module_name, is_pkg in pkgutil.iter_modules(__path__):
|
||||
# import the sub-module
|
||||
module = importlib.import_module(f"{__name__}.{module_name}")
|
||||
|
||||
# decide which names to re-export:
|
||||
# use module.__all__ if it exists, otherwise every non-private attribute
|
||||
public_names = getattr(
|
||||
module, "__all__", [n for n in dir(module) if not n.startswith("_")]
|
||||
)
|
||||
|
||||
# bring each name into the package namespace
|
||||
for name in public_names:
|
||||
globals()[name] = getattr(module, name)
|
||||
__all__.append(name) # type: ignore
|
@ -0,0 +1,38 @@
|
||||
|
||||
"""Pydantic models for vector search requests and responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VectorSearchRequest(BaseModel):
|
||||
"""Input payload for a vector search."""
|
||||
|
||||
interface_name: str = Field(..., description="Name of the embedding interface")
|
||||
model_name: str = Field(..., description="Name of the embedding model")
|
||||
search_string: str = Field(..., description="The natural language query to embed and search for")
|
||||
filters: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional key/value filters applied server‑side",
|
||||
)
|
||||
top_k: int = Field(10, ge=1, le=100, description="Number of matches to return")
|
||||
embedding_column: str = Field(
|
||||
"embedding",
|
||||
description="Name of the embedding column in the database table",
|
||||
)
|
||||
|
||||
|
||||
class Chunklet(BaseModel):
|
||||
"""Single result row returned by the database RPC."""
|
||||
|
||||
chunk: Optional[str] = None
|
||||
file_id: Optional[str | int] = None
|
||||
|
||||
|
||||
class VectorSearchResponse(BaseModel):
|
||||
"""Output payload wrapping vector‑search results."""
|
||||
|
||||
total: int
|
||||
results: List[Chunklet]
|
@ -0,0 +1,31 @@
|
||||
|
||||
"""Data models for t‑SNE exports.
|
||||
|
||||
These models are used by *vecview* and any endpoint that needs to return or
|
||||
validate t‑SNE projection data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TSNEPoint(BaseModel):
|
||||
"""A single point in a 3‑D t‑SNE projection."""
|
||||
|
||||
x: float
|
||||
y: float
|
||||
z: float
|
||||
file_id: str
|
||||
chunk: str
|
||||
cluster: Optional[str] = None
|
||||
hover_text: Optional[str] = None
|
||||
|
||||
|
||||
class TSNEData(BaseModel):
|
||||
"""Container returned to callers requesting a t‑SNE view."""
|
||||
|
||||
course_id: Optional[int] = None
|
||||
total: int
|
||||
points: List[TSNEPoint]
|
@ -0,0 +1,9 @@
|
||||
|
||||
"""Embedding‑related helpers."""
|
||||
from __future__ import annotations
|
||||
|
||||
from .vector_inserter import VectorInserter
|
||||
from .embedding_generator import EmbeddingGenerator
|
||||
from .embedding_workflow import EmbeddingWorkflow
|
||||
|
||||
__all__ = ["VectorInserter", "EmbeddingGenerator", "EmbeddingWorkflow"]
|
@ -0,0 +1,155 @@
|
||||
|
||||
"""Parallel‑aware embedding helpers.
|
||||
|
||||
* **embed_single_file()** – embed one file (sync).
|
||||
* **run_embedder()** – embed all files in a course (async, kept for back‑compat).
|
||||
* **_create_hnsw_index()** – helper to (re)build PGVector HNSW index.
|
||||
|
||||
This file contains no Prefect code; it’s pure embedding logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List, Union
|
||||
|
||||
from postgrest import APIResponse
|
||||
|
||||
from librarian_core.temp_payloads.chunk_data import ChunkCourse, ChunkFile
|
||||
from librarian_vspace.vecembed.embedding_generator import EmbeddingGenerator
|
||||
from librarian_vspace.vecembed.vector_inserter import VectorInserter
|
||||
from librarian_vspace.vecembed.embedding_workflow import EmbeddingWorkflow
|
||||
from librarian_vspace.vutils.supabase_singleton import MySupabase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _autodiscover_pg_conn():
|
||||
supa = MySupabase.get_client() # type: ignore
|
||||
if supa is None:
|
||||
raise RuntimeError("MySupabase.get_client() returned None – no DB connection.")
|
||||
return supa
|
||||
|
||||
|
||||
def _create_hnsw_index(
|
||||
supa,
|
||||
table_fqn: str,
|
||||
*,
|
||||
column_name: str = "embedding",
|
||||
query_operator: str = "<=>",
|
||||
m: int = 16,
|
||||
ef: int = 64,
|
||||
) -> None:
|
||||
if "." not in table_fqn:
|
||||
raise ValueError("table_fqn must be schema.table")
|
||||
schema, table = table_fqn.split(".", 1)
|
||||
try:
|
||||
supa.schema(schema).rpc(
|
||||
"create_or_reindex_hnsw",
|
||||
dict(
|
||||
p_schema=schema,
|
||||
p_table=table,
|
||||
p_column=column_name,
|
||||
p_operator=query_operator,
|
||||
p_m=m,
|
||||
p_ef=ef,
|
||||
),
|
||||
).execute()
|
||||
except Exception:
|
||||
logger.exception("Failed to run create_or_reindex_hnsw")
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# single file #
|
||||
# --------------------------------------------------------------------------- #
|
||||
def embed_single_file(
|
||||
*,
|
||||
course_id: str,
|
||||
file_entry: dict | ChunkFile | SimpleNamespace,
|
||||
concat_path: Union[str, Path],
|
||||
db_schema: str = "librarian",
|
||||
db_function: str = "pdf_chunking",
|
||||
interface_name: str = "ollama",
|
||||
model_name: str = "snowflake-arctic-embed2",
|
||||
file_type: str = "md",
|
||||
) -> Path | None:
|
||||
|
||||
if isinstance(file_entry, (dict, SimpleNamespace)):
|
||||
file_name = file_entry["file_name"] if isinstance(file_entry, dict) else file_entry.file_name
|
||||
file_id = file_entry["file_id"] if isinstance(file_entry, dict) else file_entry.file_id
|
||||
else:
|
||||
file_name, file_id = file_entry.file_name, file_entry.file_id
|
||||
|
||||
chunk_path = Path(concat_path) / file_name
|
||||
if not chunk_path.exists():
|
||||
logger.warning("Missing chunk file %s – skipping", chunk_path)
|
||||
return None
|
||||
|
||||
generator = EmbeddingGenerator()
|
||||
inserter = VectorInserter(schema=db_schema, function=db_function, model=model_name)
|
||||
|
||||
wf = EmbeddingWorkflow(
|
||||
chunk_path=chunk_path,
|
||||
course_id=course_id,
|
||||
file_id=file_id,
|
||||
file_type=file_type,
|
||||
interface_name=interface_name,
|
||||
model_name=model_name,
|
||||
generator=generator,
|
||||
inserter=inserter,
|
||||
)
|
||||
wf.process()
|
||||
return chunk_path
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
async def run_embedder(
|
||||
course: ChunkCourse,
|
||||
concat_path: Union[str, Path],
|
||||
*,
|
||||
db_schema: str = "librarian",
|
||||
db_function: str = "pdf_chunking",
|
||||
interface_name: str = "ollama",
|
||||
model_name: str = "snowflake-arctic-embed2",
|
||||
file_type: str = "md",
|
||||
vector_column: str = "embedding",
|
||||
query_operator: str = "<=>",
|
||||
hnsw_m: int = 16,
|
||||
hnsw_ef: int = 64,
|
||||
max_parallel_files: int | None = None,
|
||||
) -> Path:
|
||||
|
||||
supa_client = _autodiscover_pg_conn()
|
||||
root = Path(concat_path)
|
||||
sem = asyncio.Semaphore(max_parallel_files or len(course.files) or 1)
|
||||
|
||||
async def _wrapper(cf):
|
||||
async with sem:
|
||||
return await asyncio.to_thread(
|
||||
embed_single_file,
|
||||
course_id=course.course_id,
|
||||
file_entry=cf,
|
||||
concat_path=root,
|
||||
db_schema=db_schema,
|
||||
db_function=db_function,
|
||||
interface_name=interface_name,
|
||||
model_name=model_name,
|
||||
file_type=file_type,
|
||||
)
|
||||
|
||||
await asyncio.gather(*[asyncio.create_task(_wrapper(cf)) for cf in course.files])
|
||||
|
||||
inserter = VectorInserter(schema=db_schema, function=db_function, model=model_name)
|
||||
_create_hnsw_index(
|
||||
supa_client,
|
||||
inserter.table_fqn(),
|
||||
column_name=vector_column,
|
||||
query_operator=query_operator,
|
||||
m=hnsw_m,
|
||||
ef=hnsw_ef,
|
||||
)
|
||||
return root
|
||||
|
||||
__all__ = ["embed_single_file", "run_embedder", "_create_hnsw_index", "_autodiscover_pg_conn"]
|
@ -0,0 +1,67 @@
|
||||
|
||||
"""EmbedderWorker – Prefect‑mapped per file."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List
|
||||
|
||||
from prefect import get_run_logger, task, unmapped
|
||||
from pydantic import BaseModel, Field
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
@task(name="embed_file", retries=2, retry_delay_seconds=5, log_prints=True, tags=["embed_file"])
|
||||
def embed_file_task(course_dict: dict | SimpleNamespace, file_entry: dict, concat_path: Path) -> Path | None:
|
||||
from librarian_vspace.vecembed.embedder import embed_single_file
|
||||
cid = course_dict["course_id"] if isinstance(course_dict, dict) else course_dict.course_id
|
||||
return embed_single_file(course_id=cid, file_entry=file_entry, concat_path=concat_path)
|
||||
|
||||
class EmbedderInput(BaseModel):
|
||||
chunk_courses: List[Any] = Field(default_factory=list, alias="chunk_courses")
|
||||
concat_path: Path
|
||||
chunk_course: Any | None = None
|
||||
def model_post_init(self, _):
|
||||
if not self.chunk_courses and self.chunk_course is not None:
|
||||
self.chunk_courses = [self.chunk_course]
|
||||
model_config = dict(populate_by_name=True, extra="allow")
|
||||
|
||||
class EmbedderOutput(BaseModel):
|
||||
result_paths: List[Path]
|
||||
|
||||
class EmbedderWorker(Worker[EmbedderInput, EmbedderOutput]):
|
||||
input_model = EmbedderInput
|
||||
output_model = EmbedderOutput
|
||||
|
||||
async def __run__(self, payload: EmbedderInput) -> EmbedderOutput:
|
||||
log = get_run_logger()
|
||||
total_files = sum(len(c["files"]) if isinstance(c, dict) else len(c.files) for c in payload.chunk_courses)
|
||||
log.info("Embedding %d files", total_files)
|
||||
|
||||
result_paths: List[Path] = []
|
||||
|
||||
# constants – could be parameterised later
|
||||
schema = "librarian"
|
||||
func = "pdf_chunking"
|
||||
model_name = "snowflake-arctic-embed2"
|
||||
|
||||
for course in payload.chunk_courses:
|
||||
files = course["files"] if isinstance(course, dict) else course.files
|
||||
futures = embed_file_task.map(unmapped(course), files, unmapped(payload.concat_path))
|
||||
for fut in futures:
|
||||
path = fut.result()
|
||||
if path:
|
||||
result_paths.append(path)
|
||||
|
||||
# rebuild index once per course
|
||||
from librarian_vspace.vecembed.embedder import _create_hnsw_index, _autodiscover_pg_conn
|
||||
from librarian_vspace.vecembed.vector_inserter import VectorInserter
|
||||
|
||||
supa = _autodiscover_pg_conn()
|
||||
inserter = VectorInserter(schema=schema, function=func, model=model_name)
|
||||
_create_hnsw_index(supa, inserter.table_fqn())
|
||||
|
||||
for p in result_paths:
|
||||
self.stage(p, new_name=p.name)
|
||||
|
||||
return EmbedderOutput(result_paths=result_paths)
|
@ -0,0 +1,21 @@
|
||||
|
||||
"""Factory for embedding back‑ends."""
|
||||
import logging
|
||||
from typing import Any, List, Optional, Tuple, Dict, Type
|
||||
|
||||
from librarian_vspace.vecembed.embedding_interface import EmbeddingInterface
|
||||
from librarian_vspace.vecembed.ollama_embedder import OllamaEmbedder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingGenerator:
|
||||
_registry: Dict[str, Type[EmbeddingInterface]] = {
|
||||
"ollama": OllamaEmbedder,
|
||||
}
|
||||
|
||||
def generate_embedding(self, interface_name: str, model_name: str, text_to_embed: str, identifier: Any) -> Tuple[str, Optional[List[float]], Any]:
|
||||
cls = self._registry.get(interface_name.lower())
|
||||
if not cls:
|
||||
raise ValueError(f"Unsupported embedding interface: {interface_name}")
|
||||
embedder = cls(model_name=model_name)
|
||||
return embedder.embed(text_to_embed, identifier)
|
@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
class EmbeddingInterface(ABC):
|
||||
"""Contract for any embedding service implementation."""
|
||||
|
||||
def __init__(self, model_name: str, **kwargs: Any) -> None:
|
||||
self.model_name = model_name
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, text_or_chunk: str, identifier: Any) -> Tuple[str, Optional[List[float]], Any]:
|
||||
"""Return (original_text, embedding, identifier) — embedding may be None on failure."""
|
||||
pass
|
@ -0,0 +1,92 @@
|
||||
"""Orchestrates loading, embedding, and storing a text chunk."""
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Dict, Union
|
||||
|
||||
# Import the worker classes for type hinting
|
||||
from librarian_vspace.vecembed.embedding_generator import EmbeddingGenerator
|
||||
from librarian_vspace.vecembed.vector_inserter import VectorInserter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EmbeddingWorkflow:
|
||||
# Accept generator and inserter instances in __init__
|
||||
def __init__(self,
|
||||
chunk_path: Union[str, Path],
|
||||
course_id: Any,
|
||||
file_id: Any,
|
||||
file_type: str,
|
||||
interface_name: str, # Still needed for generate_embedding method
|
||||
model_name: str, # Still needed for generate_embedding method
|
||||
generator: EmbeddingGenerator, # Accept pre-instantiated generator
|
||||
inserter: VectorInserter, # Accept pre-instantiated inserter
|
||||
# db_schema and db_function are now implicit via the inserter
|
||||
# db_schema: str = "librarian",
|
||||
# db_function: str = "pdf_chunking",
|
||||
):
|
||||
self.chunk_path = Path(chunk_path)
|
||||
self.course_id = course_id
|
||||
self.file_id = file_id
|
||||
self.file_type = file_type
|
||||
# Keep interface_name and model_name as they are passed to the generator's method
|
||||
self.interface_name = interface_name
|
||||
self.model_name = model_name
|
||||
|
||||
# Assign the passed instances instead of creating new ones
|
||||
self.generator = generator
|
||||
self.inserter = inserter
|
||||
|
||||
# No need to store db_schema/db_function here if inserter handles it
|
||||
|
||||
# ---------------- helpers ----------------
|
||||
def _load_chunk(self) -> Optional[str]:
|
||||
try:
|
||||
text = self.chunk_path.read_text(encoding="utf-8").strip()
|
||||
if not text:
|
||||
logger.warning("Chunk %s is empty", self.chunk_path)
|
||||
return None
|
||||
return text
|
||||
except Exception as exc:
|
||||
logger.error("Failed to read %s: %s", self.chunk_path, exc)
|
||||
return None
|
||||
|
||||
def process(self) -> bool:
|
||||
chunk_text = self._load_chunk()
|
||||
if chunk_text is None:
|
||||
return False
|
||||
|
||||
# Use the shared generator instance
|
||||
original_text, vector, _ = self.generator.generate_embedding(
|
||||
interface_name=self.interface_name, # Pass parameters to the method
|
||||
model_name=self.model_name, # Pass parameters to the method
|
||||
text_to_embed=chunk_text,
|
||||
identifier=self.file_id,
|
||||
)
|
||||
|
||||
if vector is None:
|
||||
# Log failure within generator if not already done, or here
|
||||
logger.error(f"Failed to generate embedding for {self.chunk_path}")
|
||||
return False
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"chunk": original_text,
|
||||
"course_id": self.course_id,
|
||||
"file_id": self.file_id,
|
||||
"file_type": self.file_type,
|
||||
"embedding": vector,
|
||||
}
|
||||
|
||||
# Use the shared inserter instance
|
||||
insert_result = self.inserter.insert_vector(payload)
|
||||
|
||||
if insert_result is None:
|
||||
logger.error(f"Failed to insert vector for {self.chunk_path}")
|
||||
return False
|
||||
|
||||
logger.debug(f"Successfully processed and inserted {self.chunk_path}")
|
||||
return True # Indicate success
|
||||
|
||||
|
||||
# Keep __all__ if needed
|
||||
# __all__ = ["EmbeddingWorkflow"]
|
@ -0,0 +1,44 @@
|
||||
|
||||
"""Ollama-based embedding implementation (env handled at application layer)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import requests
|
||||
|
||||
from librarian_vspace.vecembed.embedding_interface import EmbeddingInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaEmbedder(EmbeddingInterface):
|
||||
def __init__(self, model_name: str, **kwargs: Any) -> None:
|
||||
super().__init__(model_name=model_name)
|
||||
self.base_url = os.getenv("OLLAMA_BASE_URL")
|
||||
if not self.base_url:
|
||||
raise ValueError("OLLAMA_BASE_URL not configured – ensure env is set in the examples layer")
|
||||
self.api_endpoint = f"{self.base_url.rstrip('/')}/api/embeddings"
|
||||
|
||||
|
||||
def embed(self, text_or_chunk: str, identifier: Any) -> Tuple[str, Optional[List[float]], Any]:
|
||||
payload = {"model": self.model_name, "prompt": text_or_chunk}
|
||||
vector: Optional[List[float]] = None
|
||||
try:
|
||||
logger.debug("Requesting embedding for id=%s", identifier)
|
||||
resp = requests.post(self.api_endpoint, json=payload, timeout=3600, headers={"Content-Type": "application/json"})
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
if isinstance(data.get("embedding"), list):
|
||||
vector = data["embedding"]
|
||||
logger.debug("Received embedding dim=%d for id=%s", len(vector), identifier)
|
||||
else:
|
||||
logger.error("Invalid response from Ollama: %s", data)
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error("Timeout contacting Ollama at %s", self.api_endpoint)
|
||||
except requests.exceptions.RequestException as exc:
|
||||
logger.error("HTTP error contacting Ollama: %s", exc)
|
||||
except Exception:
|
||||
logger.exception("Unexpected error during embed for id=%s", identifier)
|
||||
return text_or_chunk, vector, identifier
|
@ -0,0 +1,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from librarian_vspace.vutils.vector_class import BaseVectorOperator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VectorInserter(BaseVectorOperator):
|
||||
"""High-level write helper for embeddings."""
|
||||
|
||||
def insert_vector(self, data: Dict[str, Any]) -> Optional[List[Dict[str, Any]]]:
|
||||
if not self.table:
|
||||
logger.error("Table resolution failed earlier")
|
||||
return None
|
||||
preview = {k: (f"<vector,len={len(v)}>" if k == "embedding" else v) for k, v in data.items()}
|
||||
logger.debug("Insert → %s.%s :: %s", self.schema, self.table, preview)
|
||||
try:
|
||||
resp = self.spc.schema(self.schema).table(self.table).insert(data).execute()
|
||||
return resp.data if isinstance(resp.data, list) else []
|
||||
except Exception:
|
||||
logger.exception("Insert failed for %s", self.table_fqn())
|
||||
return None
|
@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from librarian_vspace!"
|
@ -0,0 +1,264 @@
|
||||
"""Loads vectors from Supabase, reduces dimensions using t-SNE."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import json # Import json for parsing
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.manifold import TSNE
|
||||
|
||||
# Assuming vutils is installed or in the python path
|
||||
try:
|
||||
from librarian_vspace.vutils.vector_class import BaseVectorOperator
|
||||
except ImportError as e:
|
||||
logging.error(f"Failed to import vutils: {e}. Ensure vutils package is installed.")
|
||||
raise
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VectorLoaderError(Exception):
|
||||
"""Custom exception for loader errors."""
|
||||
pass
|
||||
|
||||
|
||||
class VectorLoader:
|
||||
"""Fetches vectors and applies t-SNE."""
|
||||
|
||||
DEFAULT_TSNE_PARAMS = {
|
||||
"n_components": 3,
|
||||
"perplexity": 30.0, # Adjust based on dataset size (5-50 typically)
|
||||
"n_iter": 1000, # Minimum recommended iterations
|
||||
"learning_rate": "auto", # Usually a good default
|
||||
"init": "pca", # PCA initialization is often faster and more stable
|
||||
"random_state": 42, # For reproducibility
|
||||
"n_jobs": -1, # Use all available CPU cores
|
||||
"verbose": 1, # Log progress (controls scikit-learn's verbosity)
|
||||
}
|
||||
|
||||
def __init__(self, schema: str, function: str, model: str, embedding_column: str = "embedding"):
|
||||
"""
|
||||
Initializes the loader.
|
||||
(Constructor remains the same)
|
||||
"""
|
||||
logger.info(f"Initializing VectorLoader for {schema=}, {function=}, {model=}")
|
||||
try:
|
||||
self.operator = BaseVectorOperator(schema=schema, function=function, model=model)
|
||||
self.embedding_column = embedding_column
|
||||
if not self.operator.table:
|
||||
raise VectorLoaderError("BaseVectorOperator failed to resolve table.")
|
||||
logger.info(f"Target table resolved to: {self.operator.table_fqn()}")
|
||||
except (ImportError, ValueError, RuntimeError) as e:
|
||||
logger.exception("Failed to initialize BaseVectorOperator.")
|
||||
raise VectorLoaderError(f"Failed to initialize BaseVectorOperator: {e}") from e
|
||||
|
||||
|
||||
def _parse_vector_string(self, vector_str: Any) -> Optional[List[float]]:
|
||||
"""Safely parses the string representation of a vector into a list of floats."""
|
||||
if not isinstance(vector_str, str):
|
||||
# If it's already a list (less likely now, but safe check), return it if valid
|
||||
if isinstance(vector_str, list) and all(isinstance(n, (int, float)) for n in vector_str):
|
||||
return vector_str # Assume it's already correctly parsed
|
||||
logger.debug(f"Unexpected type for vector parsing: {type(vector_str)}. Skipping.")
|
||||
return None
|
||||
try:
|
||||
# Use json.loads which correctly handles [...] syntax
|
||||
parsed_list = json.loads(vector_str)
|
||||
if isinstance(parsed_list, list) and all(isinstance(n, (int, float)) for n in parsed_list):
|
||||
return [float(n) for n in parsed_list] # Ensure elements are floats
|
||||
else:
|
||||
logger.warning(f"Parsed vector string '{vector_str[:50]}...' but result is not a list of numbers.")
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to JSON decode vector string: '{vector_str[:50]}...'")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error parsing vector string '{vector_str[:50]}...': {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def fetch_all_vectors(self, limit: Optional[int] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Fetches all vectors and metadata from the resolved table.
|
||||
Parses string representations of vectors into lists.
|
||||
|
||||
Args:
|
||||
limit: Optional limit on the number of rows to fetch (for large tables).
|
||||
|
||||
Returns:
|
||||
A pandas DataFrame with columns like 'file_id', 'chunk', 'embedding' (as list).
|
||||
|
||||
Raises:
|
||||
VectorLoaderError: If fetching fails or no data is found.
|
||||
"""
|
||||
if not self.operator.table or not self.operator.schema:
|
||||
raise VectorLoaderError("Operator not initialized, table name or schema is unknown.")
|
||||
|
||||
table_name = self.operator.table
|
||||
schema_name = self.operator.schema
|
||||
select_columns = f"file_id, chunk, {self.embedding_column}"
|
||||
|
||||
logger.info(f"Fetching data from {schema_name}.{table_name} (columns: {select_columns})...")
|
||||
try:
|
||||
query = self.operator.spc.schema(schema_name).table(table_name).select(select_columns)
|
||||
if limit:
|
||||
logger.info(f"Applying limit: {limit}")
|
||||
query = query.limit(limit)
|
||||
response = query.execute()
|
||||
|
||||
if not response.data:
|
||||
logger.warning(f"No data found in table {self.operator.table_fqn()}.")
|
||||
return pd.DataFrame(columns=['file_id', 'chunk', self.embedding_column])
|
||||
|
||||
logger.info(f"Fetched {len(response.data)} rows.")
|
||||
df = pd.DataFrame(response.data)
|
||||
|
||||
# --- FIX: Parse the embedding string into a list ---
|
||||
logger.info(f"Parsing string representation in '{self.embedding_column}' column...")
|
||||
parsed_embeddings = df[self.embedding_column].apply(self._parse_vector_string)
|
||||
# Overwrite the original string column with the parsed list (or None if parsing failed)
|
||||
df[self.embedding_column] = parsed_embeddings
|
||||
logger.debug(f"Sample '{self.embedding_column}' data after parsing (first 5 rows):\n{df[[self.embedding_column]].head()}")
|
||||
# --- END FIX ---
|
||||
|
||||
|
||||
# === Enhanced Debugging for Embedding Column (Now checks the parsed list) ===
|
||||
logger.info(f"Checking validity of parsed '{self.embedding_column}' column...")
|
||||
if self.embedding_column not in df.columns:
|
||||
raise VectorLoaderError(f"Required embedding column '{self.embedding_column}' missing after processing.")
|
||||
|
||||
# 1. Check for NULLs (includes rows where parsing failed and returned None)
|
||||
initial_count = len(df)
|
||||
null_mask = df[self.embedding_column].isnull()
|
||||
null_count = null_mask.sum()
|
||||
if null_count > 0:
|
||||
logger.warning(f"Found {null_count} rows with NULL or unparsable vectors in '{self.embedding_column}'.")
|
||||
|
||||
df_no_nulls = df.dropna(subset=[self.embedding_column])
|
||||
count_after_null_drop = len(df_no_nulls)
|
||||
logger.debug(f"{count_after_null_drop} rows remaining after dropping NULLs/unparsable.")
|
||||
|
||||
# 2. Check for non-empty list type (This check might be slightly redundant now if parsing worked, but keep for safety)
|
||||
if not df_no_nulls.empty:
|
||||
def is_valid_list(x):
|
||||
# Check should pass if parsing was successful
|
||||
return isinstance(x, list) and len(x) > 0
|
||||
|
||||
valid_list_mask = df_no_nulls[self.embedding_column].apply(is_valid_list)
|
||||
invalid_list_count = len(df_no_nulls) - valid_list_mask.sum()
|
||||
|
||||
if invalid_list_count > 0:
|
||||
# This indicates an issue with the parsing logic or unexpected data format
|
||||
logger.error(f"Found {invalid_list_count} rows where '{self.embedding_column}' is not a non-empty list *after parsing*. This should not happen.")
|
||||
invalid_entries = df_no_nulls[~valid_list_mask][self.embedding_column]
|
||||
for i, entry in enumerate(invalid_entries.head(5)):
|
||||
logger.debug(f" Problematic entry example {i+1}: Type={type(entry)}, Value='{str(entry)[:100]}...'")
|
||||
|
||||
df_filtered = df_no_nulls[valid_list_mask].copy()
|
||||
else:
|
||||
df_filtered = df_no_nulls
|
||||
|
||||
final_count = len(df_filtered)
|
||||
# === End Enhanced Debugging ===
|
||||
|
||||
if final_count < initial_count:
|
||||
logger.warning(f"Filtered out {initial_count - final_count} rows total due to missing/invalid '{self.embedding_column}'.")
|
||||
|
||||
if df_filtered.empty:
|
||||
logger.warning(f"No valid embedding data found after filtering. Check data in table {self.operator.table_fqn()} and parsing logic.")
|
||||
return pd.DataFrame(columns=['file_id', 'chunk', self.embedding_column])
|
||||
|
||||
logger.info(f"Proceeding with {final_count} valid rows.")
|
||||
|
||||
# Validate and potentially add placeholder metadata columns AFTER filtering
|
||||
if 'file_id' not in df_filtered.columns:
|
||||
logger.warning("'file_id' column missing, using index instead.")
|
||||
df_filtered['file_id'] = df_filtered.index
|
||||
if 'chunk' not in df_filtered.columns:
|
||||
logger.warning("'chunk' column missing, hover text will be limited.")
|
||||
df_filtered['chunk'] = "N/A"
|
||||
|
||||
return df_filtered
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to fetch data from {self.operator.table_fqn()}.")
|
||||
if 'relation' in str(e) and 'does not exist' in str(e):
|
||||
raise VectorLoaderError(f"Table/Relation not found error: {e}. Check schema/table name and permissions.") from e
|
||||
else:
|
||||
raise VectorLoaderError(f"Database query failed: {e}") from e
|
||||
|
||||
# reduce_dimensions and load_and_reduce methods remain the same as the previous version
|
||||
# (they expect df with a valid list in the embedding column)
|
||||
|
||||
def reduce_dimensions(self, df: pd.DataFrame, tsne_params: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
||||
"""
|
||||
Applies t-SNE to reduce embedding dimensions to 3D.
|
||||
(Code remains the same as previous correct version)
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning("Input DataFrame for reduce_dimensions is empty. Returning empty DataFrame.")
|
||||
empty_df_with_cols = df.copy()
|
||||
for col in ['x', 'y', 'z']:
|
||||
if col not in empty_df_with_cols:
|
||||
empty_df_with_cols[col] = pd.Series(dtype=float)
|
||||
return empty_df_with_cols
|
||||
|
||||
if self.embedding_column not in df.columns:
|
||||
raise VectorLoaderError(f"Embedding column '{self.embedding_column}' missing in DataFrame passed to reduce_dimensions.")
|
||||
|
||||
try:
|
||||
embeddings = np.array(df[self.embedding_column].tolist(), dtype=float)
|
||||
except ValueError as ve:
|
||||
logger.error(f"Failed to convert embedding list to numeric numpy array: {ve}")
|
||||
raise VectorLoaderError(f"Data in '{self.embedding_column}' could not be converted to numeric vectors.") from ve
|
||||
|
||||
if embeddings.ndim != 2:
|
||||
raise VectorLoaderError(f"Embedding data has unexpected dimensions: {embeddings.ndim} (expected 2). Shape: {embeddings.shape}")
|
||||
|
||||
n_samples = embeddings.shape[0]
|
||||
|
||||
if n_samples < 2:
|
||||
logger.warning(f"Found only {n_samples} valid vector(s). t-SNE requires at least 2. Assigning default 3D coordinates.")
|
||||
default_coords = [[0.0, 0.0, 0.0]] * n_samples
|
||||
df[['x', 'y', 'z']] = default_coords
|
||||
return df
|
||||
|
||||
logger.info(f"Applying t-SNE to {n_samples} vectors of dimension {embeddings.shape[1]}...")
|
||||
|
||||
current_tsne_params = self.DEFAULT_TSNE_PARAMS.copy()
|
||||
if tsne_params:
|
||||
current_tsne_params.update(tsne_params)
|
||||
logger.info(f"Using custom t-SNE params: {tsne_params}")
|
||||
|
||||
if n_samples <= current_tsne_params['perplexity']:
|
||||
new_perplexity = max(5.0, float(n_samples - 1))
|
||||
logger.warning(f"Adjusting t-SNE perplexity from {current_tsne_params['perplexity']:.1f} "
|
||||
f"to {new_perplexity:.1f} due to low sample count ({n_samples}).")
|
||||
current_tsne_params['perplexity'] = new_perplexity
|
||||
|
||||
if n_samples * embeddings.shape[1] > 100000 and current_tsne_params['n_iter'] < 1000:
|
||||
logger.warning(f"Dataset size seems large, increasing t-SNE n_iter from {current_tsne_params['n_iter']} to 1000 for better convergence.")
|
||||
current_tsne_params['n_iter'] = 1000
|
||||
|
||||
try:
|
||||
logger.debug(f"Final t-SNE parameters: {current_tsne_params}")
|
||||
tsne = TSNE(**current_tsne_params)
|
||||
reduced_embeddings = tsne.fit_transform(embeddings)
|
||||
|
||||
df[['x', 'y', 'z']] = reduced_embeddings
|
||||
logger.info("t-SNE reduction complete.")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("t-SNE dimensionality reduction failed.")
|
||||
raise VectorLoaderError(f"t-SNE failed: {e}") from e
|
||||
|
||||
|
||||
def load_and_reduce(self, limit: Optional[int] = None, tsne_params: Optional[Dict[str, Any]] = None) -> pd.DataFrame:
|
||||
"""Orchestrates fetching vectors and reducing dimensions."""
|
||||
logger.info("Starting vector load and reduction process...")
|
||||
df_raw_filtered = self.fetch_all_vectors(limit=limit)
|
||||
df_reduced = self.reduce_dimensions(df_raw_filtered, tsne_params=tsne_params)
|
||||
logger.info("Vector load and reduction process finished.")
|
||||
return df_reduced
|
@ -0,0 +1,776 @@
|
||||
# --- START OF FILE visualizer.py ---
|
||||
|
||||
"""Dash/Plotly based 3D visualizer for vector embeddings with tabs, clustering, filtering, and centroid click interaction."""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import dash
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import plotly.graph_objects as go
|
||||
import plotly.express as px
|
||||
from dash import dcc, html, ctx # Import ctx
|
||||
from dash.dependencies import Input, Output, State
|
||||
from dash.exceptions import PreventUpdate
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
# --- Imports ---
|
||||
try:
|
||||
from librarian_vspace.vecembed.embedding_generator import EmbeddingGenerator
|
||||
except ImportError as e:
|
||||
logging.error(f"Import vecembed failed: {e}. Using Dummy.")
|
||||
|
||||
|
||||
# Define dummy class correctly indented
|
||||
class EmbeddingGenerator:
|
||||
"""Dummy class if vecembed import fails."""
|
||||
|
||||
def generate_embedding(*args, **kwargs) -> Tuple[
|
||||
str, None, Any]: # Match expected output type Optional[List[float]]
|
||||
logging.error("Dummy EmbeddingGenerator called.")
|
||||
text_to_embed = kwargs.get("text_to_embed", args[3] if len(args) > 3 else "unknown")
|
||||
identifier = kwargs.get("identifier", args[4] if len(args) > 4 else "unknown")
|
||||
logger.debug(f"Dummy generate_embedding called for text='{text_to_embed}', id='{identifier}'")
|
||||
# Return None for the vector part to match expected type
|
||||
return text_to_embed, None, identifier
|
||||
|
||||
try:
|
||||
from librarian_vspace.vutils.vector_query_loader import VectorQueryLoader as VectorLoader, VectorQueryLoaderError as VectorLoaderError
|
||||
except ImportError as e:
|
||||
logging.error(f"Import loader failed: {e}. Using Dummy.")
|
||||
|
||||
|
||||
# Define dummy classes correctly indented
|
||||
class VectorLoader:
|
||||
"""Dummy class if loader import fails."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
logging.error("Dummy VectorLoader initialized.")
|
||||
pass
|
||||
|
||||
def load_and_reduce(self, *args, **kwargs) -> pd.DataFrame:
|
||||
logging.error("Dummy VectorLoader load_and_reduce called.")
|
||||
return pd.DataFrame() # Return empty DataFrame
|
||||
|
||||
|
||||
class VectorLoaderError(Exception):
|
||||
"""Dummy exception if loader import fails."""
|
||||
pass
|
||||
# --- End Imports ---
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_N_CLUSTERS = 8
|
||||
|
||||
# Opacity constants
|
||||
OPACITY_DEFAULT = 0.8
|
||||
OPACITY_SEARCH_DIMMED = 0.1 # Reduced from 0.6 to 0.3 for better visual distinction
|
||||
OPACITY_SELECTED_CLUSTER = 0.9
|
||||
|
||||
|
||||
class VectorVisualizer:
|
||||
def __init__(self, initial_data: pd.DataFrame,
|
||||
db_schema: str, db_function: str,
|
||||
interface_name: str, model_name: str,
|
||||
embedding_column: str = "embedding",
|
||||
initial_limit: Optional[int] = None,
|
||||
initial_perplexity: float = 30.0,
|
||||
n_clusters: int = DEFAULT_N_CLUSTERS
|
||||
):
|
||||
required_cols = ['x', 'y', 'z', 'file_id', 'chunk', embedding_column]
|
||||
processed_data_json: Optional[str] = None
|
||||
processed_color_map: Dict = {}
|
||||
processed_original_embeddings: np.ndarray = np.array([])
|
||||
processed_cluster_centroids: Dict[str, List[float]] = {}
|
||||
|
||||
self.embedding_column = embedding_column
|
||||
self.n_clusters = n_clusters
|
||||
self.db_schema = db_schema
|
||||
self.db_function = db_function
|
||||
self.model_name = model_name
|
||||
self.limit = initial_limit
|
||||
self.perplexity = initial_perplexity
|
||||
self.interface_name = interface_name
|
||||
# Use the correctly defined EmbeddingGenerator (either real or dummy)
|
||||
self.app = dash.Dash(__name__, suppress_callback_exceptions=True)
|
||||
self.embedding_generator = EmbeddingGenerator() # Instantiated here
|
||||
|
||||
if initial_data.empty or not all(col in initial_data.columns for col in required_cols):
|
||||
logger.warning("Initial DataFrame empty/invalid.")
|
||||
base_cols = required_cols + ['cluster', 'hover_text']
|
||||
initial_df_processed = pd.DataFrame(columns=base_cols)
|
||||
else:
|
||||
try:
|
||||
logger.info("Processing initial data...")
|
||||
df_copy = initial_data.copy()
|
||||
df_after_kmeans, kmeans_color_map = self._run_kmeans(df_copy, self.n_clusters)
|
||||
if not isinstance(df_after_kmeans, pd.DataFrame): raise TypeError("KMeans failed.")
|
||||
processed_color_map = kmeans_color_map
|
||||
df_after_prepare = self._prepare_plot_data(df_after_kmeans)
|
||||
if not isinstance(df_after_prepare, pd.DataFrame): raise TypeError("Prep data failed.")
|
||||
initial_df_processed = df_after_prepare
|
||||
if not initial_df_processed.empty and all(
|
||||
c in initial_df_processed for c in ['x', 'y', 'z', 'cluster']):
|
||||
processed_cluster_centroids = self._calculate_centroids(initial_df_processed)
|
||||
else:
|
||||
logger.warning("Could not calculate initial centroids.")
|
||||
if not initial_df_processed.empty:
|
||||
processed_data_json = initial_df_processed.to_json(date_format='iso', orient='split')
|
||||
else:
|
||||
logger.warning("DataFrame empty after processing.")
|
||||
if not initial_df_processed.empty and self.embedding_column in initial_df_processed.columns:
|
||||
try:
|
||||
emb = initial_df_processed[self.embedding_column].iloc[0]
|
||||
if isinstance(emb, np.ndarray):
|
||||
processed_original_embeddings = np.stack(initial_df_processed[self.embedding_column].values)
|
||||
elif isinstance(emb, list):
|
||||
processed_original_embeddings = np.array(
|
||||
initial_df_processed[self.embedding_column].tolist(), dtype=float)
|
||||
else:
|
||||
raise TypeError("Unsupported embedding type.")
|
||||
except Exception as emb_err:
|
||||
logger.error(f"Embed processing error: {emb_err}"); processed_original_embeddings = np.array([])
|
||||
else:
|
||||
logger.warning("Could not extract original embeddings.")
|
||||
except Exception as e:
|
||||
logger.error(f"Initial processing error: {e}", exc_info=True)
|
||||
processed_data_json, processed_color_map, processed_original_embeddings, processed_cluster_centroids = None, {}, np.array(
|
||||
[]), {}
|
||||
initial_df_processed = pd.DataFrame()
|
||||
|
||||
self.initial_data_json = processed_data_json
|
||||
self.initial_cluster_color_map = processed_color_map
|
||||
self.initial_cluster_centroids = processed_cluster_centroids
|
||||
self.original_embeddings = processed_original_embeddings
|
||||
# Determine slider limits and elbow‑based default
|
||||
self.max_clusters = max(1, len(initial_data))
|
||||
try:
|
||||
self.optimal_clusters = self._estimate_optimal_clusters(processed_original_embeddings,
|
||||
max_k=min(10, self.max_clusters))
|
||||
except Exception:
|
||||
self.optimal_clusters = self.n_clusters
|
||||
# Use elbow result as the current cluster count
|
||||
self.n_clusters = self.optimal_clusters
|
||||
self._build_layout();
|
||||
self._register_callbacks()
|
||||
|
||||
def _run_kmeans(self, df: pd.DataFrame, n_clusters: int) -> Tuple[pd.DataFrame, Dict[str, str]]:
|
||||
"""Runs K-Means, assigns string cluster labels."""
|
||||
default_map = {"-1": "grey"}
|
||||
if df.empty or self.embedding_column not in df.columns: df['cluster'] = "-1"; return df, default_map
|
||||
try:
|
||||
emb_col = df[self.embedding_column]
|
||||
if isinstance(emb_col.iloc[0], np.ndarray):
|
||||
embeddings = np.stack(emb_col.values)
|
||||
elif isinstance(emb_col.iloc[0], list):
|
||||
embeddings = np.array(emb_col.tolist(), dtype=float)
|
||||
else:
|
||||
raise TypeError("Unsupported embedding type.")
|
||||
if embeddings.ndim != 2: raise ValueError("Embeddings must be 2D.")
|
||||
|
||||
eff_clusters = min(n_clusters, embeddings.shape[0])
|
||||
if embeddings.shape[0] < 2 or eff_clusters < 1:
|
||||
lbl = "0" if embeddings.shape[0] > 0 else "-1";
|
||||
df['cluster'] = lbl
|
||||
colors = px.colors.qualitative.Plotly;
|
||||
return df, {lbl: colors[0 % len(colors)]} if lbl == "0" else default_map
|
||||
if eff_clusters == 1: df['cluster'] = "0"; colors = px.colors.qualitative.Plotly; return df, {
|
||||
"0": colors[0 % len(colors)]}
|
||||
|
||||
kmeans = KMeans(n_clusters=eff_clusters, random_state=42, n_init='auto')
|
||||
df['cluster'] = kmeans.fit_predict(embeddings).astype(str)
|
||||
unique_labels = sorted(df['cluster'].unique())
|
||||
colors = px.colors.qualitative.Plotly
|
||||
color_map = {lbl: colors[i % len(colors)] for i, lbl in enumerate(unique_labels)}
|
||||
return df, color_map
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"KMeans input error: {e}"); df['cluster'] = "-1"; return df, default_map
|
||||
except Exception as e:
|
||||
logger.exception("KMeans failed."); df['cluster'] = "-1"; return df, default_map
|
||||
|
||||
def _prepare_plot_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Prepares hover text."""
|
||||
if df.empty: return df
|
||||
if 'cluster' not in df.columns: df['cluster'] = 'N/A'
|
||||
df_copy = df.copy()
|
||||
|
||||
def gen_hover(row):
|
||||
try:
|
||||
return f"ID: {row.get('file_id', 'N/A')}<br>Cluster: {str(row.get('cluster', 'N/A'))}<br>Chunk: {str(row.get('chunk', ''))[:200]}{'...' if len(str(row.get('chunk', ''))) > 200 else ''}"
|
||||
except Exception:
|
||||
return "Hover gen error"
|
||||
|
||||
try:
|
||||
df_copy['hover_text'] = df_copy.apply(gen_hover, axis=1); return df_copy
|
||||
except Exception as e:
|
||||
logger.error(f"Hover gen failed: {e}"); return df
|
||||
|
||||
def _calculate_centroids(self, df: pd.DataFrame) -> Dict[str, List[float]]:
|
||||
"""Calculates 3D centroids."""
|
||||
centroids = {}
|
||||
required = ['x', 'y', 'z', 'cluster'];
|
||||
numeric_cols = ['x', 'y', 'z']
|
||||
if df.empty or not all(col in df.columns for col in required): return centroids
|
||||
df_copy = df.copy();
|
||||
df_copy['cluster'] = df['cluster'].astype(str)
|
||||
for col in numeric_cols:
|
||||
if not pd.api.types.is_numeric_dtype(df_copy[col]):
|
||||
try:
|
||||
df_copy[col] = pd.to_numeric(df_copy[col], errors='coerce')
|
||||
except Exception:
|
||||
logger.error(f"Centroid calc: conv error '{col}'"); return {}
|
||||
if df_copy[col].isnull().any(): logger.warning(f"Centroid calc: NaNs in '{col}'")
|
||||
try:
|
||||
# Calculate mean, drop rows where ALL numeric_cols are NaN, then drop rows where the resulting mean is NaN
|
||||
centroid_data = df_copy.dropna(subset=numeric_cols, how='all').groupby('cluster')[
|
||||
numeric_cols].mean().dropna()
|
||||
return {str(idx): row.tolist() for idx, row in centroid_data.iterrows()}
|
||||
except Exception as e:
|
||||
logger.exception("Centroid calc failed."); return {}
|
||||
|
||||
def _create_base_figure(self) -> go.Figure:
|
||||
"""Creates base Plotly figure."""
|
||||
fig = go.Figure()
|
||||
fig.update_layout(title='3D t-SNE', margin=dict(l=0, r=0, b=0, t=40),
|
||||
scene_camera_eye=dict(x=1.5, y=1.5, z=0.5),
|
||||
scene=dict(xaxis_title='TSNE-1', yaxis_title='TSNE-2', zaxis_title='TSNE-3',
|
||||
aspectmode='data'),
|
||||
legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01, bgcolor='rgba(255,255,255,0.7)'),
|
||||
hovermode='closest')
|
||||
return fig
|
||||
|
||||
def _build_layout(self) -> None:
|
||||
"""Builds the Dash layout."""
|
||||
self.app.layout = html.Div([
|
||||
dcc.Store(id='stored-data', data=self.initial_data_json),
|
||||
dcc.Store(id='cluster-color-map-store', data=self.initial_cluster_color_map),
|
||||
dcc.Store(id='cluster-centroids-store', data=self.initial_cluster_centroids),
|
||||
dcc.Store(id='search-results-store', data=None),
|
||||
dcc.Store(id='selected-cluster-store', data=None), # Store for click state
|
||||
html.H1("Vector Embedding Visualizer"),
|
||||
dcc.Tabs(id="main-tabs", value='tab-vis', children=[
|
||||
dcc.Tab(label='Visualization', value='tab-vis', children=[
|
||||
html.Div([ # Controls
|
||||
html.Div(
|
||||
[html.Button('Reload Data', id='reload-button', n_clicks=0, style={'marginRight': '10px'}),
|
||||
dcc.Input(id='search-input', type='text', placeholder='Search term...', debounce=True,
|
||||
style={'width': '40%', 'marginRight': '5px'}),
|
||||
html.Button('Search', id='search-button', n_clicks=0)],
|
||||
style={'padding': '10px', 'display': 'flex'}),
|
||||
html.Div([html.Label("Similarity:", style={'marginRight': '10px'}),
|
||||
dcc.Slider(id='similarity-slider', min=0, max=1, step=0.01, value=0.0,
|
||||
marks={i / 10: f'{i / 10:.1f}' for i in range(11)},
|
||||
tooltip={"placement": "bottom", "always_visible": True}, disabled=True)],
|
||||
id='slider-container', style={'display': 'none', 'padding': '10px 20px'}),
|
||||
html.Div([
|
||||
html.Label("Clusters:", style={'marginRight': '10px'}),
|
||||
dcc.Slider(
|
||||
id='cluster-slider',
|
||||
min=1,
|
||||
max=self.max_clusters,
|
||||
step=1,
|
||||
value=self.optimal_clusters,
|
||||
marks=self._cluster_marks(),
|
||||
tooltip={'placement': 'bottom', 'always_visible': True}
|
||||
)
|
||||
], style={'padding': '10px 20px'}),
|
||||
html.Div(id='status-output', style={'padding': '10px', 'color': 'blue', 'minHeight': '20px'}),
|
||||
dcc.Loading(id="loading-graph", type="circle",
|
||||
children=dcc.Graph(id='vector-graph', style={'height': '70vh'}))
|
||||
])
|
||||
]),
|
||||
dcc.Tab(label='Settings', value='tab-settings', children=[
|
||||
html.Div([html.H3("Settings"), html.Div([html.Label("Marker Size:", style={'marginRight': '10px'}),
|
||||
dcc.Slider(id='size-slider', min=1, max=15, step=1,
|
||||
value=4,
|
||||
marks={i: str(i) for i in range(1, 16)},
|
||||
tooltip={'placement': 'bottom',
|
||||
'always_visible': True})],
|
||||
style={'padding': '10px 20px'})], style={'padding': '20px'})
|
||||
]),
|
||||
]),
|
||||
])
|
||||
|
||||
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> np.ndarray | float:
|
||||
"""Calculates cosine similarity."""
|
||||
if not isinstance(vec1, np.ndarray): vec1 = np.array(vec1, dtype=float)
|
||||
if not isinstance(vec2, np.ndarray): vec2 = np.array(vec2, dtype=float)
|
||||
if vec1.ndim == 1: vec1 = vec1.reshape(1, -1)
|
||||
if vec2.ndim == 1: vec2 = vec2.reshape(1, -1)
|
||||
if vec1.shape[1] != vec2.shape[1]: raise ValueError("Vector dimension mismatch")
|
||||
norm1 = np.linalg.norm(vec1, axis=1, keepdims=True);
|
||||
norm2 = np.linalg.norm(vec2, axis=1, keepdims=True)
|
||||
z1 = (norm1 == 0).flatten();
|
||||
z2 = (norm2 == 0).flatten()
|
||||
# Handle potential division by zero for zero vectors
|
||||
norm1[z1] = 1.0;
|
||||
norm2[z2] = 1.0
|
||||
sim = np.dot(vec1 / norm1, (vec2 / norm2).T)
|
||||
# Ensure zero vectors result in zero similarity
|
||||
if np.any(z1): sim[:, :] = 0.0;
|
||||
sim[:, z2] = 0.0
|
||||
sim = np.clip(sim, -1.0, 1.0)
|
||||
return sim.item() if sim.size == 1 else sim.flatten()
|
||||
|
||||
def _find_neighbors(self, search_vector: List[float], k: int = 10) -> Optional[Tuple[np.ndarray, np.ndarray]]:
|
||||
"""Finds k nearest neighbors."""
|
||||
if not search_vector or not isinstance(search_vector, list): return None
|
||||
if self.original_embeddings is None or self.original_embeddings.size == 0: return None
|
||||
try:
|
||||
vec = np.array(search_vector, dtype=float)
|
||||
if vec.ndim != 1: raise ValueError("Search vector != 1D.")
|
||||
if self.original_embeddings.ndim != 2: raise ValueError("Embeddings != 2D.")
|
||||
if self.original_embeddings.shape[1] != vec.shape[0]: raise ValueError("Dimension mismatch.")
|
||||
sims = self._cosine_similarity(vec, self.original_embeddings)
|
||||
if not isinstance(sims, np.ndarray) or sims.ndim != 1 or sims.shape[0] != self.original_embeddings.shape[
|
||||
0]: raise TypeError("Similarity calc failed.")
|
||||
k_actual = min(k, len(sims));
|
||||
if k_actual <= 0: return None
|
||||
idx = np.argpartition(sims, -k_actual)[-k_actual:] # Get indices of top k
|
||||
sorted_idx = idx[np.argsort(sims[idx])][::-1] # Sort top k indices by similarity
|
||||
return sorted_idx, sims[sorted_idx]
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Neighbor input error: {e}"); return None
|
||||
except Exception as e:
|
||||
logger.exception(f"Neighbor search error: {e}"); return None
|
||||
|
||||
# --- Callbacks ---
|
||||
def _register_callbacks(self) -> None:
|
||||
"""Sets up Dash callbacks."""
|
||||
|
||||
# --- Callback 1: Reload Button ---
|
||||
@self.app.callback(
|
||||
Output('stored-data', 'data', allow_duplicate=True), Output('cluster-color-map-store', 'data'),
|
||||
Output('cluster-centroids-store', 'data'),
|
||||
Output('status-output', 'children'), Output('search-results-store', 'data', allow_duplicate=True),
|
||||
Output('selected-cluster-store', 'data', allow_duplicate=True),
|
||||
Input('reload-button', 'n_clicks'), prevent_initial_call=True)
|
||||
def handle_reload(n_clicks: int) -> Tuple[Optional[str], Dict, Dict, str, None, None]:
|
||||
if n_clicks == 0: raise PreventUpdate
|
||||
logger.info("Reload triggered...")
|
||||
status = "Reloading...";
|
||||
color_map, centroids, data_json = {}, {}, None;
|
||||
self.original_embeddings = np.array([])
|
||||
try:
|
||||
# Ensure VectorLoader is properly imported or defined (dummy used if import fails)
|
||||
loader = VectorLoader(self.db_schema, self.db_function, self.model_name, self.embedding_column)
|
||||
reduced_data = loader.load_and_reduce(limit=self.limit, tsne_params={"perplexity": self.perplexity})
|
||||
if not isinstance(reduced_data, pd.DataFrame) or reduced_data.empty: raise VectorLoaderError("No data.")
|
||||
|
||||
df_clustered, color_map = self._run_kmeans(reduced_data.copy(), self.n_clusters)
|
||||
if not isinstance(df_clustered, pd.DataFrame): raise TypeError(
|
||||
"KMeans failed post-reload.") # Add check
|
||||
df_final = self._prepare_plot_data(df_clustered)
|
||||
if not isinstance(df_final, pd.DataFrame): raise TypeError(
|
||||
"Prepare plot failed post-reload.") # Add check
|
||||
|
||||
if not df_final.empty and all(c in df_final for c in ['x', 'y', 'z', 'cluster']):
|
||||
centroids = self._calculate_centroids(df_final)
|
||||
else:
|
||||
logger.warning("Could not calculate centroids after reload (missing cols or empty).")
|
||||
|
||||
if not reduced_data.empty and self.embedding_column in reduced_data.columns:
|
||||
try:
|
||||
emb_col = reduced_data[self.embedding_column]
|
||||
# Check type of first element before processing
|
||||
if not emb_col.empty:
|
||||
first_emb = emb_col.iloc[0]
|
||||
if isinstance(first_emb, np.ndarray):
|
||||
self.original_embeddings = np.stack(emb_col.values)
|
||||
elif isinstance(first_emb, list):
|
||||
self.original_embeddings = np.array(emb_col.tolist(), dtype=float)
|
||||
else:
|
||||
raise TypeError(f"Unsupported reloaded embed type: {type(first_emb)}")
|
||||
logger.info(f"Stored reloaded embeddings (shape: {self.original_embeddings.shape}).")
|
||||
else:
|
||||
logger.warning("Embedding column empty during reload storage.")
|
||||
except Exception as e:
|
||||
logger.error(f"Store embed fail: {e}"); self.original_embeddings = np.array([])
|
||||
else:
|
||||
logger.warning("Embedding column missing or df empty during reload storage.")
|
||||
|
||||
if not df_final.empty:
|
||||
data_json = df_final.to_json(date_format='iso',
|
||||
orient='split'); status = f"Reloaded ({len(df_final)} pts)."
|
||||
else:
|
||||
status = "Warning: Reload empty post-process."
|
||||
except (VectorLoaderError, TypeError, Exception) as e:
|
||||
logger.exception(
|
||||
f"Reload error: {e}"); status = f"Error: {e}"; data_json, color_map, centroids = None, {}, {}; self.original_embeddings = np.array(
|
||||
[])
|
||||
return data_json, color_map, centroids, status, None, None
|
||||
|
||||
# --- Callback 1 b: Cluster‑count Slider ---
|
||||
@self.app.callback(
|
||||
Output('stored-data', 'data', allow_duplicate=True),
|
||||
Output('cluster-color-map-store', 'data', allow_duplicate=True),
|
||||
Output('cluster-centroids-store', 'data', allow_duplicate=True),
|
||||
Output('status-output', 'children', allow_duplicate=True),
|
||||
Input('cluster-slider', 'value'),
|
||||
State('stored-data', 'data'),
|
||||
prevent_initial_call=True
|
||||
)
|
||||
def update_n_clusters(k: int, stored_json: str):
|
||||
if not stored_json:
|
||||
raise PreventUpdate
|
||||
|
||||
# Update the visualizer state
|
||||
self.n_clusters = k
|
||||
|
||||
try:
|
||||
df = pd.read_json(StringIO(stored_json), orient='split')
|
||||
df, color_map = self._run_kmeans(df, k)
|
||||
df = self._prepare_plot_data(df)
|
||||
centroids = self._calculate_centroids(df)
|
||||
|
||||
status = f"Cluster count set to {k}."
|
||||
return (df.to_json(date_format='iso', orient='split'),
|
||||
color_map,
|
||||
centroids,
|
||||
status)
|
||||
except Exception as err:
|
||||
logger.error(f"Clustering update failed: {err}")
|
||||
raise PreventUpdate
|
||||
|
||||
# --- Callback 2: Search Button ---
|
||||
@self.app.callback(
|
||||
Output('search-results-store', 'data', allow_duplicate=True),
|
||||
Output('status-output', 'children', allow_duplicate=True),
|
||||
Input('search-button', 'n_clicks'), State('search-input', 'value'), prevent_initial_call=True)
|
||||
def handle_search(n_clicks: int, term: str) -> Tuple[Optional[Dict], str]:
|
||||
if n_clicks == 0 or not term: return None, "Enter search term."
|
||||
logger.info(f"Search: '{term}'");
|
||||
status = f"Embedding '{term}'..."
|
||||
try:
|
||||
if self.original_embeddings is None or self.original_embeddings.size == 0: return None, "Error: No data."
|
||||
_, vec, _ = self.embedding_generator.generate_embedding(self.interface_name, self.model_name, term,
|
||||
"search")
|
||||
if vec is None: return None, f"Error: Embed failed."
|
||||
status = f"Finding neighbors...";
|
||||
neighbors = self._find_neighbors(vec, k=20)
|
||||
if neighbors is None: return None, f"No neighbors found."
|
||||
idx, sims = neighbors;
|
||||
results = {"indices": idx.tolist(), "similarities": sims.tolist(), "term": term}
|
||||
status = f"Found {len(idx)} neighbors.";
|
||||
return results, status
|
||||
except Exception as e:
|
||||
logger.exception("Search error."); return None, f"Error: {e}"
|
||||
|
||||
# --- Callback 3: Slider Visibility ---
|
||||
@self.app.callback(
|
||||
Output('slider-container', 'style'), Output('similarity-slider', 'disabled'),
|
||||
Output('similarity-slider', 'value'),
|
||||
Input('search-results-store', 'data'), prevent_initial_call=True)
|
||||
def update_slider_visibility(res: Optional[Dict]) -> Tuple[Dict, bool, float]:
|
||||
show = res and isinstance(res, dict) and "indices" in res
|
||||
style = {'display': 'block' if show else 'none', 'padding': '10px 20px'}
|
||||
return style, not show, 0.0
|
||||
|
||||
# --- Callback 4: Graph Update (Main Logic with clickData fix and logging) ---
|
||||
@self.app.callback(
|
||||
Output('vector-graph', 'figure'), Output('selected-cluster-store', 'data'),
|
||||
Output('status-output', 'children', allow_duplicate=True),
|
||||
Input('stored-data', 'data'), Input('cluster-color-map-store', 'data'),
|
||||
Input('cluster-centroids-store', 'data'),
|
||||
Input('search-results-store', 'data'), Input('similarity-slider', 'value'), Input('size-slider', 'value'),
|
||||
Input('vector-graph', 'clickData'), # Input for clicks
|
||||
State('selected-cluster-store', 'data'), # Get current selection
|
||||
prevent_initial_call='initial_duplicate' # Allow initial run
|
||||
)
|
||||
def update_graph(stored_data_json: Optional[str], cluster_color_map: Optional[Dict],
|
||||
cluster_centroids: Optional[Dict[str, List[float]]],
|
||||
search_results: Optional[Dict], similarity_threshold: float, size_value: int,
|
||||
click_data: Optional[Dict],
|
||||
current_selected_cluster: Optional[str]) -> Tuple[go.Figure, Optional[str], str]:
|
||||
|
||||
fig = self._create_base_figure();
|
||||
status_msg = "";
|
||||
new_selected_cluster = current_selected_cluster
|
||||
trigger = ctx.triggered_id if ctx.triggered else "Initial"
|
||||
logger.debug(f"--- Graph Update | Trigger: {trigger} | CurrentSel: {current_selected_cluster} ---")
|
||||
|
||||
# --- Data Load & Validation ---
|
||||
if not stored_data_json: return fig, None, "No data."
|
||||
try:
|
||||
df = pd.read_json(StringIO(stored_data_json), orient='split')
|
||||
if df.empty: return fig, None, "Empty data."
|
||||
required = ['x', 'y', 'z', 'cluster', 'hover_text'];
|
||||
assert all(col in df.columns for col in required)
|
||||
df['cluster'] = df['cluster'].astype(str)
|
||||
color_map = cluster_color_map if isinstance(cluster_color_map, dict) else {}
|
||||
centroids = cluster_centroids if isinstance(cluster_centroids, dict) else {}
|
||||
if not color_map:
|
||||
logger.warning("Missing color map, generating default.")
|
||||
unique_clusters = df['cluster'].unique();
|
||||
colors = px.colors.qualitative.Plotly
|
||||
color_map = {str(c): colors[i % len(colors)] for i, c in enumerate(unique_clusters)} or {
|
||||
'0': 'grey'}
|
||||
|
||||
# Calculate overall data range (handle potential NaNs/Infs in full data)
|
||||
df_finite = df[['x', 'y', 'z']].replace([np.inf, -np.inf], np.nan).dropna()
|
||||
if not df_finite.empty:
|
||||
overall_x_min, overall_x_max = df_finite['x'].min(), df_finite['x'].max()
|
||||
overall_y_min, overall_y_max = df_finite['y'].min(), df_finite['y'].max()
|
||||
overall_z_min, overall_z_max = df_finite['z'].min(), df_finite['z'].max()
|
||||
logger.debug(
|
||||
f"Overall Finite Range: X=[{overall_x_min:.2f}, {overall_x_max:.2f}], Y=[{overall_y_min:.2f}, {overall_y_max:.2f}], Z=[{overall_z_min:.2f}, {overall_z_max:.2f}]")
|
||||
else:
|
||||
logger.warning("No finite data points found in the dataset to calculate overall range.")
|
||||
overall_x_min, overall_x_max = -10, 10 # Default ranges if no finite data
|
||||
overall_y_min, overall_y_max = -10, 10
|
||||
overall_z_min, overall_z_max = -10, 10
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Graph data error."); return fig, current_selected_cluster, f"Error: {e}"
|
||||
|
||||
# --- Click Processing ---
|
||||
if trigger == 'vector-graph':
|
||||
logger.debug(f"Click Data Received: {click_data}")
|
||||
if click_data and 'points' in click_data and click_data['points']:
|
||||
point_data = click_data['points'][0]
|
||||
clicked_customdata = point_data.get('customdata');
|
||||
clicked_text = point_data.get('text', '')
|
||||
logger.debug(f"Clicked Point Customdata: {clicked_customdata}");
|
||||
logger.debug(f"Clicked Point Text: '{clicked_text}'")
|
||||
is_centroid_click = False;
|
||||
clicked_cluster_id = None
|
||||
if isinstance(clicked_customdata, list) and len(clicked_customdata) > 0: clicked_customdata = \
|
||||
clicked_customdata[0]
|
||||
if isinstance(clicked_customdata, (str, int)): # Accept string or int cluster IDs
|
||||
is_centroid_click = True;
|
||||
clicked_cluster_id = str(clicked_customdata);
|
||||
logger.info(f"Centroid Click Parsed via customdata: Cluster '{clicked_cluster_id}'")
|
||||
elif isinstance(clicked_text, str) and clicked_text.startswith("Centroid: Cluster "):
|
||||
try:
|
||||
clicked_cluster_id = clicked_text.split("Centroid: Cluster ")[
|
||||
1]; is_centroid_click = True; logger.info(
|
||||
f"Centroid Click Parsed via text: Cluster '{clicked_cluster_id}'")
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"Failed text parse: {parse_err}")
|
||||
if is_centroid_click and clicked_cluster_id is not None:
|
||||
if current_selected_cluster == clicked_cluster_id:
|
||||
new_selected_cluster = None; status_msg = "Cluster view reset."; logger.info("Deselecting.")
|
||||
else:
|
||||
new_selected_cluster = clicked_cluster_id; status_msg = f"Showing Cluster {new_selected_cluster}."; logger.info(
|
||||
f"Selecting {new_selected_cluster}.")
|
||||
elif not is_centroid_click and current_selected_cluster is not None:
|
||||
new_selected_cluster = None; status_msg = "Cluster view reset."; logger.info("Deselecting.")
|
||||
else: # Click background
|
||||
if current_selected_cluster is not None: new_selected_cluster = None; status_msg = "Cluster view reset."; logger.info(
|
||||
"Deselecting.")
|
||||
logger.debug(f"Click Result: new_selected_cluster = {new_selected_cluster}")
|
||||
else:
|
||||
logger.debug("No click trigger.")
|
||||
|
||||
# --- Data Filtering ---
|
||||
active_selection_id = new_selected_cluster
|
||||
df_to_plot = df.copy();
|
||||
centroids_to_plot = centroids.copy()
|
||||
logger.debug(f"Filtering based on active_selection_id: {active_selection_id}")
|
||||
if active_selection_id is not None:
|
||||
df_to_plot = df_to_plot[df_to_plot['cluster'] == active_selection_id]
|
||||
centroids_to_plot = {cid: coords for cid, coords in centroids_to_plot.items() if
|
||||
cid == active_selection_id}
|
||||
logger.debug(f"Filtered DF rows: {len(df_to_plot)}")
|
||||
if not df_to_plot.empty:
|
||||
logger.debug(f"Coordinates of filtered points:\n{df_to_plot[['x', 'y', 'z']]}")
|
||||
else:
|
||||
logger.warning("Filtered DataFrame is empty.")
|
||||
|
||||
# --- Search Highlighting ---
|
||||
search_highlight_mask = np.zeros(len(df_to_plot), dtype=bool)
|
||||
search_term = None;
|
||||
is_search_active = False;
|
||||
highlight_sims = {}
|
||||
if search_results and isinstance(search_results, dict) and "indices" in search_results:
|
||||
is_search_active = True;
|
||||
search_term = search_results.get("term", "N/A")
|
||||
orig_indices = search_results.get("indices", []);
|
||||
orig_sims = search_results.get("similarities", [])
|
||||
if not df_to_plot.empty:
|
||||
orig_to_current_map = {orig_idx: current_idx for current_idx, orig_idx in
|
||||
enumerate(df_to_plot.index)}
|
||||
current_indices_hl = [orig_to_current_map[oi] for i, oi in enumerate(orig_indices) if
|
||||
i < len(orig_sims) and orig_sims[
|
||||
i] >= similarity_threshold and oi in orig_to_current_map]
|
||||
if current_indices_hl:
|
||||
search_highlight_mask[current_indices_hl] = True
|
||||
for i, orig_idx in enumerate(orig_indices):
|
||||
if i < len(orig_sims) and orig_sims[
|
||||
i] >= similarity_threshold and orig_idx in orig_to_current_map:
|
||||
highlight_sims[orig_to_current_map[orig_idx]] = orig_sims[i]
|
||||
else:
|
||||
logger.warning("Cannot apply search highlighting - filtered df empty.")
|
||||
|
||||
# --- Plotting ---
|
||||
df_search_hl = df_to_plot[search_highlight_mask];
|
||||
df_normal = df_to_plot[~search_highlight_mask]
|
||||
base_size = size_value;
|
||||
normal_op = OPACITY_SELECTED_CLUSTER if active_selection_id else (
|
||||
OPACITY_SEARCH_DIMMED if is_search_active else OPACITY_DEFAULT)
|
||||
|
||||
# --- Add Dummy Points if needed ---
|
||||
num_points_to_plot = len(df_normal) + len(df_search_hl)
|
||||
if active_selection_id is not None and num_points_to_plot <= 2:
|
||||
logger.info(
|
||||
f"Adding dummy invisible points to aid auto-ranging for cluster {active_selection_id} (points={num_points_to_plot}).")
|
||||
# Use overall range calculated earlier
|
||||
dummy_x = [overall_x_min, overall_x_max]
|
||||
dummy_y = [overall_y_min, overall_y_max]
|
||||
dummy_z = [overall_z_min, overall_z_max]
|
||||
# Ensure dummy points are valid numbers (in case overall range calc failed)
|
||||
if np.isfinite(dummy_x + dummy_y + dummy_z).all():
|
||||
fig.add_trace(go.Scatter3d(
|
||||
x=dummy_x, y=dummy_y, z=dummy_z,
|
||||
mode='markers', marker=dict(size=1, opacity=0), # Invisible
|
||||
hoverinfo='skip', showlegend=False, name='_dummy_'
|
||||
))
|
||||
else:
|
||||
logger.warning("Could not add dummy points because overall range contained non-finite values.")
|
||||
|
||||
# Plot Normal Points
|
||||
if not df_normal.empty:
|
||||
finite_mask_normal = np.isfinite(df_normal[['x', 'y', 'z']]).all(axis=1)
|
||||
df_normal_finite = df_normal[finite_mask_normal]
|
||||
if not df_normal_finite.empty:
|
||||
logger.debug(f"Plotting df_normal (len={len(df_normal_finite)}).")
|
||||
colors = df_normal_finite['cluster'].map(color_map).fillna('darkgrey')
|
||||
name = 'Embeddings' if active_selection_id is None else f'Cluster {active_selection_id}'
|
||||
fig.add_trace(
|
||||
go.Scatter3d(x=df_normal_finite['x'], y=df_normal_finite['y'], z=df_normal_finite['z'],
|
||||
mode='markers',
|
||||
marker=dict(color=colors, size=base_size, opacity=normal_op, line=dict(width=0.5)),
|
||||
text=df_normal_finite['hover_text'], hoverinfo='text', name=name))
|
||||
else:
|
||||
logger.warning("No finite normal points to plot.")
|
||||
|
||||
# Plot Search Highlighted Points
|
||||
if not df_search_hl.empty:
|
||||
finite_mask_search = np.isfinite(df_search_hl[['x', 'y', 'z']]).all(axis=1)
|
||||
df_search_hl_finite = df_search_hl[finite_mask_search]
|
||||
if not df_search_hl_finite.empty:
|
||||
hl_size = max(base_size * 1.5, base_size + 2);
|
||||
hl_texts = []
|
||||
# Need mapping from df_search_hl_finite index back to df_to_plot positional index for sims
|
||||
positions_in_df_to_plot = df_to_plot.index.get_indexer_for(df_search_hl_finite.index)
|
||||
|
||||
for i, (global_index, row) in enumerate(df_search_hl_finite.iterrows()):
|
||||
pos = positions_in_df_to_plot[i] # Get original position in df_to_plot
|
||||
sim = highlight_sims.get(pos, float('nan'))
|
||||
sim_txt = f"{sim:.4f}" if not np.isnan(sim) else "N/A"
|
||||
hl_texts.append(f"{row['hover_text']}<br><b>Sim: {sim_txt}</b>")
|
||||
|
||||
fig.add_trace(
|
||||
go.Scatter3d(x=df_search_hl_finite['x'], y=df_search_hl_finite['y'], z=df_search_hl_finite['z'],
|
||||
mode='markers',
|
||||
marker=dict(color='red', size=hl_size, opacity=1.0, symbol='diamond',
|
||||
line=dict(color='black', width=1)), text=hl_texts, hoverinfo='text',
|
||||
name=f'Search Neighbors'))
|
||||
if not df_search_hl_finite[['x', 'y', 'z']].isnull().values.any(): # Search Centroid
|
||||
try:
|
||||
sc = df_search_hl_finite[['x', 'y', 'z']].mean().values; fig.add_trace(
|
||||
go.Scatter3d(x=[sc[0]], y=[sc[1]], z=[sc[2]], mode='markers',
|
||||
marker=dict(color='magenta', size=max(hl_size, 10), symbol='cross',
|
||||
line=dict(width=1)), text=f"Search: '{search_term}' Centroid",
|
||||
hoverinfo='text', name='Search Centroid'))
|
||||
except Exception as e:
|
||||
logger.warning(f"Search centroid plot fail: {e}")
|
||||
else:
|
||||
logger.warning("No finite search highlighted points to plot.")
|
||||
|
||||
# Plot Centroids (filtered)
|
||||
if centroids_to_plot:
|
||||
cent_size = base_size + 1;
|
||||
logger.debug(f"Plotting centroids: {list(centroids_to_plot.keys())}")
|
||||
for cid, coords in centroids_to_plot.items():
|
||||
if isinstance(coords, list) and len(coords) == 3:
|
||||
logger.debug(f"Plotting Centroid {cid} at coords: {coords}")
|
||||
if np.isnan(coords).any() or np.isinf(coords).any(): logger.error(
|
||||
f"!!! Centroid {cid} NaN/Inf coords !!!"); continue
|
||||
color = color_map.get(str(cid), 'grey');
|
||||
name = f"Centroid {cid}";
|
||||
hover_txt = f"Centroid: Cluster {cid}"
|
||||
fig.add_trace(go.Scatter3d(
|
||||
x=[coords[0]], y=[coords[1]], z=[coords[2]], mode='markers',
|
||||
marker=dict(color=color, size=cent_size, symbol='circle', opacity=0.9,
|
||||
line=dict(color='black', width=1.5)),
|
||||
customdata=[str(cid)], text=hover_txt, hoverinfo='text', name=name,
|
||||
legendgroup="centroids", showlegend=True
|
||||
))
|
||||
else:
|
||||
logger.warning(f"Invalid centroid data for {cid}")
|
||||
|
||||
# --- Final Layout & Status ---
|
||||
title = f"3D t-SNE ({len(df)} points)"
|
||||
if active_selection_id is not None:
|
||||
title = f"Cluster {active_selection_id} ({len(df_to_plot)} points)"
|
||||
elif is_search_active:
|
||||
title = f"3D t-SNE - Search: '{search_term}'"
|
||||
if active_selection_id and is_search_active: title += f" - Search: '{search_term}'"
|
||||
base_layout = self._create_base_figure().layout
|
||||
fig.update_layout(
|
||||
title=title, legend_title_text='Legend', legend=base_layout.legend,
|
||||
scene=base_layout.scene # Use base scene settings (includes aspectmode='data')
|
||||
# Rely on auto-ranging (potentially helped by dummy points if added)
|
||||
)
|
||||
|
||||
final_status = status_msg
|
||||
if not final_status: # Default status
|
||||
base = f"{len(df_to_plot)} points shown."
|
||||
if active_selection_id: base = f"Cluster {active_selection_id}: {len(df_to_plot)} points."
|
||||
final_status = base;
|
||||
if is_search_active: final_status += f" (Search: '{search_term}')"
|
||||
|
||||
return fig, new_selected_cluster, final_status
|
||||
|
||||
def run(self, host: str = "127.0.0.1", port: int = 8050, debug: bool = False) -> None:
|
||||
"""Starts the Dash server."""
|
||||
logger.info(f"Starting Dash server on http://{host}:{port}")
|
||||
try:
|
||||
self.app.run(host=host, port=port, debug=debug)
|
||||
except OSError as e:
|
||||
logger.error(f"Server start failed: {e}. Port {port} busy?")
|
||||
except Exception as e:
|
||||
logger.exception(f"Server error: {e}")
|
||||
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
# >>> Helpers for automatic cluster‐count selection <<<
|
||||
# ──────────────────────────────────────────────────────────────
|
||||
def _estimate_optimal_clusters(self, embeddings: np.ndarray, max_k: int = 10) -> int:
|
||||
"""
|
||||
Estimate an optimal number of clusters using a quick elbow heuristic.
|
||||
Computes K‑means inertia for k = 1…max_k and picks the k that is farthest
|
||||
from the straight line connecting (1, inertia₁) and (max_k, inertiaₘₐₓ).
|
||||
"""
|
||||
if embeddings is None or embeddings.size == 0:
|
||||
return 1
|
||||
n_samples = embeddings.shape[0]
|
||||
if n_samples < 3:
|
||||
return 1
|
||||
|
||||
max_k = min(max_k, n_samples)
|
||||
inertias = []
|
||||
for k in range(1, max_k + 1):
|
||||
km = KMeans(n_clusters=k, random_state=42, n_init="auto").fit(embeddings)
|
||||
inertias.append(km.inertia_)
|
||||
|
||||
# distance from each point to the line between first and last
|
||||
x = np.arange(1, max_k + 1)
|
||||
x1, y1 = 1, inertias[0]
|
||||
x2, y2 = max_k, inertias[-1]
|
||||
numerator = np.abs((y2 - y1) * x - (x2 - x1) * np.array(inertias) + x2 * y1 - y2 * x1)
|
||||
denominator = np.sqrt((y2 - y1) ** 2 + (x2 - x1) ** 2)
|
||||
elbow_idx = int(np.argmax(numerator / denominator))
|
||||
return elbow_idx + 1 # since k starts at 1
|
||||
|
||||
def _cluster_marks(self) -> Dict[int, str]:
|
||||
"""Generate tick marks for the cluster-count slider."""
|
||||
if self.max_clusters <= 15:
|
||||
return {i: str(i) for i in range(1, self.max_clusters + 1)}
|
||||
# Show first, optimal, and max for large data sets
|
||||
return {1: "1", self.optimal_clusters: str(self.optimal_clusters), self.max_clusters: str(self.max_clusters)}
|
||||
|
||||
# --- END OF FILE visualizer.py ---
|
@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from librarian_vspace!"
|
@ -0,0 +1,93 @@
|
||||
|
||||
"""TsneExportWorker – Prefect worker that generates a t‑SNE JSON export.
|
||||
|
||||
It wraps Vspace.get_tnse → vecview.get_tsne_json, writes the JSON to a file,
|
||||
stages it, and returns the file path.
|
||||
|
||||
Minimal Pydantic payload models are defined locally to avoid extra deps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import get_run_logger
|
||||
from pydantic import BaseModel
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
from librarian_vspace.vecview.vecview import get_tsne_json
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
def _safe_get_logger(name: str):
|
||||
try:
|
||||
return get_run_logger()
|
||||
except Exception:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Pydantic payloads
|
||||
# ------------------------------------------------------------------ #
|
||||
class TsneExportInput(BaseModel):
|
||||
course_id: int
|
||||
limit: Optional[int] = None
|
||||
perplexity: float = 30.0
|
||||
db_schema: str = "librarian"
|
||||
rpc_function: str = "pdf_chunking"
|
||||
embed_model: str = "snowflake-arctic-embed2"
|
||||
embedding_column: str = "embedding"
|
||||
base_output_dir: Optional[Path] = None # where to place JSON file
|
||||
|
||||
|
||||
class TsneExportOutput(BaseModel):
|
||||
json_path: Path
|
||||
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
class TsneExportWorker(Worker[TsneExportInput, TsneExportOutput]):
|
||||
"""Runs the t‑SNE export inside a Prefect worker.""" # noqa: D401
|
||||
|
||||
input_model = TsneExportInput
|
||||
output_model = TsneExportOutput
|
||||
|
||||
async def __run__(self, payload: TsneExportInput) -> TsneExportOutput:
|
||||
logger = _safe_get_logger(self.worker_name)
|
||||
logger.info("🔨 %s startet (payload=%r)", self.worker_name, payload)
|
||||
|
||||
# Run get_tsne_json in a thread
|
||||
data_json = await asyncio.to_thread(
|
||||
get_tsne_json,
|
||||
db_schema=payload.db_schema,
|
||||
db_function=payload.rpc_function,
|
||||
model_name=payload.embed_model,
|
||||
limit=payload.limit,
|
||||
course_id=payload.course_id,
|
||||
perplexity=payload.perplexity,
|
||||
embedding_column=payload.embedding_column,
|
||||
)
|
||||
|
||||
# Determine output file
|
||||
if payload.base_output_dir:
|
||||
out_dir = Path(payload.base_output_dir).expanduser()
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
json_path = out_dir / f"{payload.course_id}_tsne.json"
|
||||
else:
|
||||
tf = tempfile.NamedTemporaryFile(
|
||||
mode="w+", suffix="_tsne.json", prefix="vspace_", delete=False
|
||||
)
|
||||
json_path = Path(tf.name)
|
||||
|
||||
# Write JSON to file
|
||||
json_path.write_text(data_json, encoding="utf-8")
|
||||
|
||||
# Stage file for Prefect
|
||||
self.stage(json_path, new_name=json_path.name)
|
||||
|
||||
result = TsneExportOutput(json_path=json_path)
|
||||
logger.info("✅ %s fertig: %r", self.worker_name, result)
|
||||
return result
|
@ -0,0 +1,104 @@
|
||||
|
||||
"""Utility functions to fetch vectors from Supabase, apply t‑SNE, add simple K‑means
|
||||
clustering and hover text – prepared exactly like the `VectorVisualizer` expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.cluster import KMeans
|
||||
|
||||
from librarian_vspace.vutils.vector_query_loader import VectorQueryLoader, VectorQueryLoaderError
|
||||
from librarian_vspace.models.tsne_model import TSNEPoint, TSNEData
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_N_CLUSTERS = 8
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Internal helpers (kept minimal – no extra bells & whistles)
|
||||
# --------------------------------------------------------------------- #
|
||||
def _run_kmeans(df: pd.DataFrame, *, embedding_column: str, k: int = DEFAULT_N_CLUSTERS) -> pd.DataFrame:
|
||||
"""Adds a 'cluster' column using K‑means (string labels)."""
|
||||
if df.empty or embedding_column not in df.columns:
|
||||
df['cluster'] = "-1"
|
||||
return df
|
||||
|
||||
embeddings = np.array(df[embedding_column].tolist(), dtype=float)
|
||||
n_samples = embeddings.shape[0]
|
||||
k = max(1, min(k, n_samples)) # ensure 1 ≤ k ≤ n_samples
|
||||
if n_samples < 2:
|
||||
df['cluster'] = "0"
|
||||
return df
|
||||
|
||||
km = KMeans(n_clusters=k, random_state=42, n_init='auto')
|
||||
df['cluster'] = km.fit_predict(embeddings).astype(str)
|
||||
return df
|
||||
|
||||
|
||||
def _add_hover(df: pd.DataFrame) -> pd.DataFrame:
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
df = df.copy()
|
||||
|
||||
def _hover(row):
|
||||
preview = str(row.get('chunk', ''))[:200]
|
||||
if len(str(row.get('chunk', ''))) > 200:
|
||||
preview += "..."
|
||||
return (
|
||||
f"ID: {row.get('file_id', 'N/A')}<br>"
|
||||
f"Cluster: {row.get('cluster', 'N/A')}<br>"
|
||||
f"Chunk: {preview}"
|
||||
)
|
||||
|
||||
df['hover_text'] = df.apply(_hover, axis=1)
|
||||
return df
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Public helpers
|
||||
# --------------------------------------------------------------------- #
|
||||
def get_tsne_dataframe(
|
||||
db_schema: str,
|
||||
db_function: str,
|
||||
model_name: str,
|
||||
*,
|
||||
limit: Optional[int] = None,
|
||||
course_id: Optional[int] = None,
|
||||
perplexity: float = 30.0,
|
||||
embedding_column: str = "embedding",
|
||||
n_clusters: int = DEFAULT_N_CLUSTERS,
|
||||
) -> pd.DataFrame:
|
||||
"""Returns a pandas DataFrame with tsne (x,y,z) & metadata ready for plotting."""
|
||||
loader = VectorQueryLoader(db_schema, db_function, model_name, embedding_column)
|
||||
df = loader.load_and_reduce(
|
||||
limit=limit,
|
||||
course_id=course_id,
|
||||
tsne_params={"perplexity": perplexity},
|
||||
)
|
||||
|
||||
if df.empty:
|
||||
return df
|
||||
|
||||
df = _run_kmeans(df, embedding_column=embedding_column, k=n_clusters)
|
||||
df = _add_hover(df)
|
||||
return df
|
||||
|
||||
|
||||
def get_tsne_json(**kwargs) -> str:
|
||||
"""Convenience wrapper returning DataFrame as JSON (orient='split')."""
|
||||
df = get_tsne_dataframe(**kwargs)
|
||||
return df.to_json(date_format='iso', orient='split')
|
||||
|
||||
|
||||
def get_tsne_response(**kwargs) -> TSNEData:
|
||||
"""Returns a validated `TSNEResponse` Pydantic model."""
|
||||
df = get_tsne_dataframe(**kwargs)
|
||||
points: List[TSNEPoint] = [TSNEPoint(**row.dropna().to_dict()) for _, row in df.iterrows()]
|
||||
return TSNEData(course_id=kwargs.get('course_id'), total=len(points), points=points)
|
@ -0,0 +1,10 @@
|
||||
"""
|
||||
vquery package for high-level read operations against vector tables.
|
||||
"""
|
||||
from .query import VectorQuery
|
||||
|
||||
__all__ = ["VectorQuery"] # Defines the public interface of the package
|
||||
|
||||
# Optional: Add package-level logging setup if desired, but often handled by the application
|
||||
# import logging
|
||||
# logging.getLogger(__name__).addHandler(logging.NullHandler())
|
@ -0,0 +1,438 @@
|
||||
# --- START OF FILE cluster_export.py (Refactored & Workaround - Import Updated) ---
|
||||
|
||||
"""
|
||||
cluster_export.py – Generate IVFFlat‑equivalent clusters from Supabase/Vectorbase
|
||||
pgvector data and export each cluster’s chunks to Markdown.
|
||||
|
||||
This version fetches vectors filtered by course ID at the database level using
|
||||
VectorQueryLoader, performs k-means clustering, and exports to Markdown.
|
||||
|
||||
Includes automatic k-downsizing.
|
||||
|
||||
Environment variables (used by the script entry point)
|
||||
---------------------
|
||||
* **Vectorbase credentials** (auto‑mapped to Supabase):
|
||||
* `VECTORBASE_URL` → `SUPABASE_URL`
|
||||
* `VECTORBASE_API_KEY` → `SUPABASE_KEY`
|
||||
* `VECTORBASE_USER_UUID` → `SUPABASE_USER_UUID` (optional)
|
||||
* **Embedding/table config**
|
||||
* `VECTOR_SCHEMA` – Postgres schema (default `librarian`)
|
||||
* `VECTOR_FUNCTION` – RPC / Postgres function name (optional)
|
||||
* `EMBED_MODEL` – embedding model label (default `snowflake-arctic-embed2`)
|
||||
* **Clustering hyper‑parameters**
|
||||
* `K` – requested number of clusters / IVFFlat *nlist* (default 128)
|
||||
* `TRAIN_SAMPLE` – how many rows to feed into k‑means (default 20 000, but
|
||||
capped at the table size)
|
||||
* **Export**
|
||||
* `OUTPUT_DIR` – directory for the generated Markdown files (default
|
||||
`./cluster_md`)
|
||||
* `CLUSTER_COURSE_ID` - Optional course ID to filter vectors (used by script)
|
||||
|
||||
|
||||
Usage
|
||||
~~~~~
|
||||
# Via script entry point
|
||||
export VECTORBASE_URL="https://xyz.vectorbase.co"
|
||||
export VECTORBASE_API_KEY="service_role_key"
|
||||
export VECTOR_SCHEMA=librarian
|
||||
export EMBED_MODEL=snowflake-arctic-embed2
|
||||
export CLUSTER_COURSE_ID=123 # Optional filtering
|
||||
export K=64
|
||||
python -m librarian_vspace.vquery.cluster_export
|
||||
|
||||
# As a callable function
|
||||
from librarian_vspace.vquery.cluster_export import run_cluster_export_job
|
||||
output_path = run_cluster_export_job(course_id=456, output_dir="/tmp/clusters_456", ...)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any, Union # Added Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn.metrics import pairwise_distances_argmin_min
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Map Vectorbase credential names → Supabase names expected by loader code
|
||||
# ---------------------------------------------------------------------------
|
||||
_ALIAS_ENV_MAP = {
|
||||
"VECTORBASE_URL": "SUPABASE_URL",
|
||||
"VECTORBASE_API_KEY": "SUPABASE_KEY",
|
||||
"VECTORBASE_USER_UUID": "SUPABASE_USER_UUID", # optional
|
||||
}
|
||||
for src, dest in _ALIAS_ENV_MAP.items():
|
||||
if dest not in os.environ and src in os.environ:
|
||||
os.environ[dest] = os.environ[src]
|
||||
|
||||
# Import the NEW data‑loading helper with filtering capabilities
|
||||
try:
|
||||
# --- FIX: Import VectorQueryLoader from vutils ---
|
||||
from librarian_vspace.vutils.vector_query_loader import VectorQueryLoader, VectorQueryLoaderError
|
||||
# VectorLoaderError is now VectorQueryLoaderError
|
||||
# --- END FIX ---
|
||||
except ImportError as e:
|
||||
# Keep the original script's error handling for standalone use
|
||||
sys.stderr.write(
|
||||
"\n[ERROR] Could not import VectorQueryLoader – check PYTHONPATH. "
|
||||
f"Original error: {e}\n"
|
||||
)
|
||||
# For callable use, we should raise an ImportError or custom exception
|
||||
raise ImportError(f"Could not import VectorQueryLoader: {e}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Logging setup (used by both script and callable function)
|
||||
# ---------------------------------------------------------------------------
|
||||
# This basicConfig runs when the module is imported.
|
||||
# Callers might want to configure logging before importing.
|
||||
# If logging is already configured, basicConfig does nothing.
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
logger = logging.getLogger(__name__) # Use __name__ for module-specific logger
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper – JSON dump for centroid in YAML front‑matter
|
||||
# ---------------------------------------------------------------------------
|
||||
def centroid_to_json(vec: np.ndarray) -> str:
|
||||
"""Converts a numpy vector to a JSON string suitable for YAML frontmatter."""
|
||||
return json.dumps([float(x) for x in vec], ensure_ascii=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main clustering and export logic as a callable function
|
||||
# ---------------------------------------------------------------------------
|
||||
def run_cluster_export_job(
|
||||
course_id: Optional[int] = None, # Added course_id parameter
|
||||
output_dir: Union[str, Path] = "./cluster_md", # Output directory parameter
|
||||
schema: str = "librarian",
|
||||
rpc_function: str = "pdf_chunking", # Default to actual function name
|
||||
model: str = "snowflake-arctic-embed2",
|
||||
k_clusters: int = 128, # Requested number of clusters (k)
|
||||
train_sample_size: int = 20000, # Sample size for K-means training
|
||||
embedding_column: str = "embedding" # Added embedding column parameter
|
||||
) -> Path:
|
||||
"""
|
||||
Fetches vectors, performs K-means clustering, and exports clustered chunks to Markdown.
|
||||
|
||||
Args:
|
||||
course_id: Optional ID to filter vectors belonging to a specific course.
|
||||
output_dir: Directory path where the cluster Markdown files will be saved.
|
||||
schema: Postgres schema containing the vector table.
|
||||
rpc_function: Optional RPC function name used by VectorQueryLoader (needed for table lookup).
|
||||
model: Embedding model label used by VectorQueryLoader (needed for table lookup).
|
||||
k_clusters: The requested number of clusters (k). Will be downsized if fewer
|
||||
vectors are available.
|
||||
train_sample_size: The maximum number of vectors to use for K-means training.
|
||||
Capped by the total number of vectors fetched.
|
||||
embedding_column: The name of the column containing the vector embeddings.
|
||||
|
||||
Returns:
|
||||
The absolute path to the output directory.
|
||||
|
||||
Raises:
|
||||
VectorQueryLoaderError: If vector loading fails.
|
||||
RuntimeError: If no embeddings are retrieved or training sample is empty after filtering.
|
||||
Exception: For other errors during clustering or export.
|
||||
"""
|
||||
output_path = Path(output_dir).expanduser().resolve() # Resolve path early
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Writing Markdown files to %s", output_path)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fetch embeddings - Now using VectorQueryLoader with filtering
|
||||
# ---------------------------------------------------------------------------
|
||||
try:
|
||||
# Use parameters for loader config
|
||||
# --- FIX: Instantiate VectorQueryLoader ---
|
||||
loader = VectorQueryLoader(schema=schema, function=rpc_function, model=model, embedding_column=embedding_column)
|
||||
# --- END FIX ---
|
||||
|
||||
# --- FIX: Call fetch_vectors WITH the course_id argument ---
|
||||
# VectorQueryLoader.fetch_vectors handles the DB-level filtering
|
||||
df = loader.fetch_vectors(limit=None, course_id=course_id)
|
||||
# --- END FIX ---
|
||||
|
||||
# --- REMOVE: In-memory filtering logic is no longer needed ---
|
||||
# initial_rows = len(df)
|
||||
# if course_id is not None and not df.empty:
|
||||
# ... (removed filtering code) ...
|
||||
# elif course_id is not None and df.empty:
|
||||
# ... (removed warning) ...
|
||||
# --- END REMOVE ---
|
||||
|
||||
# --- FIX: Catch VectorQueryLoaderError ---
|
||||
except VectorQueryLoaderError as e:
|
||||
logger.error("Vector loading failed: %s", e)
|
||||
raise e # Re-raise the specific exception for the caller
|
||||
# --- END FIX ---
|
||||
except Exception as e:
|
||||
# Catch other unexpected errors during loading
|
||||
logger.exception("An unexpected error occurred during vector loading.")
|
||||
raise RuntimeError(f"An unexpected error occurred during vector loading: {e}") from e
|
||||
|
||||
|
||||
# --- Check if DataFrame is empty *after* fetching (which includes DB filtering) ---
|
||||
if df.empty:
|
||||
logger.error("No embeddings retrieved or found for course_id %s – aborting.", course_id)
|
||||
# Raise a RuntimeError as no clustering can be done
|
||||
raise RuntimeError(f"No embeddings retrieved or found for course_id {course_id} – nothing to cluster.")
|
||||
# ----------------------------------------------------------------------------------
|
||||
|
||||
|
||||
# Use the actual embedding column name from the loader instance
|
||||
# This check is crucial *after* fetching
|
||||
if not hasattr(loader, 'embedding_column') or loader.embedding_column not in df.columns:
|
||||
# This should ideally be caught by VectorQueryLoader's internal checks, but double-check
|
||||
logger.error("Embedding column '%s' not found in fetched data.", embedding_column) # Use the input param name for error msg
|
||||
raise RuntimeError(f"Embedding column '{embedding_column}' not found in fetched data.")
|
||||
|
||||
|
||||
# --- Ensure embeddings are numeric lists before stacking ---
|
||||
# The VectorQueryLoader.fetch_vectors method now handles parsing and dropping invalid rows.
|
||||
# We just need to safely stack the potentially filtered/cleaned data.
|
||||
try:
|
||||
# Ensure data is list of floats before stacking
|
||||
# This check might be redundant if VectorQueryLoader guarantees cleaned data,
|
||||
# but it adds safety.
|
||||
if not all(isinstance(x, list) and all(isinstance(n, float) for n in x) for x in df[embedding_column]):
|
||||
logger.error(f"Data in '{embedding_column}' is not strictly list[float] format after fetching. Attempting conversion.")
|
||||
# This might catch issues the loader missed or unexpected data structures
|
||||
try:
|
||||
# Attempt robust conversion similar to the loader's parse method
|
||||
embeddings_list = []
|
||||
for item in df[embedding_column]:
|
||||
parsed_item = None
|
||||
if isinstance(item, str):
|
||||
try: parsed_item = json.loads(item)
|
||||
except json.JSONDecodeError: pass
|
||||
elif isinstance(item, (list, tuple, np.ndarray)):
|
||||
parsed_item = item
|
||||
elif isinstance(item, dict) and 'vector' in item and isinstance(item['vector'], (list, tuple, np.ndarray)):
|
||||
parsed_item = item['vector']
|
||||
|
||||
if isinstance(parsed_item, (list, tuple, np.ndarray)) and all(isinstance(val, (int, float, np.number)) for val in parsed_item):
|
||||
embeddings_list.append([float(n) for n in parsed_item])
|
||||
else:
|
||||
logger.debug(f"Skipping problematic embedding during secondary clean: {str(item)[:100]}...")
|
||||
|
||||
|
||||
if not embeddings_list:
|
||||
logger.error("No valid embeddings remained after secondary cleaning.")
|
||||
raise ValueError("No valid embeddings for stacking.")
|
||||
embeddings = np.array(embeddings_list, dtype=float)
|
||||
logger.warning("Successfully converted problematic embedding data for stacking.")
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed secondary attempt to convert embeddings for stacking: {e}")
|
||||
raise RuntimeError(f"Failed to process embedding data for stacking: {e}") from e
|
||||
else:
|
||||
# Data is in the expected list of float format, proceed directly
|
||||
embeddings = np.stack(df[embedding_column].to_list()).astype(float)
|
||||
|
||||
logger.info("Prepared %d embeddings for clustering.", embeddings.shape[0])
|
||||
|
||||
|
||||
except ValueError as ve:
|
||||
logger.exception(f"Failed to stack embeddings into a numpy array: {ve}. Ensure '{embedding_column}' contains valid vector data.")
|
||||
raise RuntimeError(f"Failed to process embedding data: {ve}") from ve
|
||||
except Exception as e:
|
||||
logger.exception(f"An unexpected error occurred while processing '{embedding_column}' column for stacking.")
|
||||
raise RuntimeError(f"An unexpected error occurred while processing embedding data for stacking: {e}") from e
|
||||
# -------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prepare training sample and determine effective k
|
||||
# ---------------------------------------------------------------------------
|
||||
# Use the parameter train_sample_size
|
||||
train_vecs = embeddings[:train_sample_size]
|
||||
|
||||
if train_vecs.shape[0] == 0:
|
||||
# If course_id filtering resulted in 0 vectors, this check prevents the crash
|
||||
# but the df.empty check earlier should already handle this.
|
||||
# Keep this check for robustness in case train_sample_size is 0 or negative.
|
||||
logger.error("Training sample is empty – nothing to cluster.")
|
||||
raise RuntimeError("Training sample is empty – nothing to cluster.")
|
||||
|
||||
# Use the parameter k_clusters
|
||||
K = min(k_clusters, train_vecs.shape[0])
|
||||
if K < k_clusters:
|
||||
logger.warning(
|
||||
"Requested k=%d but only %d training vectors available; "
|
||||
"using k=%d.",
|
||||
k_clusters,
|
||||
train_vecs.shape[0],
|
||||
K,
|
||||
)
|
||||
# Ensure K is at least 1 if there's any data
|
||||
if K == 0 and train_vecs.shape[0] > 0:
|
||||
K = 1
|
||||
logger.warning("Adjusted k to 1 as requested k resulted in 0 but data exists.")
|
||||
|
||||
if K == 0:
|
||||
# If after adjustments K is still 0 (meaning train_vecs.shape[0] was 0)
|
||||
logger.error("Effective k is 0. Cannot train k-means.")
|
||||
raise RuntimeError("Effective k is 0. Cannot train k-means.")
|
||||
|
||||
|
||||
logger.info("Training k‑means (k=%d) on %d vectors", K, train_vecs.shape[0])
|
||||
|
||||
try:
|
||||
kmeans = KMeans(
|
||||
n_clusters=K,
|
||||
init="k-means++",
|
||||
n_init="auto", # Use 'auto' for better handling of small k/n_samples
|
||||
algorithm="lloyd", # 'lloyd' is the standard, 'elkan' can be faster but has limitations
|
||||
max_iter=300,
|
||||
random_state=0,
|
||||
)
|
||||
kmeans.fit(train_vecs)
|
||||
centroids: np.ndarray = kmeans.cluster_centers_
|
||||
logger.info("K‑means converged in %d iterations", kmeans.n_iter_)
|
||||
except Exception as e:
|
||||
logger.exception("K-means clustering failed.")
|
||||
raise RuntimeError(f"K-means clustering failed: {e}") from e
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assign every vector to its nearest centroid (full table)
|
||||
# ---------------------------------------------------------------------------
|
||||
logger.info("Assigning vectors to centroids...")
|
||||
try:
|
||||
# Use the determined embedding column for assignment as well
|
||||
labels_full, _ = pairwise_distances_argmin_min(embeddings, centroids, metric="euclidean")
|
||||
df["cluster_id"] = labels_full
|
||||
logger.info("Assigned cluster labels to all embeddings.")
|
||||
except Exception as e:
|
||||
logger.exception("Failed to assign vectors to centroids.")
|
||||
raise RuntimeError(f"Failed to assign vectors to centroids: {e}") from e
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Write one Markdown file per cluster
|
||||
# ---------------------------------------------------------------------------
|
||||
files_written_count = 0
|
||||
|
||||
logger.info("Writing cluster Markdown files to %s", output_path)
|
||||
try:
|
||||
# Only iterate up to the number of actual clusters found by KMeans
|
||||
# KMeans might return fewer clusters than K if there are issues or identical points
|
||||
num_actual_clusters = len(centroids)
|
||||
if num_actual_clusters < K:
|
||||
logger.warning(f"KMeans returned only {num_actual_clusters} centroids, expected {K}. Iterating over actual centroids.")
|
||||
|
||||
|
||||
for cid in range(num_actual_clusters): # Iterate over actual cluster IDs
|
||||
# Find all data points assigned to this cluster ID
|
||||
subset = df[df.cluster_id == cid]
|
||||
|
||||
# Ensure centroid_vec corresponds to the centroid of the *current* cluster ID (cid)
|
||||
# This check is more robust now iterating up to num_actual_clusters
|
||||
if cid < len(centroids):
|
||||
centroid_vec = centroids[cid]
|
||||
else:
|
||||
# This case should theoretically not be reached with the loop range
|
||||
logger.error(f"Centroid for cluster ID {cid} missing! Using zero vector.")
|
||||
centroid_vec = np.zeros(embeddings.shape[1])
|
||||
|
||||
|
||||
# Use .get() and .fillna("") defensively in case 'chunk' column is missing
|
||||
# Ensure chunk column exists - it should if SELECT * worked
|
||||
if 'chunk' not in subset.columns:
|
||||
logger.warning("'chunk' column missing in subset data for cluster %d. Using empty strings.", cid)
|
||||
chunks = [""] * len(subset)
|
||||
else:
|
||||
chunks = subset['chunk'].fillna("").tolist()
|
||||
|
||||
|
||||
md_lines = [
|
||||
#"---",
|
||||
#f"cluster_id: {cid}",
|
||||
#f"centroid: {centroid_to_json(centroid_vec)}",
|
||||
#"---\n", # Separator between frontmatter and content
|
||||
]
|
||||
# Add chunks, ensuring each chunk is on a new line or separated by blank lines
|
||||
md_lines.extend(chunks)
|
||||
|
||||
outfile = output_path / f"cluster_{cid:03d}.md"
|
||||
# Use a different separator for chunks within the file if needed,
|
||||
# currently just joins with newline, but chunks might contain newlines.
|
||||
# Joining with "\n\n" provides separation *between* chunks.
|
||||
try:
|
||||
outfile.write_text("\n\n".join(md_lines), encoding="utf-8")
|
||||
files_written_count += 1
|
||||
logger.debug("Wrote %s (%d chunks)", outfile.name, len(chunks)) # Use debug for per-file
|
||||
except Exception as write_exc:
|
||||
logger.error(f"Failed to write cluster file {outfile}: {write_exc}", exc_info=True)
|
||||
# Decide whether to continue or raise here. Continuing allows other clusters to be saved.
|
||||
# For robustness in script, maybe continue. For library function, maybe raise.
|
||||
# For now, we'll just log and continue.
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed during Markdown file writing loop.")
|
||||
raise RuntimeError(f"Failed during Markdown file writing: {e}") from e
|
||||
|
||||
|
||||
logger.info("Done. %d Markdown files created in %s", files_written_count, output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Script entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Configuration via environment for script
|
||||
script_output_dir = Path(os.environ.get("OUTPUT_DIR", "./cluster_md")).expanduser()
|
||||
script_schema = os.environ.get("VECTOR_SCHEMA", "librarian")
|
||||
script_rpc_function = os.environ.get("VECTOR_FUNCTION", "pdf_chunking") # Default to actual function name
|
||||
script_model = os.environ.get("EMBED_MODEL", "snowflake-arctic-embed2")
|
||||
script_k_req = int(os.environ.get("K", "128"))
|
||||
script_train_sample = int(os.environ.get("TRAIN_SAMPLE", "20000"))
|
||||
# Added course ID specific to script entry point
|
||||
script_course_id_str = os.environ.get("CLUSTER_COURSE_ID")
|
||||
script_course_id = int(script_course_id_str) if script_course_id_str and script_course_id_str.isdigit() else None # Added isdigit check
|
||||
|
||||
|
||||
# Configure basic logging for the script entry point
|
||||
# (The module-level config above might not run if imported in specific ways)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
# Re-get the logger after basicConfig to ensure it's configured
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
try:
|
||||
logger.info("Starting cluster export script...")
|
||||
final_output_path = run_cluster_export_job(
|
||||
course_id=script_course_id,
|
||||
output_dir=script_output_dir,
|
||||
schema=script_schema,
|
||||
rpc_function=script_rpc_function,
|
||||
model=script_model,
|
||||
k_clusters=script_k_req,
|
||||
train_sample_size=script_train_sample,
|
||||
# embedding_column defaults to 'embedding' in the function
|
||||
)
|
||||
logger.info("Script finished successfully. Output in %s", final_output_path)
|
||||
sys.exit(0) # Explicit success exit
|
||||
# --- FIX: Catch VectorQueryLoaderError ---
|
||||
except (VectorQueryLoaderError, RuntimeError) as e: # Catch the new error type
|
||||
# --- END FIX ---
|
||||
# Specific errors we raised
|
||||
logger.error("Script failed: %s", e)
|
||||
sys.exit(1) # Indicate failure
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
logger.exception("An unhandled error occurred during script execution.")
|
||||
sys.exit(1) # Indicate failure
|
@ -0,0 +1,73 @@
|
||||
|
||||
"""ClusterExportWorker – Prefect worker that wraps run_cluster_export_job."""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from prefect import get_run_logger
|
||||
from pydantic import BaseModel
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
from librarian_vspace.vquery.cluster_export import run_cluster_export_job
|
||||
|
||||
def _safe_get_logger(name: str):
|
||||
try:
|
||||
return get_run_logger()
|
||||
except Exception:
|
||||
return logging.getLogger(name)
|
||||
|
||||
class ClusterExportInput(BaseModel):
|
||||
course_id: int
|
||||
k_clusters: int = 128
|
||||
train_sample_size: int = 20_000
|
||||
db_schema: str = "librarian"
|
||||
rpc_function: str = "pdf_chunking"
|
||||
model: str = "snowflake-arctic-embed2"
|
||||
embedding_column: str = "embedding"
|
||||
base_output_dir: Optional[Path] = None
|
||||
|
||||
class ClusterExportOutput(BaseModel):
|
||||
output_dir: Path
|
||||
|
||||
class ClusterExportWorker(Worker[ClusterExportInput, ClusterExportOutput]):
|
||||
input_model = ClusterExportInput
|
||||
output_model = ClusterExportOutput
|
||||
|
||||
async def __run__(self, payload: ClusterExportInput) -> ClusterExportOutput:
|
||||
logger = _safe_get_logger(self.worker_name)
|
||||
logger.info("🔨 %s startet (payload=%r)", self.worker_name, payload)
|
||||
|
||||
# Prepare output directory
|
||||
if payload.base_output_dir:
|
||||
base_dir = Path(payload.base_output_dir).expanduser()
|
||||
base_dir.mkdir(parents=True, exist_ok=True)
|
||||
tmp_base = tempfile.mkdtemp(dir=base_dir)
|
||||
else:
|
||||
tmp_base = tempfile.mkdtemp()
|
||||
|
||||
output_dir = Path(tmp_base) / str(payload.course_id)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.debug("Output directory: %s", output_dir)
|
||||
|
||||
final_dir = await asyncio.to_thread(
|
||||
run_cluster_export_job,
|
||||
course_id=payload.course_id,
|
||||
output_dir=output_dir,
|
||||
schema=payload.db_schema,
|
||||
rpc_function=payload.rpc_function,
|
||||
model=payload.model,
|
||||
k_clusters=payload.k_clusters,
|
||||
train_sample_size=payload.train_sample_size,
|
||||
embedding_column=payload.embedding_column,
|
||||
)
|
||||
|
||||
self.stage(final_dir, new_name=final_dir.name)
|
||||
|
||||
result = ClusterExportOutput(output_dir=final_dir)
|
||||
logger.info("✅ %s fertig: %r", self.worker_name, result)
|
||||
return result
|
@ -0,0 +1,127 @@
|
||||
|
||||
"""VectorQuery – helper for vector searches against chunklet tables.
|
||||
|
||||
This module provides:
|
||||
* A Pydantic‑powered request / response API (see ``librarian_vspace.models.query_model``).
|
||||
* A single public method :py:meth:`VectorQuery.search` that returns a
|
||||
:class:`~librarian_vspace.models.query_model.VectorSearchResponse`.
|
||||
* A thin legacy wrapper ``get_chucklets_by_vector`` that produces the
|
||||
historical ``List[Dict[str, Any]]`` format, built on top of ``search``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from librarian_vspace.vutils.vector_class import BaseVectorOperator
|
||||
from librarian_vspace.vecembed.embedding_generator import EmbeddingGenerator
|
||||
except ImportError as exc: # pragma: no cover
|
||||
logging.error(
|
||||
"Failed to import vutils or vecembed sub‑packages: %s. " "Ensure they are on PYTHONPATH.", exc
|
||||
)
|
||||
|
||||
class BaseVectorOperator: # type: ignore
|
||||
"""Minimal stub if real class is unavailable (runtime error later)."""
|
||||
|
||||
class EmbeddingGenerator: # type: ignore
|
||||
"""Minimal stub; will raise at runtime if used."""
|
||||
|
||||
from librarian_vspace.models.query_model import (
|
||||
VectorSearchRequest,
|
||||
VectorSearchResponse,
|
||||
Chunklet,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------- #
|
||||
# Main helper
|
||||
# --------------------------------------------------------------------- #
|
||||
class VectorQuery(BaseVectorOperator):
|
||||
"""High‑level helper for vector searches via Supabase RPC."""
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Public – modern API
|
||||
# -----------------------------------------------------------------
|
||||
def search(self, request: VectorSearchRequest) -> VectorSearchResponse:
|
||||
"""Perform a similarity search and return structured results."""
|
||||
|
||||
if not getattr(self, "table", None):
|
||||
logger.error("VectorQuery: target table not determined (self.table is None).")
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
|
||||
# 1) Generate query embedding
|
||||
try:
|
||||
_tts, query_vec, _ = EmbeddingGenerator().generate_embedding(
|
||||
interface_name=request.interface_name,
|
||||
model_name=request.model_name,
|
||||
text_to_embed=request.search_string,
|
||||
identifier="query",
|
||||
)
|
||||
if query_vec is None:
|
||||
logger.error("Embedding generation returned None.")
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("Embedding generation failed: %s", exc)
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
|
||||
# 2) Build RPC parameters
|
||||
rpc_params = {
|
||||
"p_query_embedding": query_vec,
|
||||
"p_target_table": self.table,
|
||||
"p_embedding_column": request.embedding_column,
|
||||
"p_match_count": request.top_k,
|
||||
"p_filters": request.filters or {},
|
||||
}
|
||||
|
||||
# 3) Execute RPC
|
||||
try:
|
||||
if not getattr(self, "spc", None):
|
||||
logger.error("Supabase client (self.spc) not available.")
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
|
||||
resp = (
|
||||
self.spc
|
||||
.schema(self.schema)
|
||||
.rpc("vector_search", rpc_params)
|
||||
.execute()
|
||||
)
|
||||
data = resp.data or []
|
||||
results = [
|
||||
Chunklet(chunk=row.get("chunk"), file_id=row.get("file_id")) if isinstance(row.get("file_id"), str) else Chunklet(chunk=row.get("chunk"), file_id=str(row.get("file_id")))
|
||||
for row in data
|
||||
]
|
||||
return VectorSearchResponse(total=len(results), results=results)
|
||||
|
||||
except Exception as exc: # pragma: no cover
|
||||
logger.exception("RPC 'vector_search' failed: %s", exc)
|
||||
return VectorSearchResponse(total=0, results=[])
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Public – legacy compatibility
|
||||
# -----------------------------------------------------------------
|
||||
def get_chucklets_by_vector(
|
||||
self,
|
||||
*,
|
||||
interface_name: str,
|
||||
model_name: str,
|
||||
search_string: str,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
top_k: int = 10,
|
||||
embedding_column: str = "embedding",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Backward‑compatible wrapper returning ``{'chunk', 'file_id'}`` dicts."""
|
||||
|
||||
req = VectorSearchRequest(
|
||||
interface_name=interface_name,
|
||||
model_name=model_name,
|
||||
search_string=search_string,
|
||||
filters=filters,
|
||||
top_k=top_k,
|
||||
embedding_column=embedding_column,
|
||||
)
|
||||
resp = self.search(req)
|
||||
return [ck.dict() for ck in resp.results]
|
@ -0,0 +1,62 @@
|
||||
|
||||
"""QueryWorker – Prefect worker that performs a vector search.
|
||||
|
||||
It instantiates VectorQuery directly (no vspace dependency) and returns the
|
||||
VectorSearchResponse.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from prefect import get_run_logger
|
||||
from pydantic import BaseModel
|
||||
from librarian_core.workers.base import Worker
|
||||
|
||||
from librarian_vspace.vquery.query import VectorQuery
|
||||
from librarian_vspace.models.query_model import VectorSearchRequest, VectorSearchResponse
|
||||
|
||||
def _safe_get_logger(name: str):
|
||||
try:
|
||||
return get_run_logger()
|
||||
except Exception:
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
class QueryInput(BaseModel):
|
||||
request: VectorSearchRequest
|
||||
db_schema: str = "librarian"
|
||||
rpc_function: str = "pdf_chunking"
|
||||
embed_model: str = "snowflake-arctic-embed2"
|
||||
embedding_column: str = "embedding"
|
||||
|
||||
|
||||
class QueryWorker(Worker[QueryInput, VectorSearchResponse]):
|
||||
"""Runs a Supabase vector search via VectorQuery."""
|
||||
|
||||
input_model = QueryInput
|
||||
output_model = VectorSearchResponse
|
||||
|
||||
async def __run__(self, payload: QueryInput) -> VectorSearchResponse:
|
||||
logger = _safe_get_logger(self.worker_name)
|
||||
logger.info("🔨 %s startet (payload=%r)", self.worker_name, payload)
|
||||
|
||||
def _do_search() -> VectorSearchResponse:
|
||||
try:
|
||||
vq = VectorQuery(
|
||||
schema=payload.db_schema,
|
||||
function=payload.rpc_function,
|
||||
model=payload.embed_model,
|
||||
embedding_column=payload.embedding_column,
|
||||
)
|
||||
except TypeError:
|
||||
# fallback to positional signature
|
||||
vq = VectorQuery(payload.db_schema, payload.rpc_function, payload.embed_model)
|
||||
return vq.search(payload.request)
|
||||
|
||||
response = await asyncio.to_thread(_do_search)
|
||||
|
||||
logger.info("✅ %s fertig: %s results", self.worker_name, response.total)
|
||||
return response
|
@ -0,0 +1,2 @@
|
||||
def hello() -> str:
|
||||
return "Hello from librarian_vspace!"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user