You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
232 lines
8.6 KiB
232 lines
8.6 KiB
import os |
|
import numpy as np |
|
from time import sleep |
|
import sys |
|
from pathlib import Path |
|
from typing import Iterable, Mapping, List, Any |
|
from chromadb import Collection |
|
import backoff |
|
from _llm import LLM |
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
|
# Set /home/lasse/riksdagen as working directory |
|
os.chdir("/home/lasse/riksdagen") |
|
# Add the project root to Python path to locate local modules |
|
sys.path.append("/home/lasse/riksdagen") |
|
|
|
from arango_client import arango |
|
from colorprinter import * |
|
|
|
|
|
from scripts.build_embeddings import assign_debate_ids |
|
|
|
|
|
def make_debate_ids(): |
|
q = """ |
|
FOR d IN talks |
|
FILTER d.debate == null |
|
RETURN DISTINCT d.datum |
|
""" |
|
|
|
all_dates = list(arango.db.aql.execute(q)) |
|
all_dates.sort() |
|
print(f"Found {len(all_dates)} unique dates with talks missing debate ids") |
|
updates_docs = [] |
|
for date in all_dates: |
|
talks = arango.db.aql.execute( |
|
f""" |
|
FOR d IN talks |
|
FILTER d.datum == @date |
|
return {{"_key": d._key, "datum": d.datum, "replik": d.replik}} |
|
""", |
|
bind_vars={"date": date}, |
|
) |
|
print(date, len(updates_docs), end="\r") |
|
|
|
docs = assign_debate_ids(list(talks), date) |
|
for doc in docs: |
|
updates_docs.append({"_key": doc["_key"], "debate": doc["debate"]}) |
|
if len(updates_docs) > 1000: |
|
arango.db.collection("talks").update_many( |
|
updates_docs, |
|
raise_on_document_error=True, |
|
return_new=False, |
|
silent=True, |
|
) |
|
updates_docs = [] |
|
arango.db.collection("talks").update_many(updates_docs) |
|
|
|
|
|
def summarize_talk(talk: str, llm) -> tuple[str, LLM]: |
|
llm = LLM() |
|
talare = talk['talare'] |
|
party = talk['parti'] |
|
text = talk['anforandetext'] |
|
if talk['replik']: |
|
prompt = f""" |
|
Nedan är ett tal från en debatt i Sveriges riksdag. Sammanfatta talet kort och koncist på svenska, fokusera på de viktigaste argumenten och sakförhållandena som framförs. |
|
Talet är ett svar på ett föregående tal. När du sammanfattar, se till att den går att förstå även utan att ha läst det föregående talet. Inkludera däremot INTE information eller argument från det föregående talet. |
|
Talaren är {talare} från {party}. Börja gärna sammanfattningen med "Namn (Parti) ..." för att tydligt ange vem som talar. |
|
--- |
|
{text} |
|
--- |
|
|
|
Svara _enbart_ med sammanfattningen, inga andra kommentarer eller förklaringar. |
|
""" |
|
|
|
else: |
|
prompt = f""" |
|
Nedan är ett tal från en debatt i Sveriges riksdag. Sammanfatta talet kort och koncist på svenska, fokusera på de viktigaste argumenten och sakförhållandena som framförs. |
|
Talaren är {talare} från {party}. Börja gärna sammanfattningen med "Namn (Parti) ..." för att tydligt ange vem som talar. |
|
--- |
|
{talk} |
|
--- |
|
|
|
Svara **enbart** med sammanfattningen, inga andra kommentarer eller förklaringar. |
|
""" |
|
|
|
response = llm.generate(query=prompt) |
|
return response.content.strip(), llm |
|
|
|
|
|
def process_debate_date(date: str, system_message: str) -> None: |
|
""" |
|
Processes all debates for a given date: summarizes each talk and the debate. |
|
Args: |
|
date (str): The date to process. |
|
system_message (str): The system message for the LLM. |
|
Returns: |
|
None |
|
""" |
|
# Fetch all debates for this date |
|
debates = list(arango.db.aql.execute( |
|
""" |
|
FOR doc IN talks |
|
FILTER doc.datum == @date && doc.summary == null |
|
SORT doc.debate ASC |
|
RETURN DISTINCT doc.debate |
|
""", |
|
ttl=300, |
|
bind_vars={"date": date}, |
|
)) |
|
for debate in debates: |
|
llm = LLM(model="vllm", temperature=0.1, system_message=system_message) |
|
# Fetch the talks in this debate |
|
talks = arango.db.aql.execute( |
|
""" |
|
FOR doc IN talks OPTIONS { indexHint: "debates_index", } |
|
FILTER doc.debate == @debate |
|
SORT doc.anforande_nummer ASC |
|
return { |
|
"_id": doc._id, |
|
"anforandetext": doc.anforandetext, |
|
"anforande_nummer": doc.anforande_nummer, |
|
"datum": doc.datum, |
|
"replik": doc.replik, |
|
"talare": doc.talare, |
|
"parti": doc.parti, |
|
} |
|
""", |
|
ttl=300, |
|
bind_vars={"debate": debate}, |
|
) |
|
talks = list(talks) |
|
print(f"Processing debate {debate} with {len(talks)} talks") |
|
updates = [] |
|
summaries = [] |
|
for talk in talks: |
|
print_blue(talk['_id']) |
|
if "summary" in talk and talk["summary"]: |
|
print( |
|
f" Talk {talk['anforande_nummer']} already has summary, skipping" |
|
) |
|
continue |
|
summary, _ = summarize_talk(talk, llm) |
|
summaries.append(f"{talk['talare']} ({talk['parti']}):\n{summary}") |
|
print(f" Talk {talk['anforande_nummer']} summary: {summary}") |
|
updates.append({"_id": talk["_id"], "summary": summary}) |
|
arango.db.collection('talks').update_many(updates) |
|
updates = [] |
|
if len(talks) == 1: |
|
print_yellow( |
|
f"Debate {debate} has only one talk, skipping debate summary" |
|
) |
|
continue |
|
summaries_string = "\n---\n".join(summaries) |
|
prompt = f""" |
|
Tack! Nu ska du sammanfatta hela debatten baserat på de enskilda sammanfattningarna av varje tal nedan. |
|
Fokusera på de viktigaste argumenten och sakförhållandena som framförs i debatten. |
|
Sammanfattningen ska vara koncis och informativ, och skriven på svenska. |
|
|
|
Här är sammanfattningarna av de enskilda talen i debatten: |
|
''' |
|
{summaries_string} |
|
''' |
|
Svara så att det framgår vad debatten handlade om och vilka de viktigaste argumenten var, samt vilka ståndpunkter de olika partierna hade. |
|
Svara i löpande text utan någon avanderad formatering. Exempel: |
|
|
|
''' |
|
Debatten handlade om ... |
|
De viktigaste argumenten som framfördes var ... |
|
Partierna hade följande ståndpunkter: ... |
|
S: ... |
|
M: ... |
|
... |
|
''' |
|
|
|
Svara **enbart** med sammanfattningen, inga andra kommentarer eller förklaringar. |
|
""" |
|
debate_summary = llm.generate(query=prompt).content.strip() |
|
print_green(f"Debate summary:\n{debate_summary}") |
|
arango.db.collection("debates").insert( |
|
{ |
|
"_key": debate, |
|
"debate": debate, |
|
"summary": debate_summary, |
|
"num_talks": len(talks), |
|
"talk_summaries": summaries, |
|
"talk_ids": [talk["_id"] for talk in talks], |
|
"datum": talks[0]["datum"] if talks else None, |
|
}, |
|
overwrite=True, |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
system_message = """Din uppgift är att sammanfatta debatter i Sveriges riksdag. |
|
Du kommer först att få enskilda tal som du ska sammanfatta var för sig, efter det ska du sammanfatta hela debatten. |
|
Sammanfattningarna ska vara på svenska och vara koncisa och informativa. |
|
Det är viktigt att du förstår vad som är kärnan i varje tal och debatt, fokusera därför på de argument och sakförhållanden som framförs. |
|
""" |
|
# Get all unique dates with talks missing summary |
|
while True: |
|
all_dates: list[str] = list(arango.db.aql.execute( |
|
""" |
|
FOR doc IN talks |
|
FILTER doc.summary == null |
|
RETURN DISTINCT doc.datum |
|
""", |
|
ttl=300, |
|
)) |
|
all_dates.sort() |
|
if len(all_dates) == 0: |
|
print_green("All talks have summaries, sleeping for 15 minutes") |
|
sleep(60*60*24) # Sleep for a day |
|
continue |
|
print(f"Found {len(all_dates)} unique dates to process.") |
|
# Use ProcessPoolExecutor to process each date in parallel |
|
with ProcessPoolExecutor(max_workers=4) as executor: |
|
errors = 0 |
|
futures = {executor.submit(process_debate_date, date, system_message): date for date in all_dates} |
|
for future in as_completed(futures): |
|
date = futures[future] |
|
if errors > 20: |
|
sleep(60*10) |
|
try: |
|
future.result() |
|
print_green(f"Finished processing date {date}") |
|
errors = 0 |
|
except Exception as exc: |
|
errors += 1 |
|
print_red(f"Error processing date {date}: {exc}") |
|
sleep(60*15) |