2025-05-24 12:15:48 +02:00

67 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""examples/demo_run_query.py
Runs QueryWorker via FlowArtifact wrapper (mirrors cluster export demo).
"""
from __future__ import annotations
import asyncio
import logging
import os
from pathlib import Path
from librarian_vspace.vquery.query_worker import QueryWorker, QueryInput
from librarian_vspace.models.query_model import VectorSearchRequest
from librarian_core.workers.base import FlowArtifact
# ------------------------------------------------------------------ #
# Config
# ------------------------------------------------------------------ #
SEARCH_STRING = "integration"
COURSE_FILTER_GT = 900 # adjust if needed
logger = logging.getLogger(__name__)
def _load_env(path: Path) -> None:
if not path.is_file():
return
for line in path.read_text().splitlines():
if line.strip() and not line.startswith("#") and "=" in line:
k, v = [p.strip() for p in line.split("=", 1)]
os.environ.setdefault(k, v)
# ------------------------------------------------------------------ #
async def _main() -> None:
# Vector search request
vs_req = VectorSearchRequest(
interface_name=os.getenv("EMBED_INTERFACE", "ollama"),
model_name=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
search_string=SEARCH_STRING,
filters={"file_id": ("gt", COURSE_FILTER_GT)},
top_k=10,
)
payload = QueryInput(
request=vs_req,
db_schema=os.getenv("VECTOR_SCHEMA", "librarian"),
rpc_function=os.getenv("VECTOR_FUNCTION", "pdf_chunking"),
embed_model=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
)
worker = QueryWorker()
art = FlowArtifact.new(run_id="", dir=Path.cwd(), data=payload)
result_artifact = await worker.flow()(art) # FlowArtifact
response = result_artifact.data # VectorSearchResponse
logger.info("✅ Worker finished received %s results", response.total)
for idx, ck in enumerate(response.results, 1):
logger.info("%s: %s", idx, ck.chunk[:80] + ("" if len(ck.chunk or '') > 80 else ""))
# ------------------------------------------------------------------ #
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
APP_DIR = Path(__file__).resolve().parent
_load_env(APP_DIR / ".env")
asyncio.run(_main())