You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
182 lines
6.7 KiB
182 lines
6.7 KiB
#!/usr/bin/env python3 |
|
""" |
|
Convert stored embeddings to plain Python lists for an existing Chroma collection. |
|
|
|
Usage: |
|
# Dry run (inspect first 50 ids) |
|
python scripts/convert_embeddings_to_lists.py --collection talks --limit 50 --dry-run |
|
|
|
# Full run (no dry run) |
|
python scripts/convert_embeddings_to_lists.py --collection talks |
|
|
|
Notes: |
|
- Run from your project root (same env you use to access chroma_db). |
|
- Back up chromadb_data before running. |
|
""" |
|
import argparse |
|
import json |
|
import os |
|
import time |
|
from pathlib import Path |
|
from typing import List |
|
import math |
|
import sys |
|
|
|
# Use the same imports/bootstrapping as you already have in your project |
|
# so the same chroma client and embedding function are loaded. |
|
# Adjust the import path if necessary. |
|
os.chdir("/home/lasse/riksdagen") |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
import numpy as np |
|
from _chromadb.chroma_client import chroma_db |
|
|
|
CHECKPOINT_DIR = Path("var/chroma_repair") |
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
def normalize_embedding(emb): |
|
""" |
|
Convert a single embedding to a plain Python list[float]. |
|
Accepts numpy arrays, array-likes, lists. |
|
""" |
|
# numpy ndarray |
|
if isinstance(emb, np.ndarray): |
|
return emb.tolist() |
|
# Some array-likes (pandas/other) may have tolist() |
|
if hasattr(emb, "tolist") and not isinstance(emb, list): |
|
try: |
|
return emb.tolist() |
|
except Exception: |
|
pass |
|
# If it's already a list of numbers, convert elements to float |
|
if isinstance(emb, list): |
|
return [float(x) for x in emb] |
|
# last resort: try iterating |
|
try: |
|
return [float(x) for x in emb] |
|
except Exception: |
|
raise ValueError("Cannot normalize embedding of type: %s" % type(emb)) |
|
|
|
def chunked_iter(iterable, n): |
|
it = iter(iterable) |
|
while True: |
|
chunk = [] |
|
try: |
|
for _ in range(n): |
|
chunk.append(next(it)) |
|
except StopIteration: |
|
pass |
|
if not chunk: |
|
break |
|
yield chunk |
|
|
|
def load_checkpoint(name): |
|
path = CHECKPOINT_DIR / f"{name}.json" |
|
if path.exists(): |
|
return json.load(path) |
|
return {"last_index": 0, "processed_ids": []} |
|
|
|
def save_checkpoint(name, data): |
|
path = CHECKPOINT_DIR / f"{name}.json" |
|
with open(path, "w") as f: |
|
json.dump(data, f) |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--collection", required=True, help="Chroma collection name (e.g. talks)") |
|
parser.add_argument("--batch", type=int, default=1000, help="Batch size for update (default 1000)") |
|
parser.add_argument("--dry-run", action="store_true", help="Dry run: don't write updates, just report") |
|
parser.add_argument("--limit", type=int, default=None, help="Limit total number of ids to process (for testing)") |
|
parser.add_argument("--checkpoint-name", default=None, help="Name for checkpoint file (defaults to collection name)") |
|
args = parser.parse_args() |
|
|
|
coll_name = args.collection |
|
checkpoint_name = args.checkpoint_name or coll_name |
|
|
|
print(f"Connecting to Chroma collection '{coll_name}'...") |
|
col = chroma_db.get_collection(coll_name) |
|
|
|
# Get the full list of ids. For 600k this should be okay to hold in memory, |
|
# but if you need a more streaming approach, tell me and I can adapt. |
|
all_info = col.get(include=[]) # may return {'ids': [...]} as in your env |
|
ids = list(all_info.get("ids", [])) |
|
total_ids = len(ids) |
|
if args.limit: |
|
ids = ids[: args.limit] |
|
total_process = len(ids) |
|
else: |
|
total_process = total_ids |
|
|
|
print(f"Found {total_ids} ids in collection; will process {total_process} ids (limit={args.limit})") |
|
|
|
# load checkpoint |
|
ck = load_checkpoint(checkpoint_name) |
|
start_index = ck.get("last_index", 0) |
|
print(f"Resuming at index {start_index}") |
|
|
|
# iterate in batches starting from last_index |
|
processed = 0 |
|
for i in range(start_index, total_process, args.batch): |
|
batch_ids = ids[i : i + args.batch] |
|
print(f"\nProcessing batch {i}..{i+len(batch_ids)-1} (count={len(batch_ids)})") |
|
|
|
# fetch full info for this batch (documents, metadatas, embeddings) |
|
# we only need embeddings for this repair, but include docs/meta for verification if you want |
|
try: |
|
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
|
except Exception as e: |
|
print("Error fetching batch:", e) |
|
# do a small retry after sleep |
|
time.sleep(2) |
|
items = col.get(ids=batch_ids, include=["embeddings", "documents", "metadatas"]) |
|
|
|
batch_embeddings = items.get("embeddings", []) |
|
# items.get("ids") should match batch_ids order; if not, align by ids |
|
ids_from_get = items.get("ids", batch_ids) |
|
if len(ids_from_get) != len(batch_ids): |
|
print("Warning: length mismatch between requested ids and returned ids") |
|
|
|
# Normalize embeddings |
|
normalized_embeddings = [] |
|
failed = False |
|
for idx, emb in enumerate(batch_embeddings): |
|
try: |
|
norm = normalize_embedding(emb) |
|
except Exception as e: |
|
print(f"Failed to normalize embedding for id {ids_from_get[idx]}: {e}") |
|
failed = True |
|
break |
|
normalized_embeddings.append(norm) |
|
|
|
if failed: |
|
print("Skipping this batch due to failures. You can adjust batch size and retry.") |
|
break |
|
|
|
# Dry-run: just print stats and continue |
|
if args.dry_run: |
|
# show a sample |
|
sample_i = min(3, len(normalized_embeddings)) |
|
print("Sample normalized embedding lengths:", [len(normalized_embeddings[k]) for k in range(sample_i)]) |
|
# Optionally inspect first few floats |
|
print("Sample values (first 6 floats):", [normalized_embeddings[k][:6] for k in range(sample_i)]) |
|
else: |
|
# Update the collection in place (update will upsert embeddings for given ids) |
|
try: |
|
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
|
except Exception as e: |
|
print("Update failed, retrying once after short sleep:", e) |
|
time.sleep(2) |
|
col.update(ids=ids_from_get, embeddings=normalized_embeddings) |
|
|
|
print(f"Updated {len(normalized_embeddings)} embeddings in collection '{coll_name}'") |
|
|
|
# checkpoint progress |
|
ck["last_index"] = i + len(batch_ids) |
|
save_checkpoint(checkpoint_name, ck) |
|
processed += len(batch_ids) |
|
|
|
print(f"\nDone. Processed {processed} ids. Checkpoint saved to {CHECKPOINT_DIR / (checkpoint_name + '.json')}") |
|
print("Reminder: run a few queries to validate search quality.") |
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|