Compare commits

..

1 Commits

Author SHA1 Message Date
Lasse Studion 7202fe6cd1 Want to add possibility to ask question as user_input 2 years ago
  1. 10
      .gitignore
  2. 80
      arango_things.py
  3. 48
      llama_server.py
  4. 7
      notes.md
  5. 76
      requirements_streamlit.txt
  6. 514
      streamlit_app_talking_ep.py
  7. 8
      streamlit_info.py
  8. 313
      things.py

10
.gitignore vendored

@ -1,11 +1,9 @@
*
!streamlit_info.py
!download_debates.py
!translate_speeches.py
!arango_things.py
!things.py
!arango_things
!things
!streamlit_app_talking_ep.py
!.gitignore
!notes.md
!llama_server.py
!requirements_streamlit.txt
!streamlit_info.py
!notes.md

@ -1,80 +0,0 @@
from arango import ArangoClient, exceptions
import pandas as pd
import yaml
def get_documents(query=False, collection=False, fields=[], filter = '', df=False, index=False, field_names=False, limit=False):
"""
This function retrieves documents from a specified collection or based on a query in ArangoDB.
Parameters:
query (str): If specified this will be the query. Defaults to False.
collection (str): The name of the collection from which to retrieve documents. Defaults to False.
fields (list): The fields of the documents to retrieve. If empty, all fields are retrieved.
filter (str): AQL filter to apply to the retrieval. Defaults to no filter.
df (bool): If True, the result is returned as a pandas DataFrame. Defaults to False.
index (bool): If True and df is True, the DataFrame is indexed. Defaults to False.
field_names (dict): If provided, these field names will replace the original field names in the result.
Returns:
list or DataFrame: The retrieved documents as a list of dictionaries or a DataFrame.
"""
if query:
pass
else:
if fields == []:
return_fields = 'doc'
else:
fields_dict = {}
for field in fields:
fields_dict[field] = field
if field_names:
for k, v in field_names.items():
fields_dict[k] = v
fields_list = [f'{v}: doc.{k}' for k, v in fields_dict.items()]
fields_string = ', '.join(fields_list)
return_fields = f"{{{fields_string}}}"
if filter != None and filter != '':
filter = f'filter {filter}'
else:
filter = ''
if limit:
limit = f'limit {limit}'
else:
limit = ''
query = f'''
for doc
in {collection}
{filter}
{limit}
return {return_fields}
'''
try:
cursor = arango_db.aql.execute(query)
except exceptions.AQLQueryExecuteError:
print('ERROR:\n', query)
exit()
result = [i for i in cursor]
if df:
result = pd.DataFrame(result)
if index:
result.set_index(index, inplace=True)
return result
with open('config.yml', 'r') as f:
config = yaml.safe_load(f)
db = config['arango']['db']
username = config['arango']['username']
pwd = config['arango']['pwd_lasse']
# Initialize the database for ArangoDB.
client = ArangoClient(hosts=config['arango']['hosts'])
arango_db = client.db(db, username=username, password=pwd)

@ -1,48 +0,0 @@
import requests
class LLM():
def __init__(self, system_prompt=None, temperature=0.8, max_new_tokens=1000):
"""
Initializes the LLM class with the given parameters.
Args:
system_prompt (str, optional): The system prompt to use. Defaults to "Be precise and keep to the given information.".
temperature (float, optional): The temperature to use for generating new tokens. Defaults to 0.8.
max_new_tokens (int, optional): The maximum number of new tokens to generate. Defaults to 1000.
"""
self.temperature=temperature
self.max_new_tokens=max_new_tokens
if system_prompt is None:
self.system_prompt="Be precise and keep to the given information."
else:
self.system_prompt=system_prompt
def generate(self, prompt, repeat_penalty=1.2):
"""
Generates new tokens based on the given prompt.
Args:
prompt (str): The prompt to use for generating new tokens.
Returns:
str: The generated tokens.
"""
# Make a POST request to the API endpoint
headers = {"Content-Type": "application/json"}
url = "http://localhost:8080/completion"
json={
"prompt": prompt,
#"system_prompt": self.system_prompt, #TODO https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#change-system-prompt-on-runtime
"temperature": self.temperature,
"n_predict": self.max_new_tokens,
"top_k": 30,
"repeat_penalty": repeat_penalty,
}
response = requests.post(url, headers=headers, json=json)
if not response.ok:
print(response.content)
else:
return response.json()['content']

@ -28,3 +28,10 @@ Debate-id: 2020-12-15#3
Dokument: O-000077/2020 (en inlämnad fråga)
Inlämnad fråga: https://www.europarl.europa.eu/doceo/document/O-9-2020-000078_EN.html
Dokumentärendet: https://www.europarl.europa.eu/doceo/document/A-9-2020-0241_EN.html
4-bit, with Act Order and group size 32g. Gives highest possible inference quality, with maximum VRAM usage.
huggingface-cli download TheBloke/Mistral-7B-OpenOrca-GGUF mistral-7b-openorca.Q5_K_S.gguf --local-dir model_files --local-dir-use-symlinks False
huggingface-cli download TheBloke/Mistral-7B-OpenOrca-GGUF mistral-7b-openorca.Q5_K_S.gguf --local-dir . --local-dir-use-symlinks False

@ -1,76 +0,0 @@
aiohttp==3.8.6
aiosignal==1.3.1
altair==5.1.2
annotated-types==0.6.0
anyio==3.7.1
async-timeout==4.0.3
attrs==23.1.0
blinker==1.6.3
cachetools==5.3.2
certifi==2023.7.22
charset-normalizer==3.3.1
click==8.1.7
contourpy==1.1.1
cycler==0.12.1
dataclasses-json==0.6.1
fonttools==4.43.1
frozenlist==1.4.0
gitdb==4.0.11
GitPython==3.1.40
greenlet==3.0.1
idna==3.4
importlib-metadata==6.8.0
Jinja2==3.1.2
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.19.1
jsonschema-specifications==2023.7.1
kiwisolver==1.4.5
langchain==0.0.325
langsmith==0.0.53
markdown-it-py==3.0.0
MarkupSafe==2.1.3
marshmallow==3.20.1
matplotlib==3.8.0
mdurl==0.1.2
multidict==6.0.4
mypy-extensions==1.0.0
numpy==1.26.1
packaging==23.2
pandas==2.1.2
Pillow==10.1.0
protobuf==4.24.4
pyarrow==13.0.0
pydantic==2.4.2
pydantic_core==2.10.1
pydeck==0.8.1b0
Pygments==2.16.1
PyJWT==2.8.0
pyparsing==3.1.1
python-arango==7.7.0
python-dateutil==2.8.2
pytz==2023.3.post1
PyYAML==6.0.1
referencing==0.30.2
requests==2.31.0
requests-toolbelt==1.0.0
rich==13.6.0
rpds-py==0.10.6
six==1.16.0
smmap==5.0.1
sniffio==1.3.0
SQLAlchemy==2.0.22
streamlit==1.28.0
tenacity==8.2.3
toml==0.10.2
toolz==0.12.0
tornado==6.3.3
typing-inspect==0.9.0
typing_extensions==4.8.0
tzdata==2023.3
tzlocal==5.2
urllib3==2.0.7
validators==0.22.0
watchdog==3.0.0
yarl==1.9.2
zipp==3.17.0

@ -8,7 +8,6 @@ import matplotlib.pyplot as plt
import pandas as pd
import streamlit as st
import yaml
from arango_things import arango_db
from streamlit_info import (
party_colors,
@ -23,20 +22,6 @@ from streamlit_info import (
from langchain.schema import Document
from llama_server import LLM
from langchain.vectorstores import Chroma
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chat_models import ChatOpenAI
class Session():
def __init__(self) -> None:
self.mode = ''
self.user_input = ''
self.query = ''
self.df = None
class Params:
"""Class containing parameters for URL.
@ -65,6 +50,7 @@ class Params:
self.persons = self.set_param("persons")
self.from_year = self.set_param("from_year")
self.to_year = self.set_param("to_year")
self.debates = self.set_param("debates")
self.excerpt_page = self.set_param("excerpt_page")
def set_param(self, key):
@ -101,6 +87,7 @@ class Params:
from_year=self.from_year,
to_year=self.to_year,
parties=",".join(self.parties),
debates=",".join(self.debates),
persons=",".join(self.persons),
excerpt_page=self.excerpt_page,
)
@ -189,6 +176,7 @@ def summarize(df_party, user_input):
""",
)
#TODO Include examples in the prompt?
# Example 1:
# Short summary: In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget.
@ -200,21 +188,7 @@ def summarize(df_party, user_input):
#TODO Have to understand sysmtem_prompt in llama_server
#system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament."
from langchain.schema import (
AIMessage,
HumanMessage,
SystemMessage
)
messages = [
SystemMessage(content="You are a journalist reporting from the EU."),
HumanMessage(content=prompt)
]
llm = ChatOpenAI(temperature=0, openai_api_key="sk-xxx", openai_api_base="http://localhost:8081/v1", max_tokens=100)
response=llm(messages)
return response.content
#return llama.generate(prompt=prompt)
return llama.generate(prompt=prompt)
def normalize_party_names(name):
@ -248,12 +222,13 @@ def normalize_party_names(name):
"Vacant": "NA",
"NI": "NA",
"Renew": "Renew"
}
return parties[name]
def make_snippet(row, input_tokens, mode):
def make_snippet(text, input_tokens, token_text_list, token_input_list):
"""
Find the word searched for and give it some context.
@ -267,17 +242,11 @@ def make_snippet(row, input_tokens, mode):
str: A snippet of text containing the input tokens and some context.
"""
# If the input token is "speaker", return the first 300 characters of the text.
if mode == "question":
snippet = row['summary_short']
elif input_tokens == "speaker":
if input_tokens == "speaker":
snippet = str(text[:300])
if len(text) > 300:
snippet += "..."
else:
text = row["Text"]
token_test_list = row["tokens"]
token_input_list = row["input_tokens"]
snippet = []
text_lower = text.lower()
# Calculate snippet length in words.
@ -381,6 +350,28 @@ def build_style_mps(mps):
return style
def fix_party(party):
"""Replace old party codes with new ones."""
party = party.upper().replace("KDS", "KD").replace("FP", "L")
return party
def build_style_debate_types(debates):
"""Build a CSS style for debate type buttons."""
style = "<style> "
for debate in debates:
style += f' span[data-baseweb="tag"][aria-label="{debate}, close by backspace"]{{ background-color: #767676;}} .st-eg {{min-width: 14px;}}' # max-width: 328px;
style += "</style>"
return style
def highlight_cells(party):
if party in party_colors.keys():
color = party_colors[party]
return f"background-color: {color}; font-weight: 'bold'"
@st.cache_data
def options_persons(df):
d = {}
@ -390,7 +381,7 @@ def options_persons(df):
@st.cache_data
def get_data(aql):
def get_data(aql, input_tokens):
"""Get data from SQL database.
Args:
@ -411,19 +402,27 @@ def get_data(aql):
else:
status = "ok"
df = pd.DataFrame([doc for doc in cursor])
df.mep_id = df.mep_id.astype(int)
df["Year"] = df["Date"].apply(lambda x: int(x[:4]))
# df.drop_duplicates(ignore_index=True, inplace=True)
df["Party"] = df["Party"].apply(lambda x: normalize_party_names(x))
df.sort_values(["Date", "number"], axis=0, ascending=True, inplace=True)
df["len_translation"] = df["Text"].apply(lambda x: len(x.split(" ")))
df["len_tokens"] = df["tokens"].apply(lambda x: len(x))
df.sort_values(["Date", "speech_id", "number"], axis=0, inplace=True)
df.drop_duplicates(ignore_index=True, inplace=True)
return df, status
def user_input_to_db(user_input):
"""Writes user input to db for debugging."""
arango_db.collection("streamlit_user_input").insert(
{"timestamp": datetime.timestamp(datetime.now()), "input": user_input}
)
def create_aql_query(input_tokens):
"""Returns a valid sql query."""
# Split input into tokens
@ -455,10 +454,21 @@ def create_aql_query(input_tokens):
str_tokens = str([re.sub(r"[^a-zA-Z0-9\s]", "", token) for token in input_tokens])
select_fields = f"""
{strings['arango']['select_fields']}
{{
"speech_id": doc._key,
"Text": doc.translation,
"number": doc.speech_number,
"debatetype": doc.debate_type,
"Speaker": doc.name,
"Date": doc.date,
"url": doc.url,
"Party": doc.party,
"mep_id": doc.mep_id,
"Summary": doc.summary,
"debate_id": doc.debate_id,
'relevance_score': BM25(doc),
'tokens': TOKENS(doc.translation, 'text_en'),
'input_tokens': TOKENS({str_tokens}, 'text_en')
'input_tokens': TOKENS({str_tokens}, 'text_en')}}
"""
# Add LIMIT and RETURN clause
@ -471,10 +481,17 @@ def create_aql_query(input_tokens):
return aql_query.replace("SEARCH AND", "SEARCH ")
def fix_party(party):
"""Replace old party codes with new ones."""
party = party.upper().replace("KDS", "KD").replace("FP", "L")
return party
def error2db(error, user_input, engine):
"""Write error to DB for debugging."""
doc = (
{
"error": error,
"time": datetime.date(datetime.now()),
"user_input": str(user_input),
},
)
arango_db.collection("errors").inser(doc)
@st.cache_data
@ -554,7 +571,9 @@ def show_excerpts():
].copy()
df_excepts["Excerpt"] = df_excepts.apply(
lambda x: make_snippet(x, input_tokens, mode=mode),
lambda x: make_snippet(
x["Text"], input_tokens, x["tokens"], x["input_tokens"]
),
axis=1,
)
@ -600,16 +619,17 @@ def show_excerpts():
st.markdown(row["Speaker"])
with col2:
try:
snippet = (row["Excerpt"].replace(":", "\:"))
st.markdown(
f""" <span style="background-color:{party_colors_lighten[row['Party']]}; color:black;">{snippet}</span> """,
unsafe_allow_html=True,
)
except AttributeError:
snippet = ''
snippet = (
row["Excerpt"]
.replace(":", "\:")
.replace("<p>", "")
.replace("</p>", "")
)
st.markdown(
f""" <span style="background-color:{party_colors_lighten[row['Party']]}; color:black;">{snippet}</span> """,
unsafe_allow_html=True,
)
with col3:
full_text = st.button("Full text", key=row["speech_id"])
if full_text:
@ -630,34 +650,6 @@ def show_excerpts():
st.markdown(f"📝 [Read the protocol]({row['url']})")
@st.cache_data(show_spinner=False)
def search_with_question(question: str, search_kwargs: dict = {}):
embeddings = HuggingFaceEmbeddings()
vectordb = Chroma(persist_directory='chroma_db', embedding_function=embeddings)
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 20, 'fetch_k': 50}) #https://api.python.langchain.com/en/latest/vectorstores/langchain.vectorstores.chroma.Chroma.html?highlight=as_retriever#langchain.vectorstores.chroma.Chroma.as_retriever
# About MMR https://medium.com/tech-that-works/maximal-marginal-relevance-to-rerank-results-in-unsupervised-keyphrase-extraction-22d95015c7c5
docs = retriever.get_relevant_documents(question)
#docs = vectordb.similarity_search(question, search_kwargs=search_kwargs, )
docs_ids = set([doc.metadata["_id"] for doc in docs])
return get_data(f'''FOR doc IN speeches FILTER doc._id in {list(docs_ids)} RETURN DISTINCT {{{strings['arango']['select_fields']}}}''')
print(r)
exit()
# llm = ChatOpenAI(temperature=0, openai_api_key="sk-xxx", openai_api_base="http://localhost:8081/v1", max_tokens=100)
# retriever_from_llm = MultiQueryRetriever.from_llm(
# retriever=vectordb.as_retriever(search_kwargs=search_kwargs), llm=llm
# )
# #search_kwargs={"filter":{"party":"ID"}}
# unique_docs = retriever_from_llm.get_relevant_documents(query=question)
docs_ids = set([doc.metadata["_id"] for doc in unique_docs])
return get_data(f'''for doc in speeches filter doc._id in {list(docs_ids)} return {{{strings['arango']['select_fields']}}}''')
###* PAGE LOADS *###
# Title and explainer for streamlit
@ -680,11 +672,7 @@ partycodes = list(party_colors.keys()) # List of partycodes
return_limit = 10000
# Initialize LLM model.
#llama = LLM(temperature=0.5)
# Get strings from yaml file.
with open('strings.yml', 'r') as f:
strings = yaml.safe_load(f)
llama = LLM(temperature=0.5)
# Ask for word to search for.
user_input = st.text_input(
@ -692,67 +680,63 @@ user_input = st.text_input(
value=params.q,
placeholder="Search something",
# label_visibility="hidden",
help='You can use asterix (*), minus (-), quotationmarks ("") and OR. \nYou can also try asking a question, like *What is the Parliament doing for the climate?*',
)
help='''You can use asterix (*), minus (-), quotationmarks ("") and OR.
You can also ask a question but make sure to use a question mark!''')
if len(user_input) > 3:
params.q = user_input
try: # To catch errors.
meps_df = get_speakers()
meps_ids = meps_df.mep_id.to_list()
# print(user_input.upper())
# print(llama.generate(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions.
# User input: {user_input}
# Questions: '''))
try:
user_input = user_input.replace("'", '"')
input_tokens = re.findall(r'(?:"[^"]*"|\S)+', user_input)
# Put user input in session state (first run).
if "user_input" not in st.session_state:
st.session_state["user_input"] = user_input
# Write user input to DB.
arango_db.collection("streamlit_user_input").insert(
{"timestamp": datetime.timestamp(datetime.now()), "input": user_input}
)
user_input_to_db(user_input)
else:
if st.session_state["user_input"] != user_input:
# Write user input to DB.
st.session_state["user_input"] = user_input
user_input_to_db(user_input)
# Reset url parameters.
params.reset(q=user_input)
# Write user input to DB.
arango_db.collection("streamlit_user_input").insert(
{"timestamp": datetime.timestamp(datetime.now()), "input": user_input}
)
params.update()
if user_input.strip()[-1] == '?': # Search documents with AI.
mode = 'question'
with st.spinner(f"Trying to find answers..."):
df, status = search_with_question(user_input)
input_tokens = None
else:
mode = "normal"
# Check if user has searched for a specific politician.
if len(user_input.split(" ")) in [ 2, 3, 4,]: # TODO Better way of telling if name?
list_persons = meps_df["name"].to_list()
if user_input.lower() in list_persons:
sql = search_person(user_input, list_persons)
mode = "speaker"
aql = create_aql_query(input_tokens)
meps_df = get_speakers()
meps_ids = meps_df.mep_id.to_list()
## Fetch data from DB.
df, status = get_data(aql)
if status == "no hits": # If no hits.
st.write("No hits. Try again!")
st.stop()
elif status == "limit":
st.write(limit_warning)
st.stop()
# Check if user has searched for a specific politician.
if len(user_input.split(" ")) in [
2,
3,
4,
]: # TODO Better way of telling if name?
list_persons = meps_df["name"].to_list()
if user_input.lower() in list_persons:
sql = search_person(user_input, list_persons)
search_terms = "speaker"
aql = create_aql_query(input_tokens)
## Fetch data from DB.
df, status = get_data(aql, input_tokens)
if status == "no hits": # If no hits.
st.write("No hits. Try again!")
st.stop()
elif status == "limit":
st.write(limit_warning)
st.stop()
party_talks = pd.DataFrame(df["Party"].value_counts())
@ -760,7 +744,7 @@ if len(user_input) > 3:
if type(party_labels) == "list":
party_labels.sort()
if mode != "speaker":
if user_input != "speaker":
# Let the user select parties to be included.
container_parties = st.container()
with container_parties:
@ -774,17 +758,27 @@ if len(user_input) > 3:
default=party_labels,
)
if params.parties != []:
if mode == "question":
df, status = search_with_question(user_input, search_kwargs={"filter":{"party":{"$in": params.parties}}})
else:
df = df.loc[df["Party"].isin(params.parties)]
df = df.loc[df["Party"].isin(params.parties)]
if len(df) == 0:
st.stop()
# Let the user select type of debate.
container_debate = st.container()
with container_debate:
debates = df["debatetype"].unique().tolist()
debates.sort()
style = build_style_debate_types(debates)
st.markdown(style, unsafe_allow_html=True)
params.debates = st.multiselect(
label="Select type of debate",
options=debates,
default=debates,
)
if params.debates != []:
df = df.loc[df["debatetype"].isin(params.debates)]
if len(df) == 0:
st.stop()
params.update()
# Let the user select a range of years.
@ -808,7 +802,7 @@ if len(user_input) > 3:
params.update()
if mode != "speaker":
if user_input != "speaker":
# Let the user select talkers.
options = options_persons(df)
style_mps = build_style_mps(options) # Make the options the right colors.
@ -866,7 +860,7 @@ if len(user_input) > 3:
del st.session_state["disable_next_page_button"]
st.session_state["hits"] = df.shape[0]
# Show snippets.
##! Show snippets.
expand = st.expander(
"Show excerpts from the speeches.",
expanded=False,
@ -874,7 +868,6 @@ if len(user_input) > 3:
with expand:
if "excerpt_page" not in st.session_state:
st.session_state["excerpt_page"] = 0
excerpts_container = st.container()
show_excerpts()
st.session_state["excerpt_page"] += 1
@ -891,147 +884,138 @@ if len(user_input) > 3:
st.session_state["excerpt_page"] += 1
show_excerpts()
# * Download all data in df.
st.download_button(
"Download the data as CSV",
data=df.to_csv(
index=False,
sep=";",
columns=[
"speech_id",
"Text",
"Party",
"Speaker",
"Date",
"url",
],
).encode("utf-8"),
file_name=f"{user_input}.csv",
mime="text/csv",
)
if mode != 'question':
if mode != "speaker":
# Remove talks from same party within the same session to make the
# statistics more representative.
df_ = df[["speech_id", "Party", "Year"]].drop_duplicates()
## Make pie chart.
party_talks = pd.DataFrame(df_["Party"].value_counts())
party_labels = party_talks.index.to_list()
fig, ax1 = plt.subplots()
total = party_talks["count"].sum()
mentions = party_talks["count"] #!
ax1.pie(
mentions,
labels=party_labels,
autopct=lambda p: "{:.0f}".format(p * total / 100),
colors=[party_colors[key] for key in party_labels],
startangle=90,
)
# Remove talks from same party within the same session to make the
# statistics more representative.
df_ = df[["speech_id", "Party", "Year"]].drop_duplicates()
if user_input != "speaker":
## Make pie chart.
party_talks = pd.DataFrame(df_["Party"].value_counts())
party_labels = party_talks.index.to_list()
fig, ax1 = plt.subplots()
total = party_talks["count"].sum()
mentions = party_talks["count"] #!
ax1.pie(
mentions,
labels=party_labels,
autopct=lambda p: "{:.0f}".format(p * total / 100),
colors=[party_colors[key] for key in party_labels],
startangle=90,
)
# Make bars per year.
years = set(df["Year"].tolist())
# Make bars per year.
years = set(df["Year"].tolist())
df_years = pd.DataFrame(columns=["Party", "Year"])
df_years = pd.DataFrame(columns=["Party", "Year"])
for year, df_i in df.groupby("Year"):
dff = pd.DataFrame(data=df_i["Party"].value_counts())
dff["Year"] = str(year)
df_years = pd.concat([df_years, dff])
df_years["party_code"] = df_years.index
# df_years = df_years.groupby(["party_code", "Year"]).size()
df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x])
# df_years.drop(["Party"], axis=1, inplace=True)
for year, df_i in df.groupby("Year"):
dff = pd.DataFrame(data=df_i["Party"].value_counts())
dff["Year"] = str(year)
df_years = pd.concat([df_years, dff])
df_years["party_code"] = df_years.index
# df_years = df_years.groupby(["party_code", "Year"]).size()
df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x])
# df_years.drop(["Party"], axis=1, inplace=True)
df_years.rename(
columns={"Party": "Mentions", "party_code": "Party"}, inplace=True
)
chart = (
alt.Chart(df_years)
.mark_bar()
.encode(
x="Year",
y="count",
color=alt.Color("color", scale=None),
tooltip=["Party", "Mentions"],
)
df_years.rename(
columns={"Party": "Mentions", "party_code": "Party"}, inplace=True
)
chart = (
alt.Chart(df_years)
.mark_bar()
.encode(
x="Year",
y="count",
color=alt.Color("color", scale=None),
tooltip=["Party", "Mentions"],
)
)
if user_input == "speaker":
st.altair_chart(chart, use_container_width=True)
if user_input == "speaker":
else:
# Put the charts in a table.
fig1, fig2 = st.columns(2)
with fig1:
st.pyplot(fig)
with fig2:
st.altair_chart(chart, use_container_width=True)
else:
# Put the charts in a table.
fig1, fig2 = st.columns(2)
with fig1:
st.pyplot(fig)
with fig2:
st.altair_chart(chart, use_container_width=True)
# Create an empty container
container = st.empty()
# Create an empty container
container = st.empty()
# * Make a summary of the standpoints of the parties.
# A general note about the summaries. #TODO Make better.
st.markdown(summary_note)
# * Make a summary of the standpoints of the parties.
summarize_please = st.button("Summarize please!", )
if summarize_please:
# A general note about the summaries. #TODO Make better.
st.markdown(summary_note)
# Count the number of speeches by party.
party_counts = df["Party"].value_counts()
# Sort the dataframe by party.
df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values(
by=["Party"]
)
# Create an empty dictionary to store the summaries.
summaries = {}
# Loop through each party and their respective speeches.
for party, df_party in df_sorted.groupby("Party"):
# Create an empty list to store the texts of each speech.
with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."):
summary = summarize(df_party, user_input)
st.markdown(
f"##### <span style='color:{party_colors[party]}'>{parties[party]} ({party})</span>",
unsafe_allow_html=True,
)
st.write(summary)
# Count the number of speeches by party.
party_counts = df["Party"].value_counts()
# Sort the dataframe by party.
df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values(
by=["Party"]
)
# Create an empty dictionary to store the summaries.
summaries = {}
# Loop through each party and their respective speeches.
for party, df_party in df_sorted.groupby("Party"):
# Create an empty list to store the texts of each speech.
with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."):
summary = summarize(df_party, user_input)
st.markdown(
f"##### <span style='color:{party_colors[party]}'>{parties[party]} ({party})</span>",
unsafe_allow_html=True,
)
st.write(summary)
# A general note about the summaries. #TODO Make better.
# A general note about the summaries. #TODO Make better.
feedback_column, download_column = st.columns([3, 1])
with feedback_column:
# * Get feedback.
st.empty()
feedback_container = st.empty()
# * Get feedback.
st.empty()
feedback_container = st.empty()
with feedback_container.container():
feedback = st.text_area(
'Feedback', placeholder='Feel free to write suggestions for functions and improvements here!', label_visibility='collapsed'
)
send = st.button("Send")
if len(feedback) > 2 and send:
doc = {
"feedback": feedback,
"params": st.experimental_get_query_params(),
"where": ("streamlit", title),
"timestamp": str(datetime.now()),
"date": str(datetime.date(datetime.now())),
}
arango_db.collection("feedback").insert(doc)
feedback_container.write("*Thanks*")
params.update()
# st.markdown("##")
with download_column:
# * Download all data in df.
st.download_button(
"Download data as CSV-file",
data=df.to_csv(
index=False,
sep=";",
columns=[
"speech_id",
"Text",
"Party",
"Speaker",
"Date",
"url",
],
).encode("utf-8"),
file_name=f"{user_input}.csv",
mime="text/csv",
with feedback_container.container():
feedback = st.text_area(
"*Feel free to write suggestions for functions and improvements here!*"
)
send = st.button("Send")
if len(feedback) > 2 and send:
doc = {
"feedback": feedback,
"params": st.experimental_get_query_params(),
"where": ("streamlit", title),
"timestamp": str(datetime.now()),
"date": str(datetime.date(datetime.now())),
}
arango_db.collection("feedback").insert(doc)
feedback_container.write("*Thanks*")
params.update()
# st.markdown("##")
except Exception as e:
if (

@ -69,7 +69,7 @@ It's a summary of the ten most relevant speeches from each party based on the se
Please make sure to check the original text before you use the summary in any way.
"""
explainer = """This is a database of what members of the European Parliamen have said in various debates in the parliament since 2019.
The data comes from the EU have been translated when not in English.
The data comes partly from the EU.
- Start by typing one or more keywords below. You can use asterix (*), minus(-), quotation marks (""), OR and year\:yyyy-yyyy. The search
`energy crisis* basic power OR nuclear power "fossil-free energy sources" -wind power year:2019-2022` is looking for quotes like\:
- mentions "energy crisis" (incl. e.g. "energy crisis*")
@ -77,9 +77,9 @@ The data comes from the EU have been translated when not in English.
- mentions the *exact phrase* "fossil-free energy sources"
- *does* not mention "wind power"
- found during the years 2019-2022
- You can also ask a specific quesion, like `What have parliamentarians said about the energy crisis?` Remember to put a question mark at the end of the question.
- When you have received your results, you can filter on which years you are interested in.
- Under "Excerpt" you can choose to see the entire speech in text, and under the text there are links to official protocol.
- When you have received your results, you can then click away matches or change which years and debate types you are interested in.
- Under "Longer excerpt" you can choose to see the entire speech in text, and under the text there are links to the Riksdag's Web TV and downloadable audio (in the cases
where the debate has been broadcast).
Please tell us how you would like to use the data and about things that don't work. [Email me](mailto:lasse@edfast.se) or [write to me on Twitter](https://twitter.com/lasseedfast).
My name is [Lasse Edfast and I'm a journalist](https://lasseedfast.se) based in Sweden.

@ -1,313 +0,0 @@
import nltk
import tiktoken
import re
def normalize_party_names(name):
"""
Normalizes party names to the format used in the database.
Parameters:
name (str): The party name to be normalized.
Returns:
str: The normalized party name.
"""
parties = {
"EPP": "EPP",
"PPE": "EPP",
"RE": "Renew",
"S-D": "S&D",
"S&D": "S&D",
"ID": "ID",
"ECR": "ECR",
"GUE/NGL": "GUE/NGL",
"The Left": "GUE/NGL",
"Greens/EFA": "Greens/EFA",
"G/EFA": "Greens/EFA",
"Verts/ALE": "Greens/EFA",
"NA": "NA",
"NULL": "NA",
None: "NA",
"-": "NA",
"Vacant": "NA",
"NI": "NA",
"Renew": "Renew"
}
return parties[name]
def count_tokens(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(string))
return num_tokens
def whitespace_remover(text):
return re.sub(r'^[\w\s]', '', text)
def fix_date(string, pattern):
"""
Args:
string (str): The input string containing the date.
pattern (str): The pattern to be used for parsing the date from the string. It should contain 'y', 'm', and 'd' to represent the year, month, and day respectively.
Returns:
str: The formatted date string in the 'yyyy-mm-dd' format.
Example:
>>> fix_date('20211231', 'yyyymmdd')
'2021-12-31'
"""
y = re.search(r'y+', pattern)
m = re.search(r'm+', pattern)
d = re.search(r'd+', pattern)
year = string[y.span()[0]:y.span()[1]]
month = string[m.span()[0]:m.span()[1]]
day = string[d.span()[0]:d.span()[1]]
return (f'{year}-{month}-{day}')
def fix_doc_name(string, pattern):
"""
Returns a dictionary with keys 'year', 'number', and 'letters', and their corresponding values from the string.
Args:
string (str): The string from which to extract the year, number, and letters.
pattern (str): The pattern which to search for in the string.
Returns:
dict: A dictionary with keys 'year', 'number', and 'letters' and their corresponding values from the string.
Example:
>>> fix_doc_name('COM/2021/570', 'lll/yyyy/nnn')
'{'year': 2023, 'number': 570, 'letters': COM}'
"""
# Find positions for y, n l.
y = re.search(r'y+', pattern)
n = re.search(r'n+', pattern)
l = re.search(r'l+', pattern)
# Extract the year, number and letters based on the positions.
year = string[y.span()[0]:y.span()[1]]
number = string[n.span()[0]:n.span()[1]]
letters = string[l.span()[0]:l.span()[1]]
return {'year': year, 'number': number, 'letters': letters}
def text_splitter(text: str, max_tokens=2000):
"""
Splits a given text into chunks of sentences where each chunk has a number of tokens less than or equal to the max_tokens.
The function first calculates the total number of tokens in the input text. If this number is greater than max_tokens,
it calculates the maximum number of tokens per chunk and splits the text into sentences. Then it iterates over the sentences,
adding them to the current chunk until the number of tokens in the current chunk reaches the maximum limit. When this limit
is reached, the current chunk is added to the list of chunks and a new chunk is started. This process continues until all
sentences have been processed. If the total number of tokens in the text is less than or equal to max_tokens, the function
returns the whole text as a single chunk.
Parameters:
text (str): The input text to be split into chunks.
max_tokens (int): The maximum number of tokens allowed in each chunk.
Returns:
chunks (list of str): A list of text chunks where each chunk has a number of tokens less than or equal to max_tokens.
"""
try:
tokens_in_text = count_tokens(text)
except:
tokens_in_text = len(text)/3
if tokens_in_text > max_tokens:
# Calculate maximal number of tokens in chunks to make the chunks even.
max_tokens_per_chunk = int(tokens_in_text/int(tokens_in_text / max_tokens))
# Split the text into sentences.
sentences = nltk.sent_tokenize(text)
# Initialize an empty list to hold chunks and a string to hold the current chunk.
chunks = []
current_chunk = ''
# Iterate over the sentences.
for sentence in sentences:
# If adding the next sentence doesn't exceed the max tokens limit, add the sentence to the current chunk.
if count_tokens(current_chunk + ' ' + sentence) <= max_tokens_per_chunk:
current_chunk += ' ' + sentence
else:
# If it does, add the current chunk to the chunks list and start a new chunk with the current sentence.
chunks.append(current_chunk)
current_chunk = sentence
# Add the last chunk to the chunks list.
if current_chunk:
chunks.append(current_chunk)
else:
chunks = [text]
return chunks
parliamentary_term_now = 9 #* Update this every term.
model_mistral = "mistral-openorca"
eu_country_codes = {
"Belgium": "BE",
"Greece": "EL",
"Lithuania": "LT",
"Portugal": "PT",
"Bulgaria": "BG",
"Spain": "ES",
"Luxembourg": "LU",
"Romania": "RO",
"Czechia": "CZ",
"France": "FR",
"Hungary": "HU",
"Slovenia": "SI",
"Denmark": "DK",
"Croatia": "HR",
"Malta": "MT",
"Slovakia": "SK",
"Germany": "DE",
"Italy": "IT",
"Netherlands": "NL",
"Finland": "FI",
"Estonia": "EE",
"Cyprus": "CY",
"Austria": "AT",
"Sweden": "SE",
"Ireland": "IE",
"Latvia": "LV",
"Poland": "PL",
}
country_flags = {
"United Kingdom": "🇬🇧",
"Sweden": "🇸🇪",
"Spain": "🇪🇸",
"Slovenia": "🇸🇮",
"Slovakia": "🇸🇰",
"Romania": "🇷🇴",
"Portugal": "🇵🇹",
"Poland": "🇵🇱",
"Netherlands": "🇳🇱",
"Malta": "🇲🇹",
"Luxembourg": "🇱🇺",
"Lithuania": "🇱🇹",
"Latvia": "🇱🇻",
"Italy": "🇮🇹",
"Ireland": "🇮🇪",
"Hungary": "🇭🇺",
"Greece": "🇬🇷",
"Germany": "🇩🇪",
"France": "🇫🇷",
"Finland": "🇫🇮",
"Estonia": "🇪🇪",
"Denmark": "🇩🇰",
"Czechia": "🇨🇿",
"Cyprus": "🇨🇾",
"Croatia": "🇭🇷",
"Bulgaria": "🇧🇬",
"Belgium": "🇧🇪",
"Austria": "🇦🇹",
}
policy_areas = [
"Agriculture",
"Business",
"Industry",
"Climate",
"Culture",
"Customs",
"Development",
"Education",
"Employment",
"Social Affairs",
"Energy",
"Environment",
"FoodSafety",
"SecurityPolicy",
"Health",
"Democracy",
"Humanitarian Aid",
"Justice",
"Research And Innovation",
"Market",
"Taxation",
"Trade",
"Transport",
]
# From https://eur-lex.europa.eu/browse/summaries.html
policy_areas = ['Agriculture', ' Audiovisual and media', ' Budget', ' Competition', ' Consumers', ' Culture', ' Customs', ' Development', ' Digital single market', ' Economic and monetary affairs', ' Education, training, youth, sport', ' Employment and social policy', ' Energy', ' Enlargement', ' Enterprise', ' Environment and climate change', ' External relations', ' External trade', ' Food safety', ' Foreign and security policy', ' Fraud and corruption', ' Humanitarian Aid and Civil Protection', ' Human rights', ' Institutional affairs', ' Internal market', ' Justice, freedom and security', ' Oceans and fisheries', ' Public health', ' Regional policy', ' Research and innovation', ' Taxation', ' Transport']
countries = [
"Romania",
"Latvia",
"Slovenia",
"Denmark",
"Spain",
"Italy",
"Hungary",
"United Kingdom",
"Netherlands",
"Czechia",
"Finland",
"Belgium",
"Germany",
"France",
"Slovakia",
"Poland",
"Ireland",
"Malta",
"Cyprus",
"Luxembourg",
"Greece",
"Austria",
"Sweden",
"Portugal",
"Lithuania",
"Croatia",
"Bulgaria",
"Estonia",
]
parties = ["Renew", "S-D", "PPE", "Verts/ALE", "ECR", "NI", "The Left", "ID", "GUE/NGL"]
party_colors = {
"EPP": "#3399FF",
"S-D": "#F0001C",
"Renew": "gold",
"ID": "#0E408A",
"G/EFA": "#57B45F",
"ECR": "#196CA8",
"GUE/NGL": "#B71C1C",
'The Left': "#B71C1C", # Same as GUE/NGL
"NI": "white",
"Vacant": "white",
"PPE": "#3399FF", # Same as EPP
"NULL": 'white',
'Verts/ALE':"#57B45F", # Same as G/EFA
None: 'white'
}
def insert_in_db(query):
con = sqlite3.connect(path_db)
con.row_factory = sqlite3.Row
cursor = con.cursor()
query = query
cursor.execute(query)
con.commit()
con.close()
Loading…
Cancel
Save