""" Skapar eller uppdaterar talk_embeddings-tabellen med chunkade embeddings. Kör: python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb """ import os import numpy as np import sys import argparse from pathlib import Path from typing import Iterable, Mapping, List, Any from chromadb import Collection from time import sleep from random import random import backoff # Set /home/lasse/riksdagen as working directory os.chdir("/home/lasse/riksdagen") # Add the project root to Python path to locate local modules sys.path.append("/home/lasse/riksdagen") from arango_client import arango from utils import TextChunker from _chromadb.chroma_client import chroma_db from colorprinter import print_green, print_red, print_yellow, print_purple, print_blue CHROMA_COLLECTION = "talks" CHROMA_PATH = "/home/lasse/riksdagen/chromadb_data" def get_existing_chunk_ids(collection, _ids: List[str]) -> set: """ Fetch all existing chunk IDs for the given talk IDs in one query. Returns a set of chunk_id strings for fast O(1) lookup. """ try: # Get all documents where metadata._id is in our batch # ChromaDB might return error if collection is empty, so handle that result = collection.get( where={"_id": {"$in": _ids}}, include=[], # We only need IDs, not documents/embeddings ) return set(result.get("ids", [])) except Exception as e: print_red(f"Warning: Could not fetch existing IDs: {e}") return set() def get_all_dates() -> List[str]: """ Return a sorted list of unique 'datum' values found in the 'talks' collection. This uses AQL COLLECT to produce unique dates so we iterate date-by-date later. """ query = """ FOR doc IN talks COLLECT datum = doc.datum RETURN datum """ from arango_client import arango cursor = arango.db.aql.execute(query) return list(cursor) def embedding_exists(collection, chunk_id: str) -> bool: """ Check if an embedding with the given chunk_id already exists in the collection. """ res = collection.get(ids=[chunk_id]) return bool(res and res.get("ids")) def assign_debate_ids( docs: List[Mapping[str, Any]], date: str ) -> List[Mapping[str, Any]]: """ Assigns a debate id to each talk in a list of talks for a given date. A new debate starts when replik == False. The debate id is a string: "{date}:{debate_index}". Returns a new list of docs with a 'debate' field added. """ debate_index = 0 current_debate_id = f"{date}:{debate_index}" updated_docs = [] for doc in docs: # If this talk is not a reply, start a new debate if not doc.get("replik", False): debate_index += 1 current_debate_id = f"{date}:{debate_index}" # Add the debate id to the doc doc_with_debate = dict(doc) doc_with_debate["debate"] = current_debate_id updated_docs.append(doc_with_debate) return updated_docs def process_batch( arango_docs: Iterable[Mapping[str, object]], collection: Collection, use_existing_chunks: bool = False, replace_chromadb: bool = False, ) -> int: """ Processes a batch of ArangoDB documents (talks), chunks their text, and adds embeddings. Each talk is assigned a debate id. Ensures that only unique chunk IDs are added to ChromaDB to avoid DuplicateIDError. Args: arango_docs: Iterable of ArangoDB documents (talks). collection: ChromaDB collection object. use_existing_chunks: If True, use chunks stored in ArangoDB docs (if available). If False, always regenerate chunks from text. replace_chromadb: If True, replace existing ChromaDB entries with same ID. If False, skip documents that already exist in ChromaDB. Note: Must be True when use_existing_chunks is True. Returns: int: Number of chunks generated and added. """ print_blue(f"Starting process_batch for {len(arango_docs)} docs...") arango_docs = list(arango_docs) # Convert to list to iterate twice # Extract date from the first doc (all docs in batch have the same date) date = arango_docs[0].get("datum") if arango_docs else None if date: # Assign debate ids to all docs in this batch arango_docs = assign_debate_ids(arango_docs, date) # Fetch all existing chunk IDs for this batch at once (unless we're replacing everything) existing_ids = set() if not replace_chromadb: _ids = [row.get("_id") for row in arango_docs] existing_ids = get_existing_chunk_ids(collection, _ids) ids: List[str] = [] documents: List[str] = [] metadatas: List[Mapping[str, Any]] = [] chunks_generated = 0 updated_docs = [] delete_ids = [] for doc in arango_docs: # All talks for this date _id = doc.get("_id") text_body = (doc.get("anforandetext") or "").strip() if not text_body: print_yellow(f"Skipping empty talk {_id}") continue # Decide whether to use existing chunks or regenerate arango_chunks: list[dict] = [] should_update_arango = False if use_existing_chunks and doc.get("chunks"): # Use existing chunks from ArangoDB arango_chunks = doc["chunks"] else: # Generate new chunks from text text_chunks = TextChunker(chunk_limit=500).chunk(text_body) for chunk_index, content in enumerate(text_chunks): if not content or not content.strip(): print_yellow( f"Skipping empty chunk for talk {_id} index {chunk_index}" ) continue chunk_id = f"{_id}:{chunk_index}" chunk_doc = { "text": content, "index": chunk_index, "chroma_id": chunk_id, "chroma_collecton": CHROMA_COLLECTION, "debate": doc.get("debate"), } arango_chunks.append(chunk_doc) # Mark that we need to update ArangoDB with new chunks should_update_arango = True # Process each chunk for ChromaDB for chunk in arango_chunks: chunk_id = chunk.get("chroma_id") content = chunk.get("text") # Skip if already exists and we're not replacing if not replace_chromadb and chunk_id in existing_ids: continue # If replacing and exists, mark for deletion first if replace_chromadb and chunk_id in existing_ids: delete_ids.append(chunk_id) ids.append(chunk_id) documents.append(content) metadatas.append( { "_id": _id, "chunk_index": chunk.get("index"), "debate": doc.get("debate"), } ) chunks_generated += 1 # Only update ArangoDB if we generated new chunks if should_update_arango: updated_docs.append( { "_key": doc["_key"], "chunks": arango_chunks, "debate": doc.get("debate"), } ) # Add to chroma 200 a time to avoid using too much memory print_green(f"Adding {chunks_generated} chunks to ChromaDB {collection.name}...") for i in range(0, len(ids), 200): batch_ids = ids[i : i + 200] batch_docs = documents[i : i + 200] batch_metas = metadatas[i : i + 200] # Delete old entries if replacing if delete_ids: batch_delete = delete_ids[:200] delete_ids = delete_ids[200:] collection.delete(ids=batch_delete) if batch_ids: collection.add( ids=batch_ids, documents=batch_docs, metadatas=batch_metas, ) else: print_red("No new chunks to add in this batch.") # Helper function to update ArangoDB in smaller batches to avoid HTTP 413 errors def update_arango_in_batches(docs: list[dict], batch_size: int = 100) -> None: """ Update ArangoDB in batches to avoid 'Payload Too Large' errors. Args: docs: List of documents to update. batch_size: Number of documents per batch. """ for i in range(0, len(docs), batch_size): batch = docs[i : i + batch_size] arango.db.collection("talks").update_many(batch, merge=False) @backoff.on_exception(backoff.expo, Exception, max_tries=5) def safe_update(): if updated_docs: # Call the batch update helper instead of a single large update update_arango_in_batches(updated_docs, batch_size=100) safe_update() return chunks_generated def process_date( date: str, use_existing_chunks: bool, replace_chromadb: bool, use_local_imports: bool = False ) -> int: """ Worker function to process all talks for a given date. If use_local_imports is True, re-import arango and chroma_db locally (for multiprocessing). If False, uses the top-level imports (for sequential mode). Returns the number of chunks processed for this date. Args: date (str): The date to process. use_existing_chunks (bool): Whether to use existing chunks from ArangoDB. replace_chromadb (bool): Whether to replace existing ChromaDB entries. use_local_imports (bool): If True, re-import client objects for multiprocessing. Returns: int: Number of chunks processed for this date. """ import time if use_local_imports: # For multiprocessing: re-import client objects in each worker process. print_yellow(f"[{date}] Worker starting: re-importing DB clients...") from arango_client import arango as local_arango from _chromadb.chroma_client import chroma_db as local_chroma_db db = local_arango chroma = local_chroma_db time.sleep(random() * 2) # Stagger start times slightly else: # Use top-level imports for sequential mode db = arango chroma = chroma_db print_blue(f"[{date}] Processing date...") collection: Collection = chroma.get_collection(name=CHROMA_COLLECTION) query = """ FOR doc IN talks FILTER doc.datum == @date SORT doc.anforande_nummer RETURN { _id: doc._id, _key: doc._key, anforandetext: doc.anforandetext, anforande_nummer: doc.anforande_nummer, replik: doc.replik, datum: doc.datum, chunks: doc.chunks } """ cursor = db.db.aql.execute(query, bind_vars={"date": date}) docs = list(cursor) print_purple(f"[{date}] Processing {len(docs)} talks...") processed = process_batch( docs, collection, use_existing_chunks=use_existing_chunks, replace_chromadb=replace_chromadb, ) print_green(f"[{date}] Worker finished: processed {processed} chunks.") return processed def main() -> None: """ Main function to build or update the talk_embeddings table with chunked embeddings. Uses multiprocessing to process each date in parallel, unless --no-multiprocessing is set. """ parser = argparse.ArgumentParser( description="Build or update talk_embeddings with chunked embeddings.", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Generate fresh chunks and add only new ones to ChromaDB: python -m scripts.build_embeddings # Use existing chunks from ArangoDB and replace any duplicates in ChromaDB: python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb # Generate fresh chunks and replace everything in ChromaDB: python -m scripts.build_embeddings --replace-chromadb """ ) parser.add_argument( "--use-existing-chunks", action="store_true", help="Use chunks already stored in ArangoDB documents instead of regenerating them. " "Must be combined with --replace-chromadb.", ) parser.add_argument( "--replace-chromadb", action="store_true", help="Replace existing ChromaDB entries with the same ID. " "If not set, skip documents that already exist in ChromaDB.", ) parser.add_argument( "--no-multiprocessing", action="store_true", help="Run sequentially without multiprocessing.", ) args = parser.parse_args() # Validate argument combination # if args.use_existing_chunks and not args.replace_chromadb: # parser.error("--use-existing-chunks requires --replace-chromadb to be set") print_purple("Connecting to ChromaDB...") collection: Collection = chroma_db.get_collection(name=CHROMA_COLLECTION) docs_in_chroma = set(collection.get(include=[])['ids']) print_green(f"ChromaDB {collection.name} has {len(docs_in_chroma)} existing chunks.") # FILTER doc.chunks != null aql = """ FOR doc IN talks FILTER doc.chunks == null return {'datum': doc.datum, 'chroma_ids': doc.chunks[*].chroma_id} """ cursor = arango.db.aql.execute(aql, ttl=3600, count=True, batch_size=1000) all_dates = [] processed_documents = 0 while True: # Processera allt i klientbufferten while not cursor.empty(): doc = cursor.pop() # pop tar från client-side buffer utan att trigga fetch processed_documents += 1 chroma_ids = set(doc.get('chroma_ids', [])) if not chroma_ids.issubset(docs_in_chroma): all_dates.append(doc['datum']) elif 'chunks' not in doc: all_dates.append(doc['datum']) # Om server säger att det finns mer, hämta nästa batch från server if cursor.has_more(): cursor.fetch() # hämtar nästa server-batch och fyller client-buffer print(f"Fetched {processed_documents} documents so far; total reported: {cursor.count()}", end="\r") continue # Inget mer att hämta => sluta break all_dates = list(set(all_dates)) # Unika datum print_green(f"Found {len(all_dates)} unique dates to process.") # Set number of workers (processes) to 2 num_workers = 2 if args.no_multiprocessing: print_purple("Running in sequential (no multiprocessing) mode...") results = [] # Sort dates to process in chronological order all_dates.sort() for date in all_dates: # Use top-level imports in sequential mode result = process_date( date, use_existing_chunks=args.use_existing_chunks, replace_chromadb=args.replace_chromadb, use_local_imports=False ) results.append(result) else: print_purple(f"Starting multiprocessing with {num_workers} workers...") import multiprocessing # Use 'spawn' start method to avoid issues with forked DB/network connections ctx = multiprocessing.get_context("spawn") with ctx.Pool(processes=num_workers) as pool: print_yellow("Processing dates in parallel...") # Use local imports in multiprocessing mode results = pool.starmap( process_date, [ (date, args.use_existing_chunks, args.replace_chromadb, True) for date in all_dates ] ) total_chunks = sum(results) print_green(f"Färdig! Totalt {total_chunks} chunks bearbetade.") if __name__ == "__main__": main() print_green("Done!")