2025-04-17 13:21:08 +02:00

619 lines
24 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import json
from typing import List, Dict, Tuple, Union, Optional, Any, Literal
# Add datetime import for timestamped folders
from datetime import datetime
# Add Gemini imports
from google import genai
from google.genai import types
import dotenv # Load environment variables from .env file
import plotly.graph_objects as go
import pandas as pd
from sklearn.manifold import TSNE
# Add imports for Clustering and Similarity Search
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np # Ensure numpy is imported if not already done earlier
from sentence_transformers import SentenceTransformer
# --- Stage 1: PDF Processing ---
from pypdf import PdfReader
# Load environment variables from .env file
dotenv.load_dotenv()
def extract_text_from_pdf(pdf_path: str) -> str:
"""Extracts text content from a PDF file."""
print(f"Processing PDF: {pdf_path}")
if not os.path.exists(pdf_path):
raise FileNotFoundError(f"PDF file not found: {pdf_path}")
try:
reader = PdfReader(pdf_path)
text = ""
for page_num, page in enumerate(reader.pages):
page_text = page.extract_text()
if page_text:
# Basic cleaning: replace multiple newlines/spaces
cleaned_text = re.sub(r"\s+", " ", page_text).strip()
text += cleaned_text + "\n" # Add newline between pages
print(f" Extracted text from page {page_num + 1}")
print(f"Finished extracting text. Total length: {len(text)} characters.")
return text
except Exception as e:
print(f"Error reading PDF {pdf_path}: {e}")
raise
# --- Stage 2: Text Chunking ---
def chunk_text(text: str, chunk_size: int = 500, chunk_overlap: int = 50) -> List[str]:
"""Splits text into overlapping chunks."""
print(f"Chunking text (size={chunk_size}, overlap={chunk_overlap})...")
if not text:
return []
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunk = text[start:end]
chunks.append(chunk)
start += chunk_size - chunk_overlap # Move start forward for overlap
# Ensure we don't go past the end if overlap is large
if start >= len(text) - chunk_overlap and start < len(text):
# Add the last remaining part if it wasn't fully covered
final_chunk = text[start:]
if final_chunk and (
not chunks or chunks[-1] != final_chunk
): # Avoid duplicates
chunks.append(final_chunk)
break # Exit loop after handling the end
print(f"Generated {len(chunks)} chunks.")
return chunks
# --- Stage 3: Embedding Generation ---
# Load a relatively small but effective model
# This will download the model the first time it's run
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
print("Embedding model loaded.")
def generate_embeddings(chunks: List[str]) -> Tuple[List[str], List[List[float]]]:
"""Generates vector embeddings (as lists of floats) for a list of text chunks."""
if not chunks:
return [], []
print(f"Generating embeddings for {len(chunks)} chunks...")
# The model's encode function returns numpy arrays directly
embeddings = embedding_model.encode(chunks, show_progress_bar=True)
print(f"Generated embeddings of shape: {embeddings.shape}")
# Convert numpy array rows to list for easier simulation downstream
embeddings_list: List[List[float]] = [emb.tolist() for emb in embeddings]
return chunks, embeddings_list
# --- Stage 4: Simulate Supabase/pgvector Storage & Retrieval ---
# This simulates storing data, including converting vectors to pgvector's string format
# and then fetching it back, parsing the string.
# In-memory "database"
mock_db: List[Dict[str, Any]] = []
def simulate_pgvector_storage(chunks: List[str], embeddings: List[List[float]]):
"""Simulates storing chunks and embeddings in a DB like Supabase."""
print("Simulating storage...")
global mock_db
mock_db = [] # Clear previous data
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
# Simulate pgvector string format "[0.1,0.2,...]"
embedding_str = json.dumps(embedding)
mock_db.append(
{
"id": i,
"content": chunk,
"embedding_str": embedding_str, # Store as string
}
)
print(f"Simulated storing {len(mock_db)} items.")
def parse_pgvector_string(vector_string: str) -> List[float]:
"""Parses the string representation from pgvector into a list of floats."""
try:
return json.loads(vector_string)
except (json.JSONDecodeError, TypeError) as e:
print(f"Error parsing vector string '{vector_string[:50]}...': {e}")
return [] # Return empty list on error
def simulate_fetch_from_db() -> Tuple[List[str], List[List[float]]]:
"""Simulates fetching data and parsing vector strings."""
print("Simulating fetching data from DB...")
fetched_chunks = []
fetched_embeddings = []
for item in mock_db:
content = item.get("content")
embedding_str = item.get("embedding_str")
if content and embedding_str:
parsed_embedding = parse_pgvector_string(embedding_str)
if parsed_embedding: # Only add if parsing was successful
fetched_chunks.append(content)
fetched_embeddings.append(parsed_embedding)
else:
print(
f"Warning: Failed to parse embedding for item id {item.get('id')}"
)
print(f"Simulated fetching {len(fetched_chunks)} items with valid embeddings.")
return fetched_chunks, fetched_embeddings
# --- Stage 5: Dimensionality Reduction (t-SNE) ---
def apply_tsne(
embeddings: Union[np.ndarray, List[List[float]]],
n_components: int = 2,
perplexity: float = 30.0, # Adjust based on number of samples
learning_rate: Union[float, Literal["auto"]] = "auto",
n_iter: int = 1000,
random_state: int = 42,
verbose: int = 1,
) -> np.ndarray:
"""Applies t-SNE to reduce the dimensionality of vector embeddings."""
print("\nApplying t-SNE...")
if isinstance(embeddings, list):
embeddings_np = np.array(embeddings, dtype=np.float32)
elif isinstance(embeddings, np.ndarray):
embeddings_np = embeddings.astype(np.float32)
else:
raise TypeError("Embeddings must be a NumPy array or a list of lists.")
if embeddings_np.ndim != 2:
raise ValueError(
f"Input embeddings must be 2D, got shape {embeddings_np.shape}"
)
n_samples = embeddings_np.shape[0]
if n_samples == 0:
print("Warning: No embeddings to process with t-SNE.")
return np.empty((0, n_components))
# Adjust perplexity if it's too high for the number of samples
effective_perplexity = min(perplexity, max(1.0, n_samples - 1.0))
if effective_perplexity != perplexity:
print(
f"Warning: Perplexity adjusted from {perplexity} to {effective_perplexity} due to low sample count ({n_samples})."
)
tsne = TSNE(
n_components=n_components,
perplexity=effective_perplexity,
learning_rate=learning_rate,
n_iter=n_iter,
init="pca", # Often more stable
random_state=random_state,
verbose=verbose,
)
reduced_embeddings = tsne.fit_transform(embeddings_np)
print(f"t-SNE finished. Output shape: {reduced_embeddings.shape}")
return reduced_embeddings
# --- Stage 6: Visualization (using Plotly) ---
def plot_embeddings_interactive(
reduced_embeddings: np.ndarray,
texts: List[str],
):
"""Creates an interactive 3D Plotly scatter plot and displays it."""
if reduced_embeddings.shape[0] != len(texts):
raise ValueError("Number of embeddings and texts must match.")
if reduced_embeddings.shape[1] != 3: # Expect 3 dimensions now
raise ValueError("Reduced embeddings must be 3D for this plot.")
if reduced_embeddings.shape[0] == 0:
print("No data to plot.")
return
print("\nGenerating interactive 3D plot...")
df = pd.DataFrame(
{
"x": reduced_embeddings[:, 0],
"y": reduced_embeddings[:, 1],
"z": reduced_embeddings[:, 2], # Add z coordinate
"text": texts,
}
)
# Create hover text (limit length for readability)
hover_texts = [t[:200] + "..." if len(t) > 200 else t for t in df["text"]]
fig = go.Figure(
data=go.Scatter3d( # Use Scatter3d
x=df["x"],
y=df["y"],
z=df["z"], # Add z data
mode="markers",
marker=dict(
size=5, # Adjust marker size for 3D if needed
# color=df['z'], # Example: color by z-coordinate
# colorscale='Viridis',
# showscale=True
),
text=hover_texts, # Text shown on hover
hoverinfo="text", # Display only the hover text
)
)
fig.update_layout(
title="3D t-SNE Visualization of Text Chunk Embeddings",
scene=dict( # Use scene for 3D layout
xaxis_title="t-SNE Dimension 1",
yaxis_title="t-SNE Dimension 2",
zaxis_title="t-SNE Dimension 3", # Add z-axis label
),
hovermode="closest",
margin=dict(r=0, b=0, l=0, t=40), # Adjust margins if needed
)
# fig.write_html(output_filename) # Remove HTML saving
fig.show() # Display the plot in an interactive window
print("Plot window opened.")
# --- Gemini API Helper ---
def call_gemini_api(
prompt_text: str, model_name: str = "gemini-1.5-flash"
) -> Optional[str]:
"""Calls the Gemini API with the provided text and returns the generated content."""
try:
# Get the API key from environment variables
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
print("Error: GEMINI_API_KEY environment variable not set.")
raise ValueError("Missing GEMINI_API_KEY")
# Initialize the client directly with the API key
# This aligns with the README example
client = genai.Client(api_key=api_key)
# Prepare the content for the API call
# The SDK handles converting the string prompt to the correct Content structure
contents = prompt_text
# Set up the generation configuration object
generation_config_obj = types.GenerateContentConfig(
temperature=0.7,
response_mime_type="text/plain", # Keep as plain text, formatting instructions are in the prompt
)
# Safety settings removed as per user request in previous step implicitly
# Generate content using the client's models.generate_content method
# Pass the generation config object to the 'config' parameter
response = client.models.generate_content(
model=model_name,
contents=contents,
config=generation_config_obj, # Corrected parameter name
# safety_settings removed
)
# Check for valid response and return text
if response and response.text:
return response.text.strip()
elif response.prompt_feedback and response.prompt_feedback.block_reason:
print(
f"Warning: Gemini API call blocked. Reason: {response.prompt_feedback.block_reason}"
)
return None
else:
# Check if the response has a 'parts' attribute (runtime check) but avoid statictype errors
if response and hasattr(response, "parts"):
parts_attr = getattr(response, "parts") # type: ignore[attr-defined]
print(
f"Warning: Gemini API response has parts but no direct text attribute. Parts: {parts_attr}"
)
try:
# Join text from all parts that expose a 'text' attribute
return " ".join(
part.text for part in parts_attr if hasattr(part, "text")
).strip()
except Exception as part_error:
print(f"Error extracting text from parts: {part_error}")
return None # Fallback if parts structure is unexpected
else:
print(
f"Warning: Gemini API response format unexpected or empty: {response}"
)
return None
except ValueError as ve: # Catch missing API key specifically
if "Missing GEMINI_API_KEY" in str(ve):
print(
"Error: GEMINI_API_KEY environment variable not set."
) # Ensure message is printed
raise ve # Re-raise to stop execution
else:
print(f"An unexpected value error occurred during Gemini API call: {ve}")
return None
except Exception as e:
print(f"Error calling Gemini API: {e}")
# Consider importing and catching specific google.api_core.exceptions if needed
# from google.api_core import exceptions as google_exceptions
# except google_exceptions.PermissionDenied: ...
return None
# --- Stage 7: Clustering and Markdown Generation ---
def generate_clustered_markdown(
reduced_embeddings: np.ndarray,
texts: List[str],
output_filename: str, # Now expects the full path including subdirectory
n_clusters: int = 5,
):
"""Performs K-Means clustering, uses Gemini for topics and rewriting (Markdown/LaTeX), and saves to Markdown."""
print("\n--- Clustering Texts and Generating Markdown with Gemini Rewriting ---")
# ... (clustering setup and KMeans prediction remain the same) ...
kmeans = KMeans(
n_clusters=n_clusters, random_state=42, n_init=10
) # n_init suppresses warning
cluster_labels = kmeans.fit_predict(reduced_embeddings)
# Group texts by cluster
clustered_texts: Dict[int, List[str]] = {i: [] for i in range(n_clusters)}
for text, label in zip(texts, cluster_labels):
clustered_texts[label].append(text)
# Build Markdown content
markdown_content = "# Clustered and Rewritten Text Document\n\n"
markdown_content += "This document groups text chunks based on semantic similarity. Topics and rewritten text (formatted in Markdown with LaTeX for equations) are generated by the Gemini API.\n\n"
for i in range(n_clusters):
cluster_topic = f"Cluster {i+1}" # Default topic
rewritten_content = "(Failed to generate rewritten text for this cluster.)" # Default content on failure
if clustered_texts[i]:
# ... (Combine chunks and limit context size as before) ...
combined_text = "\n\n---CHUNK SEPARATOR---\n\n".join(clustered_texts[i])
context_limit = (
15000 # Adjust based on model context window and typical chunk size
)
if len(combined_text) > context_limit:
print(
f"Warning: Combined text for cluster {i+1} exceeds {context_limit} chars, truncating for API call."
)
combined_text = combined_text[:context_limit] + "..."
# 1. Ask Gemini for a topic
# ... (Topic generation prompt and call remain the same) ...
topic_prompt = f"Analyze the following text excerpts separated by '---CHUNK SEPARATOR---'. Provide only a concise topic title (3-5 words maximum) that captures the main theme. Do not add any explanation or introductory text.\n\nText Excerpts:\n{combined_text}"
print(f" Generating topic for Cluster {i+1}...")
generated_topic = call_gemini_api(topic_prompt)
if generated_topic:
cluster_topic = (
generated_topic.replace('"', "").replace("Topic:", "").strip()
)
else:
print(f" Failed to generate topic for Cluster {i+1}.")
# 2. Ask Gemini to rewrite the text with Markdown and LaTeX formatting
# Updated prompt with formatting instructions
rewrite_prompt = f"""Rewrite and reorder the following text chunks, separated by '---CHUNK SEPARATOR---', into a single, coherent, and grammatically correct text.
Preserve all the original information and meaning, but improve the flow and readability.
Format the entire output as Markdown.
Use LaTeX delimiters for all mathematical equations: '$' for inline equations (e.g., $E=mc^2$) and '$$' for multiline equations (e.g., $$a^2 + b^2 = c^2$$).
Do not add any commentary, introduction, or conclusion beyond the rewritten text itself.
The output Language should be German besides technical terms.
You do not have to include citations for people or works mentioned in the text, unless they are essential to the meaning of the text.
Text Chunks:
{combined_text}"""
print(f" Rewriting text for Cluster {i+1}...")
generated_rewrite = call_gemini_api(rewrite_prompt)
if generated_rewrite:
rewritten_content = generated_rewrite
else:
print(
f" Failed to rewrite text for Cluster {i+1}. Using original chunks."
)
# Fallback to original chunks if rewrite fails
rewritten_content = (
"**Original Chunks (Rewrite Failed):**\n\n"
+ "\n\n---\n\n".join(clustered_texts[i])
)
# Add to Markdown
markdown_content += f"## {cluster_topic}\n\n"
markdown_content += (
f"{rewritten_content}\n\n" # Add the rewritten content (or fallback)
)
# Add a separator between clusters in the markdown file
if i < n_clusters - 1:
markdown_content += "\n---\n\n"
# Write to Markdown file
try:
# Ensure the directory exists before writing (handled in main)
with open(output_filename, "w", encoding="utf-8") as f:
f.write(markdown_content)
print(f"Clustered and rewritten text saved to: {output_filename}")
except IOError as e:
print(f"Error writing Markdown file {output_filename}: {e}")
print("\n--- End of Markdown Generation ---")
# --- Stage 8: Querying / Semantic Search ---
def find_similar_chunks(
query: str,
texts: List[str],
embeddings: np.ndarray, # Use original embeddings for similarity
model: SentenceTransformer,
top_n: int = 5,
) -> List[Tuple[str, float]]:
"""Finds text chunks most similar to the query."""
print(f"\nSearching for chunks similar to: '{query}'")
if embeddings.shape[0] == 0:
print("No embeddings available for search.")
return []
# Generate embedding for the query
query_embedding = model.encode([query]) # Pass query as a list
# Calculate cosine similarities
# embeddings should be 2D (n_samples, n_features)
# query_embedding should be 2D (1, n_features)
similarities = cosine_similarity(query_embedding, embeddings)[
0
] # Get the first row
# Get indices of top_n highest similarities
# If fewer results than top_n, take all available
num_results = min(top_n, len(similarities))
if num_results <= 0:
return []
# Use argsort to get indices of sorted similarities (descending)
sorted_indices = np.argsort(similarities)[::-1]
top_indices = sorted_indices[:num_results]
# Prepare results
results = [(texts[i], float(similarities[i])) for i in top_indices]
print(f"Found {len(results)} relevant chunks.")
return results
# --- Main Execution ---
if __name__ == "__main__":
pdf_file = "in/example.pdf"
# Define base output directory
base_output_dir = "output"
# Create timestamp string
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create unique subdirectory path
output_subdir = os.path.join(base_output_dir, timestamp)
# Define the final markdown output file path
markdown_output_file = os.path.join(output_subdir, "clustered_rewritten_gemini.md")
try:
# Create the output directories if they don't exist
os.makedirs(output_subdir, exist_ok=True)
print(f"Output will be saved in: {output_subdir}")
# 1. Extract Text
full_text = extract_text_from_pdf(pdf_file)
# 2. Chunk Text
text_chunks = chunk_text(full_text, chunk_size=400, chunk_overlap=40)
if not text_chunks:
print("No text chunks generated. Exiting.")
exit()
# 3. Generate Embeddings
original_chunks, embeddings_list = generate_embeddings(text_chunks)
# 4. Simulate Storage & Retrieval
simulate_pgvector_storage(original_chunks, embeddings_list)
fetched_chunks, fetched_embeddings_list = simulate_fetch_from_db()
if not fetched_embeddings_list:
print("No valid embeddings fetched. Cannot proceed. Exiting.")
exit()
# Convert fetched list back to NumPy array
original_embeddings_np = np.array(fetched_embeddings_list)
# 5. Apply t-SNE
num_chunks = len(fetched_chunks)
tsne_perplexity = min(30.0, max(5.0, num_chunks / 4.0))
reduced_embeddings_3d = apply_tsne(
original_embeddings_np,
n_components=3,
perplexity=tsne_perplexity,
random_state=42,
verbose=1,
)
# 6. Visualize
if reduced_embeddings_3d.shape[0] > 0:
plot_embeddings_interactive(reduced_embeddings_3d, fetched_chunks)
else:
print("No reduced embeddings generated for plotting.")
# 7. Cluster and Generate Markdown File using Gemini
if reduced_embeddings_3d.shape[0] > 0:
num_chunks = len(fetched_chunks)
num_clusters = min(8, max(2, num_chunks // 10))
generate_clustered_markdown( # Call updated function
reduced_embeddings_3d,
fetched_chunks,
markdown_output_file, # Pass the full dynamic path
n_clusters=num_clusters,
)
else:
print("No reduced embeddings available for clustering.")
# 8. Interactive Querying Loop
print("\n--- Interactive Query Mode ---")
print("Enter a search query (or type 'quit' to exit):")
while True:
user_query = input("> ")
if user_query.lower() == "quit":
break
if not user_query.strip():
continue
# Perform search using original embeddings
search_results = find_similar_chunks(
user_query,
fetched_chunks,
original_embeddings_np, # Use full embeddings for search
embedding_model,
top_n=3, # Show top 3 results
)
if search_results:
print("--- Top Results ---")
for i, (text, score) in enumerate(search_results):
print(f"[{i+1}] Score: {score:.4f}\n Text: {text[:300]}...\n")
else:
print("No relevant chunks found.")
except FileNotFoundError as e:
print(f"Error: {e}")
print("Please ensure 'example.pdf' exists in the same directory.")
except ImportError as e:
print(f"Import Error: {e}")
print("Please ensure all required libraries are installed:")
print(
"pip install pypdf sentence-transformers torch scikit-learn numpy plotly pandas google-generativeai"
)
print("You might need: pip install scikit-learn google-generativeai")
except ValueError as e: # Catch missing API key error from helper
if "Missing GEMINI_API_KEY" in str(e):
# Message already printed in helper function, maybe add exit instruction
print("Please set the GEMINI_API_KEY environment variable and try again.")
exit(1) # Exit if API key is missing
else:
print(f"An unexpected value error occurred: {e}")
except Exception as e:
print(f"An unexpected error occurred: {e}")
# import traceback
# traceback.print_exc()