You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
173 lines
6.0 KiB
173 lines
6.0 KiB
import os |
|
import sys |
|
import logging |
|
|
|
# Silence the per-request HTTP logs from the ollama/httpx library |
|
logging.getLogger("httpx").setLevel(logging.WARNING) |
|
|
|
os.chdir("/home/lasse/riksdagen") |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
from arango_client import arango |
|
from ollama import Client as Ollama |
|
from arango.collection import Collection |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
from typing import List, Dict |
|
from time import sleep |
|
from utils import TextChunker |
|
|
|
|
|
def make_embeddings(texts: List[str]) -> List[List[float]]: |
|
""" |
|
Generate embeddings for a list of texts using Ollama. |
|
|
|
Args: |
|
texts (List[str]): List of text strings to embed. |
|
|
|
Returns: |
|
List[List[float]]: List of embedding vectors. |
|
""" |
|
ollama_client = Ollama(host='192.168.1.12:33405') |
|
embeddings = ollama_client.embed( |
|
model="qwen3-embedding:latest", |
|
input=texts, |
|
dimensions=384, |
|
) |
|
return embeddings.embeddings |
|
|
|
|
|
def process_chunk_batch(chunk_batch: List[Dict]) -> List[Dict]: |
|
""" |
|
Generate embeddings for a batch of chunks and attach them. |
|
|
|
Args: |
|
chunk_batch (List[Dict]): List of chunk dicts, each with a 'text' field. |
|
|
|
Returns: |
|
List[Dict]: Same list with an 'embedding' field added to each dict. |
|
""" |
|
sleep(1) |
|
texts = [chunk['text'] for chunk in chunk_batch] |
|
embeddings = make_embeddings(texts) |
|
for i, chunk in enumerate(chunk_batch): |
|
chunk['embedding'] = embeddings[i] |
|
return chunk_batch |
|
|
|
|
|
def make_arango_embeddings() -> int: |
|
""" |
|
Chunks and embeds all talks that are not yet represented in the 'chunks' collection. |
|
|
|
For each talk that has no chunks in the collection yet: |
|
- If the talk document already has a 'chunks' field (legacy path), those are used. |
|
- Otherwise the speech text is split into chunks using TextChunker. |
|
Embedding vectors are generated via Ollama and stored in the 'chunks' collection. |
|
|
|
Each chunk document in ArangoDB has: |
|
_key : "{talk_key}:{chunk_index}" (unique within the collection) |
|
text : the chunk text |
|
index : chunk index within the talk |
|
parent_id : "talks/{talk_key}" (links back to the source talk) |
|
collection: "talks" |
|
embedding : the vector (list of floats) |
|
|
|
Returns: |
|
int: Total number of chunk documents inserted/updated. |
|
""" |
|
if not arango.db.has_collection("chunks"): |
|
chunks_collection: Collection = arango.db.create_collection("chunks") |
|
else: |
|
chunks_collection: Collection = arango.db.collection("chunks") |
|
|
|
# Find every talk that has no entry yet in the chunks collection. |
|
# The inner FOR loop returns [] if no match exists (acts as NOT EXISTS). |
|
cursor = arango.db.aql.execute( |
|
""" |
|
FOR p IN talks |
|
FILTER p.anforandetext != null AND p.anforandetext != "" |
|
FILTER ( |
|
FOR c IN chunks |
|
FILTER c.parent_id == p._id |
|
LIMIT 1 |
|
RETURN 1 |
|
) == [] |
|
RETURN { |
|
_key: p._key, |
|
_id: p._id, |
|
anforandetext: p.anforandetext, |
|
chunks: p.chunks |
|
} |
|
""", |
|
batch_size=1000, |
|
ttl=360, |
|
) |
|
|
|
n = 0 |
|
embed_batch_size = 20 # Number of chunks per Ollama call |
|
chunk_batches: List[List[Dict]] = [] |
|
|
|
for talk in cursor: |
|
talk_key = talk["_key"] |
|
parent_id = f"talks/{talk_key}" |
|
|
|
if talk.get("chunks"): |
|
# Legacy path: chunks were previously generated and stored on the talk document. |
|
# Strip out the old ChromaDB-specific fields and assign a proper _key. |
|
_chunks = [] |
|
for chunk in talk["chunks"]: |
|
idx = chunk.get("index", 0) |
|
_chunks.append({ |
|
"_key": f"{talk_key}:{idx}", |
|
"text": chunk["text"], |
|
"index": idx, |
|
"parent_id": parent_id, |
|
"collection": "talks", |
|
}) |
|
else: |
|
# New path: chunk the speech text directly with TextChunker. |
|
text = (talk.get("anforandetext") or "").strip() |
|
text_chunks = TextChunker(chunk_limit=500).chunk(text) |
|
_chunks = [ |
|
{ |
|
"_key": f"{talk_key}:{idx}", |
|
"text": content, |
|
"index": idx, |
|
"parent_id": parent_id, |
|
"collection": "talks", |
|
} |
|
for idx, content in enumerate(text_chunks) |
|
if content and content.strip() |
|
] |
|
|
|
# Split into batches for embedding |
|
for i in range(0, len(_chunks), embed_batch_size): |
|
batch = _chunks[i : i + embed_batch_size] |
|
if batch: |
|
chunk_batches.append(batch) |
|
|
|
# Embed all batches in parallel (Ollama calls are I/O-bound, threads are fine) |
|
total_batches = len(chunk_batches) |
|
completed_batches = 0 |
|
with ThreadPoolExecutor(max_workers=3) as executor: |
|
futures = [executor.submit(process_chunk_batch, batch) for batch in chunk_batches] |
|
processed_chunks: List[Dict] = [] |
|
for future in as_completed(futures): |
|
result = future.result() |
|
completed_batches += 1 |
|
processed_chunks.extend(result) |
|
print(f"Embedding batches: {completed_batches}/{total_batches} | chunks ready to insert: {len(processed_chunks)}", end="\r") |
|
# Insert in batches of 100 to keep HTTP payloads small |
|
if len(processed_chunks) >= 100: |
|
n += len(processed_chunks) |
|
chunks_collection.insert_many(processed_chunks, overwrite=True) |
|
processed_chunks = [] |
|
if processed_chunks: |
|
n += len(processed_chunks) |
|
chunks_collection.insert_many(processed_chunks, overwrite=True) |
|
|
|
print(f"\nDone. Inserted/updated {n} chunks in ArangoDB.") |
|
return n |
|
|
|
|
|
if __name__ == "__main__": |
|
make_arango_embeddings()
|
|
|