619 lines
24 KiB
Python
619 lines
24 KiB
Python
import os
|
||
import re
|
||
import json
|
||
from typing import List, Dict, Tuple, Union, Optional, Any, Literal
|
||
|
||
# Add datetime import for timestamped folders
|
||
from datetime import datetime
|
||
|
||
# Add Gemini imports
|
||
from google import genai
|
||
from google.genai import types
|
||
import dotenv # Load environment variables from .env file
|
||
|
||
import plotly.graph_objects as go
|
||
import pandas as pd
|
||
|
||
from sklearn.manifold import TSNE
|
||
|
||
# Add imports for Clustering and Similarity Search
|
||
from sklearn.cluster import KMeans
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
import numpy as np # Ensure numpy is imported if not already done earlier
|
||
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
|
||
# --- Stage 1: PDF Processing ---
|
||
from pypdf import PdfReader
|
||
|
||
|
||
# Load environment variables from .env file
|
||
dotenv.load_dotenv()
|
||
|
||
|
||
def extract_text_from_pdf(pdf_path: str) -> str:
|
||
"""Extracts text content from a PDF file."""
|
||
print(f"Processing PDF: {pdf_path}")
|
||
if not os.path.exists(pdf_path):
|
||
raise FileNotFoundError(f"PDF file not found: {pdf_path}")
|
||
try:
|
||
reader = PdfReader(pdf_path)
|
||
text = ""
|
||
for page_num, page in enumerate(reader.pages):
|
||
page_text = page.extract_text()
|
||
if page_text:
|
||
# Basic cleaning: replace multiple newlines/spaces
|
||
cleaned_text = re.sub(r"\s+", " ", page_text).strip()
|
||
text += cleaned_text + "\n" # Add newline between pages
|
||
print(f" Extracted text from page {page_num + 1}")
|
||
print(f"Finished extracting text. Total length: {len(text)} characters.")
|
||
return text
|
||
except Exception as e:
|
||
print(f"Error reading PDF {pdf_path}: {e}")
|
||
raise
|
||
|
||
|
||
# --- Stage 2: Text Chunking ---
|
||
def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
|
||
"""Splits text into overlapping chunks."""
|
||
print(f"Chunking text (size={chunk_size}, overlap={chunk_overlap})...")
|
||
if not text:
|
||
return []
|
||
|
||
chunks = []
|
||
start = 0
|
||
while start < len(text):
|
||
end = start + chunk_size
|
||
chunk = text[start:end]
|
||
chunks.append(chunk)
|
||
start += chunk_size - chunk_overlap # Move start forward for overlap
|
||
# Ensure we don't go past the end if overlap is large
|
||
if start >= len(text) - chunk_overlap and start < len(text):
|
||
# Add the last remaining part if it wasn't fully covered
|
||
final_chunk = text[start:]
|
||
if final_chunk and (
|
||
not chunks or chunks[-1] != final_chunk
|
||
): # Avoid duplicates
|
||
chunks.append(final_chunk)
|
||
break # Exit loop after handling the end
|
||
|
||
print(f"Generated {len(chunks)} chunks.")
|
||
return chunks
|
||
|
||
|
||
# --- Stage 3: Embedding Generation ---
|
||
|
||
# Load a relatively small but effective model
|
||
# This will download the model the first time it's run
|
||
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
||
print("Embedding model loaded.")
|
||
|
||
|
||
def generate_embeddings(chunks: List[str]) -> Tuple[List[str], List[List[float]]]:
|
||
"""Generates vector embeddings (as lists of floats) for a list of text chunks."""
|
||
if not chunks:
|
||
return [], []
|
||
print(f"Generating embeddings for {len(chunks)} chunks...")
|
||
# The model's encode function returns numpy arrays directly
|
||
embeddings = embedding_model.encode(chunks, show_progress_bar=True)
|
||
print(f"Generated embeddings of shape: {embeddings.shape}")
|
||
# Convert numpy array rows to list for easier simulation downstream
|
||
embeddings_list: List[List[float]] = [emb.tolist() for emb in embeddings]
|
||
return chunks, embeddings_list
|
||
|
||
|
||
# --- Stage 4: Simulate Supabase/pgvector Storage & Retrieval ---
|
||
# This simulates storing data, including converting vectors to pgvector's string format
|
||
# and then fetching it back, parsing the string.
|
||
|
||
# In-memory "database"
|
||
mock_db: List[Dict[str, Any]] = []
|
||
|
||
|
||
def simulate_pgvector_storage(chunks: List[str], embeddings: List[List[float]]):
|
||
"""Simulates storing chunks and embeddings in a DB like Supabase."""
|
||
print("Simulating storage...")
|
||
global mock_db
|
||
mock_db = [] # Clear previous data
|
||
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
|
||
# Simulate pgvector string format "[0.1,0.2,...]"
|
||
embedding_str = json.dumps(embedding)
|
||
mock_db.append(
|
||
{
|
||
"id": i,
|
||
"content": chunk,
|
||
"embedding_str": embedding_str, # Store as string
|
||
}
|
||
)
|
||
print(f"Simulated storing {len(mock_db)} items.")
|
||
|
||
|
||
def parse_pgvector_string(vector_string: str) -> List[float]:
|
||
"""Parses the string representation from pgvector into a list of floats."""
|
||
try:
|
||
return json.loads(vector_string)
|
||
except (json.JSONDecodeError, TypeError) as e:
|
||
print(f"Error parsing vector string '{vector_string[:50]}...': {e}")
|
||
return [] # Return empty list on error
|
||
|
||
|
||
def simulate_fetch_from_db() -> Tuple[List[str], List[List[float]]]:
|
||
"""Simulates fetching data and parsing vector strings."""
|
||
print("Simulating fetching data from DB...")
|
||
fetched_chunks = []
|
||
fetched_embeddings = []
|
||
for item in mock_db:
|
||
content = item.get("content")
|
||
embedding_str = item.get("embedding_str")
|
||
if content and embedding_str:
|
||
parsed_embedding = parse_pgvector_string(embedding_str)
|
||
if parsed_embedding: # Only add if parsing was successful
|
||
fetched_chunks.append(content)
|
||
fetched_embeddings.append(parsed_embedding)
|
||
else:
|
||
print(
|
||
f"Warning: Failed to parse embedding for item id {item.get('id')}"
|
||
)
|
||
|
||
print(f"Simulated fetching {len(fetched_chunks)} items with valid embeddings.")
|
||
return fetched_chunks, fetched_embeddings
|
||
|
||
|
||
# --- Stage 5: Dimensionality Reduction (t-SNE) ---
|
||
|
||
|
||
def apply_tsne(
|
||
embeddings: Union[np.ndarray, List[List[float]]],
|
||
n_components: int = 2,
|
||
perplexity: float = 30.0, # Adjust based on number of samples
|
||
learning_rate: Union[float, Literal["auto"]] = "auto",
|
||
n_iter: int = 1000,
|
||
random_state: int = 42,
|
||
verbose: int = 1,
|
||
) -> np.ndarray:
|
||
"""Applies t-SNE to reduce the dimensionality of vector embeddings."""
|
||
print("\nApplying t-SNE...")
|
||
if isinstance(embeddings, list):
|
||
embeddings_np = np.array(embeddings, dtype=np.float32)
|
||
elif isinstance(embeddings, np.ndarray):
|
||
embeddings_np = embeddings.astype(np.float32)
|
||
else:
|
||
raise TypeError("Embeddings must be a NumPy array or a list of lists.")
|
||
|
||
if embeddings_np.ndim != 2:
|
||
raise ValueError(
|
||
f"Input embeddings must be 2D, got shape {embeddings_np.shape}"
|
||
)
|
||
|
||
n_samples = embeddings_np.shape[0]
|
||
if n_samples == 0:
|
||
print("Warning: No embeddings to process with t-SNE.")
|
||
return np.empty((0, n_components))
|
||
|
||
# Adjust perplexity if it's too high for the number of samples
|
||
effective_perplexity = min(perplexity, max(1.0, n_samples - 1.0))
|
||
if effective_perplexity != perplexity:
|
||
print(
|
||
f"Warning: Perplexity adjusted from {perplexity} to {effective_perplexity} due to low sample count ({n_samples})."
|
||
)
|
||
|
||
tsne = TSNE(
|
||
n_components=n_components,
|
||
perplexity=effective_perplexity,
|
||
learning_rate=learning_rate,
|
||
n_iter=n_iter,
|
||
init="pca", # Often more stable
|
||
random_state=random_state,
|
||
verbose=verbose,
|
||
)
|
||
|
||
reduced_embeddings = tsne.fit_transform(embeddings_np)
|
||
print(f"t-SNE finished. Output shape: {reduced_embeddings.shape}")
|
||
return reduced_embeddings
|
||
|
||
|
||
# --- Stage 6: Visualization (using Plotly) ---
|
||
|
||
def plot_embeddings_interactive(
|
||
reduced_embeddings: np.ndarray,
|
||
texts: List[str],
|
||
):
|
||
"""Creates an interactive 3D Plotly scatter plot and displays it."""
|
||
if reduced_embeddings.shape[0] != len(texts):
|
||
raise ValueError("Number of embeddings and texts must match.")
|
||
if reduced_embeddings.shape[1] != 3: # Expect 3 dimensions now
|
||
raise ValueError("Reduced embeddings must be 3D for this plot.")
|
||
if reduced_embeddings.shape[0] == 0:
|
||
print("No data to plot.")
|
||
return
|
||
|
||
print("\nGenerating interactive 3D plot...")
|
||
|
||
df = pd.DataFrame(
|
||
{
|
||
"x": reduced_embeddings[:, 0],
|
||
"y": reduced_embeddings[:, 1],
|
||
"z": reduced_embeddings[:, 2], # Add z coordinate
|
||
"text": texts,
|
||
}
|
||
)
|
||
|
||
# Create hover text (limit length for readability)
|
||
hover_texts = [t[:200] + "..." if len(t) > 200 else t for t in df["text"]]
|
||
|
||
fig = go.Figure(
|
||
data=go.Scatter3d( # Use Scatter3d
|
||
x=df["x"],
|
||
y=df["y"],
|
||
z=df["z"], # Add z data
|
||
mode="markers",
|
||
marker=dict(
|
||
size=5, # Adjust marker size for 3D if needed
|
||
# color=df['z'], # Example: color by z-coordinate
|
||
# colorscale='Viridis',
|
||
# showscale=True
|
||
),
|
||
text=hover_texts, # Text shown on hover
|
||
hoverinfo="text", # Display only the hover text
|
||
)
|
||
)
|
||
|
||
fig.update_layout(
|
||
title="3D t-SNE Visualization of Text Chunk Embeddings",
|
||
scene=dict( # Use scene for 3D layout
|
||
xaxis_title="t-SNE Dimension 1",
|
||
yaxis_title="t-SNE Dimension 2",
|
||
zaxis_title="t-SNE Dimension 3", # Add z-axis label
|
||
),
|
||
hovermode="closest",
|
||
margin=dict(r=0, b=0, l=0, t=40), # Adjust margins if needed
|
||
)
|
||
|
||
# fig.write_html(output_filename) # Remove HTML saving
|
||
fig.show() # Display the plot in an interactive window
|
||
print("Plot window opened.")
|
||
|
||
|
||
# --- Gemini API Helper ---
|
||
def call_gemini_api(
|
||
prompt_text: str, model_name: str = "gemini-1.5-flash"
|
||
) -> Optional[str]:
|
||
"""Calls the Gemini API with the provided text and returns the generated content."""
|
||
try:
|
||
# Get the API key from environment variables
|
||
api_key = os.environ.get("GEMINI_API_KEY")
|
||
if not api_key:
|
||
print("Error: GEMINI_API_KEY environment variable not set.")
|
||
raise ValueError("Missing GEMINI_API_KEY")
|
||
|
||
# Initialize the client directly with the API key
|
||
# This aligns with the README example
|
||
client = genai.Client(api_key=api_key)
|
||
|
||
# Prepare the content for the API call
|
||
# The SDK handles converting the string prompt to the correct Content structure
|
||
contents = prompt_text
|
||
|
||
# Set up the generation configuration object
|
||
generation_config_obj = types.GenerateContentConfig(
|
||
temperature=0.7,
|
||
response_mime_type="text/plain", # Keep as plain text, formatting instructions are in the prompt
|
||
)
|
||
# Safety settings removed as per user request in previous step implicitly
|
||
|
||
# Generate content using the client's models.generate_content method
|
||
# Pass the generation config object to the 'config' parameter
|
||
response = client.models.generate_content(
|
||
model=model_name,
|
||
contents=contents,
|
||
config=generation_config_obj, # Corrected parameter name
|
||
# safety_settings removed
|
||
)
|
||
|
||
# Check for valid response and return text
|
||
if response and response.text:
|
||
return response.text.strip()
|
||
elif response.prompt_feedback and response.prompt_feedback.block_reason:
|
||
print(
|
||
f"Warning: Gemini API call blocked. Reason: {response.prompt_feedback.block_reason}"
|
||
)
|
||
return None
|
||
else:
|
||
# Check if the response has a 'parts' attribute (runtime check) but avoid static‑type errors
|
||
if response and hasattr(response, "parts"):
|
||
parts_attr = getattr(response, "parts") # type: ignore[attr-defined]
|
||
print(
|
||
f"Warning: Gemini API response has parts but no direct text attribute. Parts: {parts_attr}"
|
||
)
|
||
try:
|
||
# Join text from all parts that expose a 'text' attribute
|
||
return " ".join(
|
||
part.text for part in parts_attr if hasattr(part, "text")
|
||
).strip()
|
||
except Exception as part_error:
|
||
print(f"Error extracting text from parts: {part_error}")
|
||
return None # Fallback if parts structure is unexpected
|
||
else:
|
||
print(
|
||
f"Warning: Gemini API response format unexpected or empty: {response}"
|
||
)
|
||
return None
|
||
|
||
except ValueError as ve: # Catch missing API key specifically
|
||
if "Missing GEMINI_API_KEY" in str(ve):
|
||
print(
|
||
"Error: GEMINI_API_KEY environment variable not set."
|
||
) # Ensure message is printed
|
||
raise ve # Re-raise to stop execution
|
||
else:
|
||
print(f"An unexpected value error occurred during Gemini API call: {ve}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"Error calling Gemini API: {e}")
|
||
# Consider importing and catching specific google.api_core.exceptions if needed
|
||
# from google.api_core import exceptions as google_exceptions
|
||
# except google_exceptions.PermissionDenied: ...
|
||
return None
|
||
|
||
|
||
# --- Stage 7: Clustering and Markdown Generation ---
|
||
def generate_clustered_markdown(
|
||
reduced_embeddings: np.ndarray,
|
||
texts: List[str],
|
||
output_filename: str, # Now expects the full path including subdirectory
|
||
n_clusters: int = 5,
|
||
):
|
||
"""Performs K-Means clustering, uses Gemini for topics and rewriting (Markdown/LaTeX), and saves to Markdown."""
|
||
print("\n--- Clustering Texts and Generating Markdown with Gemini Rewriting ---")
|
||
# ... (clustering setup and KMeans prediction remain the same) ...
|
||
kmeans = KMeans(
|
||
n_clusters=n_clusters, random_state=42, n_init=10
|
||
) # n_init suppresses warning
|
||
cluster_labels = kmeans.fit_predict(reduced_embeddings)
|
||
|
||
# Group texts by cluster
|
||
clustered_texts: Dict[int, List[str]] = {i: [] for i in range(n_clusters)}
|
||
for text, label in zip(texts, cluster_labels):
|
||
clustered_texts[label].append(text)
|
||
|
||
# Build Markdown content
|
||
markdown_content = "# Clustered and Rewritten Text Document\n\n"
|
||
markdown_content += "This document groups text chunks based on semantic similarity. Topics and rewritten text (formatted in Markdown with LaTeX for equations) are generated by the Gemini API.\n\n"
|
||
|
||
for i in range(n_clusters):
|
||
cluster_topic = f"Cluster {i+1}" # Default topic
|
||
rewritten_content = "(Failed to generate rewritten text for this cluster.)" # Default content on failure
|
||
|
||
if clustered_texts[i]:
|
||
# ... (Combine chunks and limit context size as before) ...
|
||
combined_text = "\n\n---CHUNK SEPARATOR---\n\n".join(clustered_texts[i])
|
||
context_limit = (
|
||
15000 # Adjust based on model context window and typical chunk size
|
||
)
|
||
if len(combined_text) > context_limit:
|
||
print(
|
||
f"Warning: Combined text for cluster {i+1} exceeds {context_limit} chars, truncating for API call."
|
||
)
|
||
combined_text = combined_text[:context_limit] + "..."
|
||
|
||
# 1. Ask Gemini for a topic
|
||
# ... (Topic generation prompt and call remain the same) ...
|
||
topic_prompt = f"Analyze the following text excerpts separated by '---CHUNK SEPARATOR---'. Provide only a concise topic title (3-5 words maximum) that captures the main theme. Do not add any explanation or introductory text.\n\nText Excerpts:\n{combined_text}"
|
||
print(f" Generating topic for Cluster {i+1}...")
|
||
generated_topic = call_gemini_api(topic_prompt)
|
||
if generated_topic:
|
||
cluster_topic = (
|
||
generated_topic.replace('"', "").replace("Topic:", "").strip()
|
||
)
|
||
else:
|
||
print(f" Failed to generate topic for Cluster {i+1}.")
|
||
|
||
# 2. Ask Gemini to rewrite the text with Markdown and LaTeX formatting
|
||
# Updated prompt with formatting instructions
|
||
rewrite_prompt = f"""Rewrite and reorder the following text chunks, separated by '---CHUNK SEPARATOR---', into a single, coherent, and grammatically correct text.
|
||
Preserve all the original information and meaning, but improve the flow and readability.
|
||
Format the entire output as Markdown.
|
||
Use LaTeX delimiters for all mathematical equations: '$' for inline equations (e.g., $E=mc^2$) and '$$' for multiline equations (e.g., $$a^2 + b^2 = c^2$$).
|
||
Do not add any commentary, introduction, or conclusion beyond the rewritten text itself.
|
||
|
||
The output Language should be German besides technical terms.
|
||
You do not have to include citations for people or works mentioned in the text, unless they are essential to the meaning of the text.
|
||
|
||
Text Chunks:
|
||
{combined_text}"""
|
||
print(f" Rewriting text for Cluster {i+1}...")
|
||
generated_rewrite = call_gemini_api(rewrite_prompt)
|
||
if generated_rewrite:
|
||
rewritten_content = generated_rewrite
|
||
else:
|
||
print(
|
||
f" Failed to rewrite text for Cluster {i+1}. Using original chunks."
|
||
)
|
||
# Fallback to original chunks if rewrite fails
|
||
rewritten_content = (
|
||
"**Original Chunks (Rewrite Failed):**\n\n"
|
||
+ "\n\n---\n\n".join(clustered_texts[i])
|
||
)
|
||
|
||
# Add to Markdown
|
||
markdown_content += f"## {cluster_topic}\n\n"
|
||
markdown_content += (
|
||
f"{rewritten_content}\n\n" # Add the rewritten content (or fallback)
|
||
)
|
||
|
||
# Add a separator between clusters in the markdown file
|
||
if i < n_clusters - 1:
|
||
markdown_content += "\n---\n\n"
|
||
|
||
# Write to Markdown file
|
||
try:
|
||
# Ensure the directory exists before writing (handled in main)
|
||
with open(output_filename, "w", encoding="utf-8") as f:
|
||
f.write(markdown_content)
|
||
print(f"Clustered and rewritten text saved to: {output_filename}")
|
||
except IOError as e:
|
||
print(f"Error writing Markdown file {output_filename}: {e}")
|
||
|
||
print("\n--- End of Markdown Generation ---")
|
||
|
||
|
||
# --- Stage 8: Querying / Semantic Search ---
|
||
def find_similar_chunks(
|
||
query: str,
|
||
texts: List[str],
|
||
embeddings: np.ndarray, # Use original embeddings for similarity
|
||
model: SentenceTransformer,
|
||
top_n: int = 5,
|
||
) -> List[Tuple[str, float]]:
|
||
"""Finds text chunks most similar to the query."""
|
||
print(f"\nSearching for chunks similar to: '{query}'")
|
||
if embeddings.shape[0] == 0:
|
||
print("No embeddings available for search.")
|
||
return []
|
||
|
||
# Generate embedding for the query
|
||
query_embedding = model.encode([query]) # Pass query as a list
|
||
|
||
# Calculate cosine similarities
|
||
# embeddings should be 2D (n_samples, n_features)
|
||
# query_embedding should be 2D (1, n_features)
|
||
similarities = cosine_similarity(query_embedding, embeddings)[
|
||
0
|
||
] # Get the first row
|
||
|
||
# Get indices of top_n highest similarities
|
||
# If fewer results than top_n, take all available
|
||
num_results = min(top_n, len(similarities))
|
||
if num_results <= 0:
|
||
return []
|
||
|
||
# Use argsort to get indices of sorted similarities (descending)
|
||
sorted_indices = np.argsort(similarities)[::-1]
|
||
top_indices = sorted_indices[:num_results]
|
||
|
||
# Prepare results
|
||
results = [(texts[i], float(similarities[i])) for i in top_indices]
|
||
print(f"Found {len(results)} relevant chunks.")
|
||
return results
|
||
|
||
|
||
# --- Main Execution ---
|
||
if __name__ == "__main__":
|
||
pdf_file = "in/example.pdf"
|
||
# Define base output directory
|
||
base_output_dir = "output"
|
||
# Create timestamp string
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
# Create unique subdirectory path
|
||
output_subdir = os.path.join(base_output_dir, timestamp)
|
||
# Define the final markdown output file path
|
||
markdown_output_file = os.path.join(output_subdir, "clustered_rewritten_gemini.md")
|
||
|
||
try:
|
||
# Create the output directories if they don't exist
|
||
os.makedirs(output_subdir, exist_ok=True)
|
||
print(f"Output will be saved in: {output_subdir}")
|
||
|
||
# 1. Extract Text
|
||
full_text = extract_text_from_pdf(pdf_file)
|
||
|
||
# 2. Chunk Text
|
||
text_chunks = chunk_text(full_text, chunk_size=400, chunk_overlap=40)
|
||
|
||
if not text_chunks:
|
||
print("No text chunks generated. Exiting.")
|
||
exit()
|
||
|
||
# 3. Generate Embeddings
|
||
original_chunks, embeddings_list = generate_embeddings(text_chunks)
|
||
|
||
# 4. Simulate Storage & Retrieval
|
||
simulate_pgvector_storage(original_chunks, embeddings_list)
|
||
fetched_chunks, fetched_embeddings_list = simulate_fetch_from_db()
|
||
|
||
if not fetched_embeddings_list:
|
||
print("No valid embeddings fetched. Cannot proceed. Exiting.")
|
||
exit()
|
||
|
||
# Convert fetched list back to NumPy array
|
||
original_embeddings_np = np.array(fetched_embeddings_list)
|
||
|
||
# 5. Apply t-SNE
|
||
num_chunks = len(fetched_chunks)
|
||
tsne_perplexity = min(30.0, max(5.0, num_chunks / 4.0))
|
||
reduced_embeddings_3d = apply_tsne(
|
||
original_embeddings_np,
|
||
n_components=3,
|
||
perplexity=tsne_perplexity,
|
||
random_state=42,
|
||
verbose=1,
|
||
)
|
||
|
||
# 6. Visualize
|
||
if reduced_embeddings_3d.shape[0] > 0:
|
||
plot_embeddings_interactive(reduced_embeddings_3d, fetched_chunks)
|
||
else:
|
||
print("No reduced embeddings generated for plotting.")
|
||
|
||
# 7. Cluster and Generate Markdown File using Gemini
|
||
if reduced_embeddings_3d.shape[0] > 0:
|
||
num_chunks = len(fetched_chunks)
|
||
num_clusters = min(8, max(2, num_chunks // 10))
|
||
generate_clustered_markdown( # Call updated function
|
||
reduced_embeddings_3d,
|
||
fetched_chunks,
|
||
markdown_output_file, # Pass the full dynamic path
|
||
n_clusters=num_clusters,
|
||
)
|
||
else:
|
||
print("No reduced embeddings available for clustering.")
|
||
|
||
# 8. Interactive Querying Loop
|
||
print("\n--- Interactive Query Mode ---")
|
||
print("Enter a search query (or type 'quit' to exit):")
|
||
while True:
|
||
user_query = input("> ")
|
||
if user_query.lower() == "quit":
|
||
break
|
||
if not user_query.strip():
|
||
continue
|
||
|
||
# Perform search using original embeddings
|
||
search_results = find_similar_chunks(
|
||
user_query,
|
||
fetched_chunks,
|
||
original_embeddings_np, # Use full embeddings for search
|
||
embedding_model,
|
||
top_n=3, # Show top 3 results
|
||
)
|
||
|
||
if search_results:
|
||
print("--- Top Results ---")
|
||
for i, (text, score) in enumerate(search_results):
|
||
print(f"[{i+1}] Score: {score:.4f}\n Text: {text[:300]}...\n")
|
||
else:
|
||
print("No relevant chunks found.")
|
||
|
||
except FileNotFoundError as e:
|
||
print(f"Error: {e}")
|
||
print("Please ensure 'example.pdf' exists in the same directory.")
|
||
except ImportError as e:
|
||
print(f"Import Error: {e}")
|
||
print("Please ensure all required libraries are installed:")
|
||
print(
|
||
"pip install pypdf sentence-transformers torch scikit-learn numpy plotly pandas google-generativeai"
|
||
)
|
||
print("You might need: pip install scikit-learn google-generativeai")
|
||
except ValueError as e: # Catch missing API key error from helper
|
||
if "Missing GEMINI_API_KEY" in str(e):
|
||
# Message already printed in helper function, maybe add exit instruction
|
||
print("Please set the GEMINI_API_KEY environment variable and try again.")
|
||
exit(1) # Exit if API key is missing
|
||
else:
|
||
print(f"An unexpected value error occurred: {e}")
|
||
except Exception as e:
|
||
print(f"An unexpected error occurred: {e}")
|
||
# import traceback
|
||
# traceback.print_exc()
|