You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
447 lines
16 KiB
447 lines
16 KiB
""" |
|
Skapar eller uppdaterar talk_embeddings-tabellen med chunkade embeddings. |
|
Kör: python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb |
|
""" |
|
|
|
import os |
|
import numpy as np |
|
import sys |
|
import argparse |
|
from pathlib import Path |
|
from typing import Iterable, Mapping, List, Any |
|
from chromadb import Collection |
|
from time import sleep |
|
from random import random |
|
import backoff |
|
# Set /home/lasse/riksdagen as working directory |
|
os.chdir("/home/lasse/riksdagen") |
|
# Add the project root to Python path to locate local modules |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
from arango_client import arango |
|
from utils import TextChunker |
|
from _chromadb.chroma_client import chroma_db |
|
from colorprinter import print_green, print_red, print_yellow, print_purple, print_blue |
|
|
|
CHROMA_COLLECTION = "talks" |
|
CHROMA_PATH = "/home/lasse/riksdagen/chromadb_data" |
|
|
|
|
|
def get_existing_chunk_ids(collection, _ids: List[str]) -> set: |
|
""" |
|
Fetch all existing chunk IDs for the given talk IDs in one query. |
|
Returns a set of chunk_id strings for fast O(1) lookup. |
|
""" |
|
try: |
|
# Get all documents where metadata._id is in our batch |
|
# ChromaDB might return error if collection is empty, so handle that |
|
result = collection.get( |
|
where={"_id": {"$in": _ids}}, |
|
include=[], # We only need IDs, not documents/embeddings |
|
) |
|
return set(result.get("ids", [])) |
|
except Exception as e: |
|
print_red(f"Warning: Could not fetch existing IDs: {e}") |
|
return set() |
|
|
|
|
|
def get_all_dates() -> List[str]: |
|
""" |
|
Return a sorted list of unique 'datum' values found in the 'talks' collection. |
|
|
|
This uses AQL COLLECT to produce unique dates so we iterate date-by-date later. |
|
""" |
|
query = """ |
|
FOR doc IN talks |
|
COLLECT datum = doc.datum |
|
RETURN datum |
|
""" |
|
from arango_client import arango |
|
|
|
cursor = arango.db.aql.execute(query) |
|
return list(cursor) |
|
|
|
|
|
def embedding_exists(collection, chunk_id: str) -> bool: |
|
""" |
|
Check if an embedding with the given chunk_id already exists in the collection. |
|
""" |
|
res = collection.get(ids=[chunk_id]) |
|
return bool(res and res.get("ids")) |
|
|
|
|
|
def assign_debate_ids( |
|
docs: List[Mapping[str, Any]], date: str |
|
) -> List[Mapping[str, Any]]: |
|
""" |
|
Assigns a debate id to each talk in a list of talks for a given date. |
|
A new debate starts when replik == False. |
|
The debate id is a string: "{date}:{debate_index}". |
|
Returns a new list of docs with a 'debate' field added. |
|
""" |
|
debate_index = 0 |
|
current_debate_id = f"{date}:{debate_index}" |
|
updated_docs = [] |
|
for doc in docs: |
|
# If this talk is not a reply, start a new debate |
|
if not doc.get("replik", False): |
|
debate_index += 1 |
|
current_debate_id = f"{date}:{debate_index}" |
|
# Add the debate id to the doc |
|
doc_with_debate = dict(doc) |
|
doc_with_debate["debate"] = current_debate_id |
|
updated_docs.append(doc_with_debate) |
|
|
|
return updated_docs |
|
|
|
|
|
def process_batch( |
|
arango_docs: Iterable[Mapping[str, object]], |
|
collection: Collection, |
|
use_existing_chunks: bool = False, |
|
replace_chromadb: bool = False, |
|
) -> int: |
|
""" |
|
Processes a batch of ArangoDB documents (talks), chunks their text, and adds embeddings. |
|
Each talk is assigned a debate id. |
|
Ensures that only unique chunk IDs are added to ChromaDB to avoid DuplicateIDError. |
|
|
|
Args: |
|
arango_docs: Iterable of ArangoDB documents (talks). |
|
collection: ChromaDB collection object. |
|
use_existing_chunks: If True, use chunks stored in ArangoDB docs (if available). |
|
If False, always regenerate chunks from text. |
|
replace_chromadb: If True, replace existing ChromaDB entries with same ID. |
|
If False, skip documents that already exist in ChromaDB. |
|
Note: Must be True when use_existing_chunks is True. |
|
|
|
Returns: |
|
int: Number of chunks generated and added. |
|
""" |
|
print_blue(f"Starting process_batch for {len(arango_docs)} docs...") |
|
arango_docs = list(arango_docs) # Convert to list to iterate twice |
|
|
|
# Extract date from the first doc (all docs in batch have the same date) |
|
date = arango_docs[0].get("datum") if arango_docs else None |
|
if date: |
|
# Assign debate ids to all docs in this batch |
|
arango_docs = assign_debate_ids(arango_docs, date) |
|
|
|
# Fetch all existing chunk IDs for this batch at once (unless we're replacing everything) |
|
existing_ids = set() |
|
if not replace_chromadb: |
|
_ids = [row.get("_id") for row in arango_docs] |
|
existing_ids = get_existing_chunk_ids(collection, _ids) |
|
|
|
ids: List[str] = [] |
|
documents: List[str] = [] |
|
metadatas: List[Mapping[str, Any]] = [] |
|
|
|
chunks_generated = 0 |
|
updated_docs = [] |
|
delete_ids = [] |
|
|
|
for doc in arango_docs: # All talks for this date |
|
_id = doc.get("_id") |
|
text_body = (doc.get("anforandetext") or "").strip() |
|
if not text_body: |
|
print_yellow(f"Skipping empty talk {_id}") |
|
continue |
|
|
|
# Decide whether to use existing chunks or regenerate |
|
arango_chunks: list[dict] = [] |
|
should_update_arango = False |
|
|
|
if use_existing_chunks and doc.get("chunks"): |
|
# Use existing chunks from ArangoDB |
|
arango_chunks = doc["chunks"] |
|
else: |
|
# Generate new chunks from text |
|
text_chunks = TextChunker(chunk_limit=500).chunk(text_body) |
|
for chunk_index, content in enumerate(text_chunks): |
|
if not content or not content.strip(): |
|
print_yellow( |
|
f"Skipping empty chunk for talk {_id} index {chunk_index}" |
|
) |
|
continue |
|
chunk_id = f"{_id}:{chunk_index}" |
|
chunk_doc = { |
|
"text": content, |
|
"index": chunk_index, |
|
"chroma_id": chunk_id, |
|
"chroma_collecton": CHROMA_COLLECTION, |
|
"debate": doc.get("debate"), |
|
} |
|
arango_chunks.append(chunk_doc) |
|
|
|
# Mark that we need to update ArangoDB with new chunks |
|
should_update_arango = True |
|
|
|
# Process each chunk for ChromaDB |
|
for chunk in arango_chunks: |
|
chunk_id = chunk.get("chroma_id") |
|
content = chunk.get("text") |
|
|
|
# Skip if already exists and we're not replacing |
|
if not replace_chromadb and chunk_id in existing_ids: |
|
continue |
|
|
|
# If replacing and exists, mark for deletion first |
|
if replace_chromadb and chunk_id in existing_ids: |
|
delete_ids.append(chunk_id) |
|
|
|
ids.append(chunk_id) |
|
documents.append(content) |
|
metadatas.append( |
|
{ |
|
"_id": _id, |
|
"chunk_index": chunk.get("index"), |
|
"debate": doc.get("debate"), |
|
} |
|
) |
|
chunks_generated += 1 |
|
|
|
# Only update ArangoDB if we generated new chunks |
|
if should_update_arango: |
|
updated_docs.append( |
|
{ |
|
"_key": doc["_key"], |
|
"chunks": arango_chunks, |
|
"debate": doc.get("debate"), |
|
} |
|
) |
|
|
|
# Add to chroma 200 a time to avoid using too much memory |
|
print_green(f"Adding {chunks_generated} chunks to ChromaDB {collection.name}...") |
|
for i in range(0, len(ids), 200): |
|
batch_ids = ids[i : i + 200] |
|
batch_docs = documents[i : i + 200] |
|
batch_metas = metadatas[i : i + 200] |
|
|
|
# Delete old entries if replacing |
|
if delete_ids: |
|
batch_delete = delete_ids[:200] |
|
delete_ids = delete_ids[200:] |
|
collection.delete(ids=batch_delete) |
|
|
|
if batch_ids: |
|
collection.add( |
|
ids=batch_ids, |
|
documents=batch_docs, |
|
metadatas=batch_metas, |
|
) |
|
else: |
|
print_red("No new chunks to add in this batch.") |
|
|
|
# Helper function to update ArangoDB in smaller batches to avoid HTTP 413 errors |
|
def update_arango_in_batches(docs: list[dict], batch_size: int = 100) -> None: |
|
""" |
|
Update ArangoDB in batches to avoid 'Payload Too Large' errors. |
|
|
|
Args: |
|
docs: List of documents to update. |
|
batch_size: Number of documents per batch. |
|
""" |
|
for i in range(0, len(docs), batch_size): |
|
batch = docs[i : i + batch_size] |
|
arango.db.collection("talks").update_many(batch, merge=False) |
|
|
|
@backoff.on_exception(backoff.expo, Exception, max_tries=5) |
|
def safe_update(): |
|
if updated_docs: |
|
# Call the batch update helper instead of a single large update |
|
update_arango_in_batches(updated_docs, batch_size=100) |
|
|
|
safe_update() |
|
return chunks_generated |
|
|
|
|
|
def process_date( |
|
date: str, |
|
use_existing_chunks: bool, |
|
replace_chromadb: bool, |
|
use_local_imports: bool = False |
|
) -> int: |
|
""" |
|
Worker function to process all talks for a given date. |
|
If use_local_imports is True, re-import arango and chroma_db locally (for multiprocessing). |
|
If False, uses the top-level imports (for sequential mode). |
|
Returns the number of chunks processed for this date. |
|
|
|
Args: |
|
date (str): The date to process. |
|
use_existing_chunks (bool): Whether to use existing chunks from ArangoDB. |
|
replace_chromadb (bool): Whether to replace existing ChromaDB entries. |
|
use_local_imports (bool): If True, re-import client objects for multiprocessing. |
|
Returns: |
|
int: Number of chunks processed for this date. |
|
""" |
|
import time |
|
|
|
if use_local_imports: |
|
# For multiprocessing: re-import client objects in each worker process. |
|
print_yellow(f"[{date}] Worker starting: re-importing DB clients...") |
|
from arango_client import arango as local_arango |
|
from _chromadb.chroma_client import chroma_db as local_chroma_db |
|
|
|
db = local_arango |
|
chroma = local_chroma_db |
|
time.sleep(random() * 2) # Stagger start times slightly |
|
else: |
|
# Use top-level imports for sequential mode |
|
db = arango |
|
chroma = chroma_db |
|
|
|
print_blue(f"[{date}] Processing date...") |
|
collection: Collection = chroma.get_collection(name=CHROMA_COLLECTION) |
|
|
|
query = """ |
|
FOR doc IN talks |
|
FILTER doc.datum == @date |
|
SORT doc.anforande_nummer |
|
RETURN { |
|
_id: doc._id, |
|
_key: doc._key, |
|
anforandetext: doc.anforandetext, |
|
anforande_nummer: doc.anforande_nummer, |
|
replik: doc.replik, |
|
datum: doc.datum, |
|
chunks: doc.chunks |
|
} |
|
""" |
|
cursor = db.db.aql.execute(query, bind_vars={"date": date}) |
|
docs = list(cursor) |
|
print_purple(f"[{date}] Processing {len(docs)} talks...") |
|
processed = process_batch( |
|
docs, |
|
collection, |
|
use_existing_chunks=use_existing_chunks, |
|
replace_chromadb=replace_chromadb, |
|
) |
|
print_green(f"[{date}] Worker finished: processed {processed} chunks.") |
|
return processed |
|
|
|
|
|
def main() -> None: |
|
""" |
|
Main function to build or update the talk_embeddings table with chunked embeddings. |
|
Uses multiprocessing to process each date in parallel, unless --no-multiprocessing is set. |
|
""" |
|
|
|
parser = argparse.ArgumentParser( |
|
description="Build or update talk_embeddings with chunked embeddings.", |
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
epilog=""" |
|
Examples: |
|
# Generate fresh chunks and add only new ones to ChromaDB: |
|
python -m scripts.build_embeddings |
|
|
|
# Use existing chunks from ArangoDB and replace any duplicates in ChromaDB: |
|
python -m scripts.build_embeddings --use-existing-chunks --replace-chromadb |
|
|
|
# Generate fresh chunks and replace everything in ChromaDB: |
|
python -m scripts.build_embeddings --replace-chromadb |
|
""" |
|
) |
|
parser.add_argument( |
|
"--use-existing-chunks", |
|
action="store_true", |
|
help="Use chunks already stored in ArangoDB documents instead of regenerating them. " |
|
"Must be combined with --replace-chromadb.", |
|
) |
|
parser.add_argument( |
|
"--replace-chromadb", |
|
action="store_true", |
|
help="Replace existing ChromaDB entries with the same ID. " |
|
"If not set, skip documents that already exist in ChromaDB.", |
|
) |
|
parser.add_argument( |
|
"--no-multiprocessing", |
|
action="store_true", |
|
help="Run sequentially without multiprocessing.", |
|
) |
|
args = parser.parse_args() |
|
|
|
# Validate argument combination |
|
# if args.use_existing_chunks and not args.replace_chromadb: |
|
# parser.error("--use-existing-chunks requires --replace-chromadb to be set") |
|
|
|
print_purple("Connecting to ChromaDB...") |
|
collection: Collection = chroma_db.get_collection(name=CHROMA_COLLECTION) |
|
docs_in_chroma = set(collection.get(include=[])['ids']) |
|
print_green(f"ChromaDB {collection.name} has {len(docs_in_chroma)} existing chunks.") |
|
|
|
# FILTER doc.chunks != null |
|
aql = """ |
|
FOR doc IN talks |
|
FILTER doc.chunks == null |
|
return {'datum': doc.datum, 'chroma_ids': doc.chunks[*].chroma_id} |
|
""" |
|
|
|
cursor = arango.db.aql.execute(aql, ttl=3600, count=True, batch_size=1000) |
|
all_dates = [] |
|
processed_documents = 0 |
|
while True: |
|
# Processera allt i klientbufferten |
|
while not cursor.empty(): |
|
doc = cursor.pop() # pop tar från client-side buffer utan att trigga fetch |
|
processed_documents += 1 |
|
chroma_ids = set(doc.get('chroma_ids', [])) |
|
if not chroma_ids.issubset(docs_in_chroma): |
|
all_dates.append(doc['datum']) |
|
elif 'chunks' not in doc: |
|
all_dates.append(doc['datum']) |
|
|
|
# Om server säger att det finns mer, hämta nästa batch från server |
|
if cursor.has_more(): |
|
cursor.fetch() # hämtar nästa server-batch och fyller client-buffer |
|
print(f"Fetched {processed_documents} documents so far; total reported: {cursor.count()}", end="\r") |
|
continue |
|
|
|
# Inget mer att hämta => sluta |
|
break |
|
|
|
all_dates = list(set(all_dates)) # Unika datum |
|
print_green(f"Found {len(all_dates)} unique dates to process.") |
|
|
|
# Set number of workers (processes) to 2 |
|
num_workers = 2 |
|
|
|
if args.no_multiprocessing: |
|
print_purple("Running in sequential (no multiprocessing) mode...") |
|
results = [] |
|
# Sort dates to process in chronological order |
|
all_dates.sort() |
|
for date in all_dates: |
|
# Use top-level imports in sequential mode |
|
result = process_date( |
|
date, |
|
use_existing_chunks=args.use_existing_chunks, |
|
replace_chromadb=args.replace_chromadb, |
|
use_local_imports=False |
|
) |
|
results.append(result) |
|
else: |
|
print_purple(f"Starting multiprocessing with {num_workers} workers...") |
|
import multiprocessing |
|
|
|
# Use 'spawn' start method to avoid issues with forked DB/network connections |
|
ctx = multiprocessing.get_context("spawn") |
|
with ctx.Pool(processes=num_workers) as pool: |
|
print_yellow("Processing dates in parallel...") |
|
# Use local imports in multiprocessing mode |
|
results = pool.starmap( |
|
process_date, |
|
[ |
|
(date, args.use_existing_chunks, args.replace_chromadb, True) |
|
for date in all_dates |
|
] |
|
) |
|
|
|
total_chunks = sum(results) |
|
print_green(f"Färdig! Totalt {total_chunks} chunks bearbetade.") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
print_green("Done!") |