You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

447 lines
16 KiB

"""
Skapar eller uppdaterar talk_embeddings-tabellen med chunkade embeddings.
Kör: python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb
"""
import os
import numpy as np
import sys
import argparse
from pathlib import Path
from typing import Iterable, Mapping, List, Any
from chromadb import Collection
from time import sleep
from random import random
import backoff
# Set /home/lasse/riksdagen as working directory
os.chdir("/home/lasse/riksdagen")
# Add the project root to Python path to locate local modules
sys.path.append("/home/lasse/riksdagen")
from arango_client import arango
from utils import TextChunker
from _chromadb.chroma_client import chroma_db
from colorprinter import print_green, print_red, print_yellow, print_purple, print_blue
CHROMA_COLLECTION = "talks"
CHROMA_PATH = "/home/lasse/riksdagen/chromadb_data"
def get_existing_chunk_ids(collection, _ids: List[str]) -> set:
"""
Fetch all existing chunk IDs for the given talk IDs in one query.
Returns a set of chunk_id strings for fast O(1) lookup.
"""
try:
# Get all documents where metadata._id is in our batch
# ChromaDB might return error if collection is empty, so handle that
result = collection.get(
where={"_id": {"$in": _ids}},
include=[], # We only need IDs, not documents/embeddings
)
return set(result.get("ids", []))
except Exception as e:
print_red(f"Warning: Could not fetch existing IDs: {e}")
return set()
def get_all_dates() -> List[str]:
"""
Return a sorted list of unique 'datum' values found in the 'talks' collection.
This uses AQL COLLECT to produce unique dates so we iterate date-by-date later.
"""
query = """
FOR doc IN talks
COLLECT datum = doc.datum
RETURN datum
"""
from arango_client import arango
cursor = arango.db.aql.execute(query)
return list(cursor)
def embedding_exists(collection, chunk_id: str) -> bool:
"""
Check if an embedding with the given chunk_id already exists in the collection.
"""
res = collection.get(ids=[chunk_id])
return bool(res and res.get("ids"))
def assign_debate_ids(
docs: List[Mapping[str, Any]], date: str
) -> List[Mapping[str, Any]]:
"""
Assigns a debate id to each talk in a list of talks for a given date.
A new debate starts when replik == False.
The debate id is a string: "{date}:{debate_index}".
Returns a new list of docs with a 'debate' field added.
"""
debate_index = 0
current_debate_id = f"{date}:{debate_index}"
updated_docs = []
for doc in docs:
# If this talk is not a reply, start a new debate
if not doc.get("replik", False):
debate_index += 1
current_debate_id = f"{date}:{debate_index}"
# Add the debate id to the doc
doc_with_debate = dict(doc)
doc_with_debate["debate"] = current_debate_id
updated_docs.append(doc_with_debate)
return updated_docs
def process_batch(
arango_docs: Iterable[Mapping[str, object]],
collection: Collection,
use_existing_chunks: bool = False,
replace_chromadb: bool = False,
) -> int:
"""
Processes a batch of ArangoDB documents (talks), chunks their text, and adds embeddings.
Each talk is assigned a debate id.
Ensures that only unique chunk IDs are added to ChromaDB to avoid DuplicateIDError.
Args:
arango_docs: Iterable of ArangoDB documents (talks).
collection: ChromaDB collection object.
use_existing_chunks: If True, use chunks stored in ArangoDB docs (if available).
If False, always regenerate chunks from text.
replace_chromadb: If True, replace existing ChromaDB entries with same ID.
If False, skip documents that already exist in ChromaDB.
Note: Must be True when use_existing_chunks is True.
Returns:
int: Number of chunks generated and added.
"""
print_blue(f"Starting process_batch for {len(arango_docs)} docs...")
arango_docs = list(arango_docs) # Convert to list to iterate twice
# Extract date from the first doc (all docs in batch have the same date)
date = arango_docs[0].get("datum") if arango_docs else None
if date:
# Assign debate ids to all docs in this batch
arango_docs = assign_debate_ids(arango_docs, date)
# Fetch all existing chunk IDs for this batch at once (unless we're replacing everything)
existing_ids = set()
if not replace_chromadb:
_ids = [row.get("_id") for row in arango_docs]
existing_ids = get_existing_chunk_ids(collection, _ids)
ids: List[str] = []
documents: List[str] = []
metadatas: List[Mapping[str, Any]] = []
chunks_generated = 0
updated_docs = []
delete_ids = []
for doc in arango_docs: # All talks for this date
_id = doc.get("_id")
text_body = (doc.get("anforandetext") or "").strip()
if not text_body:
print_yellow(f"Skipping empty talk {_id}")
continue
# Decide whether to use existing chunks or regenerate
arango_chunks: list[dict] = []
should_update_arango = False
if use_existing_chunks and doc.get("chunks"):
# Use existing chunks from ArangoDB
arango_chunks = doc["chunks"]
else:
# Generate new chunks from text
text_chunks = TextChunker(chunk_limit=500).chunk(text_body)
for chunk_index, content in enumerate(text_chunks):
if not content or not content.strip():
print_yellow(
f"Skipping empty chunk for talk {_id} index {chunk_index}"
)
continue
chunk_id = f"{_id}:{chunk_index}"
chunk_doc = {
"text": content,
"index": chunk_index,
"chroma_id": chunk_id,
"chroma_collecton": CHROMA_COLLECTION,
"debate": doc.get("debate"),
}
arango_chunks.append(chunk_doc)
# Mark that we need to update ArangoDB with new chunks
should_update_arango = True
# Process each chunk for ChromaDB
for chunk in arango_chunks:
chunk_id = chunk.get("chroma_id")
content = chunk.get("text")
# Skip if already exists and we're not replacing
if not replace_chromadb and chunk_id in existing_ids:
continue
# If replacing and exists, mark for deletion first
if replace_chromadb and chunk_id in existing_ids:
delete_ids.append(chunk_id)
ids.append(chunk_id)
documents.append(content)
metadatas.append(
{
"_id": _id,
"chunk_index": chunk.get("index"),
"debate": doc.get("debate"),
}
)
chunks_generated += 1
# Only update ArangoDB if we generated new chunks
if should_update_arango:
updated_docs.append(
{
"_key": doc["_key"],
"chunks": arango_chunks,
"debate": doc.get("debate"),
}
)
# Add to chroma 200 a time to avoid using too much memory
print_green(f"Adding {chunks_generated} chunks to ChromaDB {collection.name}...")
for i in range(0, len(ids), 200):
batch_ids = ids[i : i + 200]
batch_docs = documents[i : i + 200]
batch_metas = metadatas[i : i + 200]
# Delete old entries if replacing
if delete_ids:
batch_delete = delete_ids[:200]
delete_ids = delete_ids[200:]
collection.delete(ids=batch_delete)
if batch_ids:
collection.add(
ids=batch_ids,
documents=batch_docs,
metadatas=batch_metas,
)
else:
print_red("No new chunks to add in this batch.")
# Helper function to update ArangoDB in smaller batches to avoid HTTP 413 errors
def update_arango_in_batches(docs: list[dict], batch_size: int = 100) -> None:
"""
Update ArangoDB in batches to avoid 'Payload Too Large' errors.
Args:
docs: List of documents to update.
batch_size: Number of documents per batch.
"""
for i in range(0, len(docs), batch_size):
batch = docs[i : i + batch_size]
arango.db.collection("talks").update_many(batch, merge=False)
@backoff.on_exception(backoff.expo, Exception, max_tries=5)
def safe_update():
if updated_docs:
# Call the batch update helper instead of a single large update
update_arango_in_batches(updated_docs, batch_size=100)
safe_update()
return chunks_generated
def process_date(
date: str,
use_existing_chunks: bool,
replace_chromadb: bool,
use_local_imports: bool = False
) -> int:
"""
Worker function to process all talks for a given date.
If use_local_imports is True, re-import arango and chroma_db locally (for multiprocessing).
If False, uses the top-level imports (for sequential mode).
Returns the number of chunks processed for this date.
Args:
date (str): The date to process.
use_existing_chunks (bool): Whether to use existing chunks from ArangoDB.
replace_chromadb (bool): Whether to replace existing ChromaDB entries.
use_local_imports (bool): If True, re-import client objects for multiprocessing.
Returns:
int: Number of chunks processed for this date.
"""
import time
if use_local_imports:
# For multiprocessing: re-import client objects in each worker process.
print_yellow(f"[{date}] Worker starting: re-importing DB clients...")
from arango_client import arango as local_arango
from _chromadb.chroma_client import chroma_db as local_chroma_db
db = local_arango
chroma = local_chroma_db
time.sleep(random() * 2) # Stagger start times slightly
else:
# Use top-level imports for sequential mode
db = arango
chroma = chroma_db
print_blue(f"[{date}] Processing date...")
collection: Collection = chroma.get_collection(name=CHROMA_COLLECTION)
query = """
FOR doc IN talks
FILTER doc.datum == @date
SORT doc.anforande_nummer
RETURN {
_id: doc._id,
_key: doc._key,
anforandetext: doc.anforandetext,
anforande_nummer: doc.anforande_nummer,
replik: doc.replik,
datum: doc.datum,
chunks: doc.chunks
}
"""
cursor = db.db.aql.execute(query, bind_vars={"date": date})
docs = list(cursor)
print_purple(f"[{date}] Processing {len(docs)} talks...")
processed = process_batch(
docs,
collection,
use_existing_chunks=use_existing_chunks,
replace_chromadb=replace_chromadb,
)
print_green(f"[{date}] Worker finished: processed {processed} chunks.")
return processed
def main() -> None:
"""
Main function to build or update the talk_embeddings table with chunked embeddings.
Uses multiprocessing to process each date in parallel, unless --no-multiprocessing is set.
"""
parser = argparse.ArgumentParser(
description="Build or update talk_embeddings with chunked embeddings.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Generate fresh chunks and add only new ones to ChromaDB:
python -m scripts.build_embeddings
# Use existing chunks from ArangoDB and replace any duplicates in ChromaDB:
python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb
# Generate fresh chunks and replace everything in ChromaDB:
python -m scripts.build_embeddings --replace-chromadb
"""
)
parser.add_argument(
"--use-existing-chunks",
action="store_true",
help="Use chunks already stored in ArangoDB documents instead of regenerating them. "
"Must be combined with --replace-chromadb.",
)
parser.add_argument(
"--replace-chromadb",
action="store_true",
help="Replace existing ChromaDB entries with the same ID. "
"If not set, skip documents that already exist in ChromaDB.",
)
parser.add_argument(
"--no-multiprocessing",
action="store_true",
help="Run sequentially without multiprocessing.",
)
args = parser.parse_args()
# Validate argument combination
# if args.use_existing_chunks and not args.replace_chromadb:
# parser.error("--use-existing-chunks requires --replace-chromadb to be set")
print_purple("Connecting to ChromaDB...")
collection: Collection = chroma_db.get_collection(name=CHROMA_COLLECTION)
docs_in_chroma = set(collection.get(include=[])['ids'])
print_green(f"ChromaDB {collection.name} has {len(docs_in_chroma)} existing chunks.")
# FILTER doc.chunks != null
aql = """
FOR doc IN talks
FILTER doc.chunks == null
return {'datum': doc.datum, 'chroma_ids': doc.chunks[*].chroma_id}
"""
cursor = arango.db.aql.execute(aql, ttl=3600, count=True, batch_size=1000)
all_dates = []
processed_documents = 0
while True:
# Processera allt i klientbufferten
while not cursor.empty():
doc = cursor.pop() # pop tar från client-side buffer utan att trigga fetch
processed_documents += 1
chroma_ids = set(doc.get('chroma_ids', []))
if not chroma_ids.issubset(docs_in_chroma):
all_dates.append(doc['datum'])
elif 'chunks' not in doc:
all_dates.append(doc['datum'])
# Om server säger att det finns mer, hämta nästa batch från server
if cursor.has_more():
cursor.fetch() # hämtar nästa server-batch och fyller client-buffer
print(f"Fetched {processed_documents} documents so far; total reported: {cursor.count()}", end="\r")
continue
# Inget mer att hämta => sluta
break
all_dates = list(set(all_dates)) # Unika datum
print_green(f"Found {len(all_dates)} unique dates to process.")
# Set number of workers (processes) to 2
num_workers = 2
if args.no_multiprocessing:
print_purple("Running in sequential (no multiprocessing) mode...")
results = []
# Sort dates to process in chronological order
all_dates.sort()
for date in all_dates:
# Use top-level imports in sequential mode
result = process_date(
date,
use_existing_chunks=args.use_existing_chunks,
replace_chromadb=args.replace_chromadb,
use_local_imports=False
)
results.append(result)
else:
print_purple(f"Starting multiprocessing with {num_workers} workers...")
import multiprocessing
# Use 'spawn' start method to avoid issues with forked DB/network connections
ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=num_workers) as pool:
print_yellow("Processing dates in parallel...")
# Use local imports in multiprocessing mode
results = pool.starmap(
process_date,
[
(date, args.use_existing_chunks, args.replace_chromadb, True)
for date in all_dates
]
)
total_chunks = sum(results)
print_green(f"Färdig! Totalt {total_chunks} chunks bearbetade.")
if __name__ == "__main__":
main()
print_green("Done!")