67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
|
||
#!/usr/bin/env python3
|
||
"""examples/demo_run_query.py
|
||
|
||
Runs QueryWorker via FlowArtifact wrapper (mirrors cluster export demo).
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
from pathlib import Path
|
||
|
||
from librarian_vspace.vquery.query_worker import QueryWorker, QueryInput
|
||
from librarian_vspace.models.query_model import VectorSearchRequest
|
||
from librarian_core.workers.base import FlowArtifact
|
||
|
||
# ------------------------------------------------------------------ #
|
||
# Config
|
||
# ------------------------------------------------------------------ #
|
||
SEARCH_STRING = "integration"
|
||
COURSE_FILTER_GT = 900 # adjust if needed
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def _load_env(path: Path) -> None:
|
||
if not path.is_file():
|
||
return
|
||
for line in path.read_text().splitlines():
|
||
if line.strip() and not line.startswith("#") and "=" in line:
|
||
k, v = [p.strip() for p in line.split("=", 1)]
|
||
os.environ.setdefault(k, v)
|
||
|
||
# ------------------------------------------------------------------ #
|
||
async def _main() -> None:
|
||
# Vector search request
|
||
vs_req = VectorSearchRequest(
|
||
interface_name=os.getenv("EMBED_INTERFACE", "ollama"),
|
||
model_name=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
|
||
search_string=SEARCH_STRING,
|
||
filters={"file_id": ("gt", COURSE_FILTER_GT)},
|
||
top_k=10,
|
||
)
|
||
|
||
payload = QueryInput(
|
||
request=vs_req,
|
||
db_schema=os.getenv("VECTOR_SCHEMA", "librarian"),
|
||
rpc_function=os.getenv("VECTOR_FUNCTION", "pdf_chunking"),
|
||
embed_model=os.getenv("EMBED_MODEL", "snowflake-arctic-embed2"),
|
||
)
|
||
|
||
worker = QueryWorker()
|
||
art = FlowArtifact.new(run_id="", dir=Path.cwd(), data=payload)
|
||
result_artifact = await worker.flow()(art) # FlowArtifact
|
||
|
||
response = result_artifact.data # VectorSearchResponse
|
||
logger.info("✅ Worker finished – received %s results", response.total)
|
||
for idx, ck in enumerate(response.results, 1):
|
||
logger.info("• %s: %s", idx, ck.chunk[:80] + ("…" if len(ck.chunk or '') > 80 else ""))
|
||
|
||
# ------------------------------------------------------------------ #
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(level=logging.INFO, format="[%(levelname)s] %(message)s")
|
||
APP_DIR = Path(__file__).resolve().parent
|
||
_load_env(APP_DIR / ".env")
|
||
asyncio.run(_main())
|