You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1001 lines
35 KiB

import difflib
import re
import traceback
from datetime import datetime
import altair as alt
import matplotlib.pyplot as plt
import pandas as pd
import streamlit as st
from arango_things import arango_db
from things import normalize_party_names
from streamlit_info import (
party_colors,
explainer,
limit_warning,
party_colors_lighten,
css,
results_limit,
parties,
summary_note,
)
from langchain.schema import Document
class Params:
"""Class containing parameters for URL.
Attributes:
params (dict): A dictionary containing the parameters.
q (str): The search query.
parties (list): A list of political parties.
persons (list): A list of persons.
from_year (int): The start year for the search.
to_year (int): The end year for the search.
debates (list): A list of debates.
excerpt_page (int): The page number for the excerpt.
"""
def __init__(self, params):
"""Initializes the Params class.
Args:
params (dict): A dictionary containing the parameters.
"""
self.params = params
# Set parameters.
self.q = self.set_param("q")
self.parties = self.set_param("parties")
self.persons = self.set_param("persons")
self.from_year = self.set_param("from_year")
self.to_year = self.set_param("to_year")
self.debates = self.set_param("debates")
self.excerpt_page = self.set_param("excerpt_page")
def set_param(self, key):
"""Returns the value of a parameter if it exists.
Args:
key (str): The key of the parameter.
Returns:
The value of the parameter if it exists, otherwise a default value.
"""
if key in self.params:
if key in ["parties", "persons", "debates"]:
value = self.params[key][0].split(",")
else:
value = self.params[key][0]
else:
value = []
if key == "q":
value = ""
elif key == "from_year":
value = 2019 # Catch all.
elif key == "to_year":
value = 2023 # Catch all.
elif key == "excerpt_page":
value = 0
return value
def update(self):
"""Updates the URL parameters."""
st.experimental_set_query_params(
q=self.q,
from_year=self.from_year,
to_year=self.to_year,
parties=",".join(self.parties),
debates=",".join(self.debates),
persons=",".join(self.persons),
excerpt_page=self.excerpt_page,
)
def update_param(self, key, value):
"""Updates a single parameter.
Args:
key (str): The key of the parameter to update.
value (str or int or list): The new value of the parameter.
"""
self.params[key] = value
self.update()
def reset(self, q=False):
"""Resets all parameters to their default values.
Args:
q (str): The new search query (optional).
"""
for key in self.params:
self.params[key] = []
if q:
self.q = q
@st.cache_data(show_spinner=False)
def summarize(df_party, user_input):
"""
Summarizes the standpoints of a political party in the EU parliament based on a user search.
Args:
df_party (pandas.DataFrame): A DataFrame containing speeches by parliamentarians from the political party in the EU parliament.
user_input (str): The user's search query.
Returns:
str: A short summary of the standpoints of the political party in the EU parliament, in relation to the user search.
"""
texts = []
documents = []
for _, row in df_party.iterrows():
documents.append(
Document(
page_content=f"**{row['Speaker']} {row['Date']}**\n{row['Summary']}",
metadata=row.to_dict(),
)
)
# https://python.langchain.com/docs/use_cases/question_answering/qa_citations
# Get the 12 most relevant speeches (relevance defined by BM25 in the arango query).
if df_party.shape[0] > 10:
df_party = df_party.sort_values(by="relevance_score", ascending=False)[:10]
if df_party.shape[0] > 5:
# Sample 6 rows with weights proportional to the year counts
sample_size = 5
weights = df_party["Year"].value_counts(normalize=True)
df_party = df_party.groupby("Year", group_keys=False).apply(
lambda x: x.sample(n=int(sample_size * weights[x.name]))
)
# Sort the dataframe by date.
df_party.sort_values(by="Date", ascending=False, inplace=True)
# Loop through the speeches and add its text to the list.
for _, row in df_party.sort_values(by="Date", ascending=False)[:10].iterrows():
texts.append(f"**{row['Speaker']} {row['Date']}**\n{row['Summary']}")
# Join the texts into a single string.
# Create a prompt for the user to summarize the party's standpoints.
prompt = re.sub(
r"\W\W\W+",
"\n",
f"""
A user have made this search in a database: <{user_input}>
That resulted in the lists below, consisting of speeches by parliamentarians from the {party} party in the EU parliament:
{texts}
Please make a very short summary of the standpoints of the {party} party in the EU parliament, in relation to the user search. Make note of the speaker and the dates for the speeches.
Example 1:
In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget.
Example 2:
Parliamentarians from {party} have been very active in the debate about the EU budget. They have been both for and against the budget.
Short summary:
""",
)
# Generate a summary of the party's standpoints.
# Return the summaries dictionary.
system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament."
return ollama(
prompt=prompt, system_prompt=system_prompt, temperature=0.5, print_tokens=False
)
def make_snippet(text, input_tokens, token_text_list, token_input_list):
"""
Find the word searched for and give it some context.
Args:
text (str): The text to search for the input tokens.
input_tokens (str): The input tokens to search for in the text.
token_text_list (list): A list of tokens in the text.
token_input_list (list): A list of input tokens.
Returns:
str: A snippet of text containing the input tokens and some context.
"""
# If the input token is "speaker", return the first 300 characters of the text.
if input_tokens == "speaker":
snippet = str(text[:300])
if len(text) > 300:
snippet += "..."
else:
snippet = []
text_lower = text.lower()
# Calculate snippet length in words.
snippet_length = 40 * int(8 / len(input_tokens)) # * Change to another value?
# Loop through each input token.
for token in input_tokens:
# Remove any non-alphanumeric characters and convert to lowercase.
token = re.sub(r"[^a-zA-Z0-9\s]", "", token).lower()
if token in text_lower:
# Find the position of the token in the text.
position = text_lower.find(token)
position_start = position - snippet_length
# Find the start of the snippet by looking for the previous space character.
while True and position_start > 0:
position_start -= 1
if text_lower[position_start] == " ":
break
position_end = position + int(snippet_length / 2)
# Find the end of the snippet by looking for the next space character.
while True and position_end > 0 and position_end < len(text_lower) - 1:
position_end += 1
if text_lower[position_end] == " ":
break
text_before = "".join(text[position_start:position])
if position_end > len(text):
position_end = len(text)
text_after = "".join(text[position + len(token) + 1 : position_end])
word_context = f"{text_before} <b>{token}</b> {text_after}"
snippet.append(word_context)
if len(snippet) == 2:
break
if snippet == [""]:
# If the input token is not found in the text, find the closest match in the token list.
text_list = text.split(" ")
for token in token_input_list:
token = "".join(token)
if token in token_text_list:
position = token_text_list.index(token)
w = 20
# Get a list of words around the token, and find the closest match to the token in that list.
list_from_text = text_list[position - w : position + w]
closest_match = difflib.get_close_matches(
token, list_from_text, n=1
)
# If no match is found, expand the list of words and try again.
if not closest_match:
w = 50
list_from_text = text_list[position - w : position + w]
closest_match = difflib.get_close_matches(
token, list_from_text, n=1
)
if closest_match:
word = closest_match[0]
# Calculate the difference between the position of the token and the position of the closest match.
diff_token_to_word = list_from_text.index(word) - w
position = position + diff_token_to_word
else:
word = text_list[position]
# Get the context of the word, including the words before and after it.
position_start = position - snippet_length
if position_start < 0:
position_start = 0
position_end = position + int(snippet_length / 2)
text_before = " ".join(text_list[position_start:position])
if position_end > len(text_list):
position_end = len(text_list)
text_after = " ".join(text_list[position + 1 : position_end])
word_context = f"{text_before} <b>{word}</b> {text_after}"
snippet.append(word_context)
snippet = "|".join(snippet)
snippet = f"...{snippet}..."
return snippet
def build_style_parties(parties):
"""Build a CSS styl for party names buttons."""
style = "<style> "
for party in parties:
style += f' span[data-baseweb="tag"][aria-label="{party}, close by backspace"]{{ background-color: {party_colors[party]}}} .st-eg {{min-width: 14px;}} ' # max-width: 328px;
style += "</style>"
return style
def build_style_mps(mps):
"""Build a CSS style for party names buttons."""
style = "<style> "
for party in mps:
party = fix_party(party)
try:
style += f' span[data-baseweb="tag"][aria-label="{party}, close by backspace"]{{ background-color: {party_colors[party]};}} .st-eg {{min-width: 14px;}} ' # max-width: 328px;
except KeyError:
style += f' span[data-baseweb="tag"][aria-label="{party}, close by backspace"]{{ background-color: {party_colors["NA"]};}} .st-eg {{min-width: 14px;}} '
style += "</style>"
return style
def fix_party(party):
"""Replace old party codes with new ones."""
party = party.upper().replace("KDS", "KD").replace("FP", "L")
return party
def build_style_debate_types(debates):
"""Build a CSS style for debate type buttons."""
style = "<style> "
for debate in debates:
style += f' span[data-baseweb="tag"][aria-label="{debate}, close by backspace"]{{ background-color: #767676;}} .st-eg {{min-width: 14px;}}' # max-width: 328px;
style += "</style>"
return style
def highlight_cells(party):
if party in party_colors.keys():
color = party_colors[party]
return f"background-color: {color}; font-weight: 'bold'"
@st.cache_data
def options_persons(df):
d = {}
for i in df.groupby("Speaker"):
d[i[0]] = i[1].shape[0]
return [f"{key} - {value}" for key, value in d.items()]
@st.cache_data
def get_data(aql, input_tokens):
"""Get data from SQL database.
Args:
sql (str): A SQL query string.
Returns:
DataFrame: Dataframe with some adjustments to the data fetched from the DB.
"""
cursor = arango_db.aql.execute(aql, count=True)
if cursor.__len__() == 0:
df = None
status = "no hits"
else:
if cursor.__len__() == results_limit:
status = "limit"
else:
status = "ok"
df = pd.DataFrame([doc for doc in cursor])
df.mep_id = df.mep_id.astype(int)
df["Year"] = df["Date"].apply(lambda x: int(x[:4]))
# df.drop_duplicates(ignore_index=True, inplace=True)
df["Party"] = df["Party"].apply(lambda x: normalize_party_names(x))
df.sort_values(["Date", "number"], axis=0, ascending=True, inplace=True)
df["len_translation"] = df["Text"].apply(lambda x: len(x.split(" ")))
df["len_tokens"] = df["tokens"].apply(lambda x: len(x))
df.sort_values(["Date", "speech_id", "number"], axis=0, inplace=True)
return df, status
def user_input_to_db(user_input):
"""Writes user input to db for debugging."""
arango_db.collection("streamlit_user_input").insert(
{"timestamp": datetime.timestamp(datetime.now()), "input": user_input}
)
def create_aql_query(input_tokens):
"""Returns a valid sql query."""
# Split input into tokens
# Build AQL query
aql_query = "FOR doc IN speeches_view\nSEARCH "
# Add SEARCH PHRASE clause
search_phrase = next((token for token in input_tokens if " " in token), None)
if search_phrase:
aql_query += f'PHRASE(doc.translation, {search_phrase}, "text_en")\n'
input_tokens.remove(search_phrase)
# Add ANALYZER clause for each token
for token in input_tokens:
if token.startswith("-"):
if "*" in token:
aql_query += (
f'AND NOT LIKE(doc.translation, "{token[1:].replace("*", "%")}")\n'
)
else:
aql_query += f'AND NOT ANALYZER(TOKENS("{token[1:]}", "text_en") ANY == doc.translation, "text_en")\n'
elif "*" in token:
aql_query += f'AND LIKE(doc.translation, "{token.replace("*", "%")}")\n'
else:
aql_query += f'AND ANALYZER(TOKENS("{token}", "text_en") ANY == doc.translation, "text_en")\n'
# Add fields for relevance and tokens.
str_tokens = str([re.sub(r"[^a-zA-Z0-9\s]", "", token) for token in input_tokens])
select_fields = f"""
{{
"speech_id": doc._key,
"Text": doc.translation,
"number": doc.speech_number,
"debatetype": doc.debate_type,
"Speaker": doc.name,
"Date": doc.date,
"url": doc.url,
"Party": doc.party,
"mep_id": doc.mep_id,
"Summary": doc.summary,
"debate_id": doc.debate_id,
'relevance_score': BM25(doc),
'tokens': TOKENS(doc.translation, 'text_en'),
'input_tokens': TOKENS({str_tokens}, 'text_en')}}
"""
# Add LIMIT and RETURN clause
aql_query += f"""
FILTER char_length(doc.summary) > 10
LIMIT {results_limit}
SORT BM25(doc) DESC
RETURN {select_fields}
"""
return aql_query.replace("SEARCH AND", "SEARCH ")
def error2db(error, user_input, engine):
"""Write error to DB for debugging."""
doc = (
{
"error": error,
"time": datetime.date(datetime.now()),
"user_input": str(user_input),
},
)
arango_db.collection("errors").inser(doc)
@st.cache_data
def get_speakers():
"""Get all meps."""
df = pd.DataFrame(
[
i
for i in arango_db.aql.execute(
"for doc in meps return {'name': doc.name, 'mep_id': doc._key, mep_url: doc.mep_URI}"
)
]
)
df.mep_id = df.mep_id.astype(int)
return df
def search_person(user_input, list_persons):
"""Returns SQL query made for searching everything a defined speaker has said.
Args:
user_input (str): The string resulting from user input (input()).
Returns:
list: List of search terms.
"""
# List all alternatives.
options = [i for i in list_persons if i.lower() == user_input.lower()]
options = [f"Yes, search for {i}" for i in options]
no_option = f"No, I want to search for what have been said about {user_input}."
options += [no_option, "Choose an alternative"]
preselected_option = len(options) - 1
# Let the user select a person or no_alternative.
speaker = st.selectbox(
":red[Do you want to search for what a specific parliamentarian has said?]",
options,
index=preselected_option,
)
if speaker == "Choose an alternative":
st.stop()
if speaker == no_option:
# Return "normal" query if no_alternative.
aql = create_aql_query(user_input)
else:
speaker = speaker.replace("Yes, search for ", "")
aql = f"FOR doc IN meps FILTER doc.name == '{speaker}' RETURN doc"
return aql
def show_excerpts():
"""
Displays excerpts from a dataframe in a Streamlit app.
The function uses the global variables `df`, `input_tokens`, `party_colors_lighten`, and `excerpts_container`.
It also uses the session state variables `excerpt_page`, `df_excerpts`, `disable_next_page_button`, and `text_next_page_button`.
The function displays a specified number of excerpts per page, and allows the user to navigate between pages.
Each excerpt includes the speaker, date, party, and a snippet of text.
The user can click a button to view the full text of an excerpt.
Returns:
None
"""
results_per_page = 10 # Number of excerpts to display per page
with excerpts_container:
n = results_per_page * int(st.session_state["excerpt_page"])
# Make snippets from the text field (short and long).
df_excepts = df.iloc[
int(st.session_state["excerpt_page"])
* results_per_page : int(st.session_state["excerpt_page"])
* results_per_page
+ results_per_page
].copy()
df_excepts["Excerpt"] = df_excepts.apply(
lambda x: make_snippet(
x["Text"], input_tokens, x["tokens"], x["input_tokens"]
),
axis=1,
)
if "df_excerpts" in st.session_state:
df_excepts = pd.concat([st.session_state["df_excerpts"], df_excepts])
st.session_state["df_excerpts"] = df_excepts
else:
st.session_state["df_excerpts"] = df_excepts
if len(df_excepts) == len(df):
st.session_state["disable_next_page_button"] = True
st.session_state["text_next_page_button"] = "No more results"
new_debate = True
debate_id = None
for _, row in df_excepts.iterrows():
n += 1
# Find out if it's a new debate.
if row["debate_id"] == debate_id:
new_debate = False
else:
new_debate = True
debate_id = row["debate_id"]
# TODO Clean names from titles etc.
# Write to table.
if new_debate:
# st.write("---", unsafe_allow_html=True)
st.markdown(
f""" <span style="font-weight: bold;">{row['Date']}</span> """,
unsafe_allow_html=True,
)
col1, col2, col3 = st.columns([2, 7, 2])
with col1:
if row["mep_id"] in meps_ids:
st.markdown(
f"""<a href="https://www.europarl.europa.eu/meps/en/{row["mep_id"]}" target="_blank">{row['Speaker']}</a>""",
unsafe_allow_html=True,
)
else:
st.markdown(row["Speaker"])
with col2:
snippet = (
row["Excerpt"]
.replace(":", "\:")
.replace("<p>", "")
.replace("</p>", "")
)
st.markdown(
f""" <span style="background-color:{party_colors_lighten[row['Party']]}; color:black;">{snippet}</span> """,
unsafe_allow_html=True,
)
with col3:
full_text = st.button("Full text", key=row["speech_id"])
if full_text:
with st.sidebar:
url_person = (
f'https://www.europarl.europa.eu/meps/en/{row["mep_id"]}'
)
st.markdown(
f""" <span class="{row['Party']}" style="font-weight: bold;">[ {row['Speaker']} ]({url_person})</span> """,
unsafe_allow_html=True,
)
st.markdown(
f""" <span style="font-style: italic;">{row["Date"]}</span> """,
unsafe_allow_html=True,
)
st.write(row["Text"].replace(":", "\:"), unsafe_allow_html=True)
st.markdown(f"📝 [Read the protocol]({row['url']})")
###* PAGE LOADS *###
# Title and explainer for streamlit
st.set_page_config(
page_title="EP Debates",
# page_icon="favicon.png",
initial_sidebar_state="auto",
)
title = "What are they talking about in the EU?"
st.title(title)
st.markdown(css, unsafe_allow_html=True)
# Get params from url.
params = Params(st.experimental_get_query_params())
params = Params({})
# The official colors of the parties
partycodes = list(party_colors.keys()) # List of partycodes
# Max hits returned by db.
return_limit = 10000
# Ask for word to search for.
user_input = st.text_input(
" ",
value=params.q,
placeholder="Search something",
# label_visibility="hidden",
help='You can use asterix (*), minus (-), quotationmarks ("") and OR.',
)
if len(user_input) > 3:
params.q = user_input
# print(user_input.upper())
# print(ollama(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions.
# User input: {user_input}
# Questions: '''))
try: #! When in procution, uncomment this.
user_input = user_input.replace("'", '"')
input_tokens = re.findall(r'(?:"[^"]*"|\S)+', user_input)
# Put user input in session state (first run).
if "user_input" not in st.session_state:
st.session_state["user_input"] = user_input
user_input_to_db(user_input)
else:
if st.session_state["user_input"] != user_input:
# Write user input to DB.
st.session_state["user_input"] = user_input
user_input_to_db(user_input)
# Reset url parameters.
params.reset(q=user_input)
params.update()
meps_df = get_speakers()
meps_ids = meps_df.mep_id.to_list()
# Check if user has searched for a specific politician.
if len(user_input.split(" ")) in [
2,
3,
4,
]: # TODO Better way of telling if name?
list_persons = meps_df["name"].to_list()
if user_input.lower() in list_persons:
sql = search_person(user_input, list_persons)
search_terms = "speaker"
aql = create_aql_query(input_tokens)
## Fetch data from DB.
df, status = get_data(aql, input_tokens)
if status == "no hits": # If no hits.
st.write("No hits. Try again!")
st.stop()
elif status == "limit":
st.write(limit_warning)
st.stop()
party_talks = pd.DataFrame(df["Party"].value_counts())
party_labels = party_talks.index.to_list() # List with active parties.
if type(party_labels) == "list":
party_labels.sort()
if user_input != "speaker":
# Let the user select parties to be included.
container_parties = st.container()
with container_parties:
style_Partyes = build_style_parties(
party_labels
) # Make the options the right colors.
st.markdown(style_Partyes, unsafe_allow_html=True)
params.parties = st.multiselect(
label="Filter on parties.",
options=party_labels,
default=party_labels,
)
if params.parties != []:
df = df.loc[df["Party"].isin(params.parties)]
if len(df) == 0:
st.stop()
# Let the user select type of debate.
container_debate = st.container()
with container_debate:
debates = df["debatetype"].unique().tolist()
debates.sort()
style = build_style_debate_types(debates)
st.markdown(style, unsafe_allow_html=True)
params.debates = st.multiselect(
label="Select type of debate",
options=debates,
default=debates,
)
if params.debates != []:
df = df.loc[df["debatetype"].isin(params.debates)]
if len(df) == 0:
st.stop()
params.update()
# Let the user select a range of years.
from_year = int(params.from_year)
to_year = int(params.to_year)
df_ = df.loc[
df["Year"].isin([i for i in range(from_year, to_year)])
] # TODO Ugly.
years = list(range(int(df["Year"].min()), int(df["Year"].max()) + 1))
if len(years) > 1:
params.from_year, params.to_year = st.select_slider(
"Select years",
list(range(int(df["Year"].min()), int(df["Year"].max()) + 1)),
value=(years[0], years[-1]),
)
df = df.loc[
df["Year"].isin(list(range(params.from_year, params.to_year + 1)))
]
elif len(years) == 1:
df = df.loc[df["Year"] == years[0]]
params.update()
if user_input != "speaker":
# Let the user select talkers.
options = options_persons(df)
style_mps = build_style_mps(options) # Make the options the right colors.
st.markdown(style_mps, unsafe_allow_html=True)
col1_persons, col2_persons = st.columns([5, 2])
# Sort alternatives in column to the right.
with col2_persons:
sort = st.selectbox(
"Sort result by", options=["Alphabetical order", "Most speeches"]
)
if sort == "Most speeches":
options = sorted(
options,
key=lambda x: [int(i) for i in x.split() if i.isdigit()][-1],
reverse=True,
)
else:
options.sort()
# Present options in column to the left.
with col1_persons:
expand_persons = st.container()
with expand_persons:
params.persons = st.multiselect(
label="Filter on parlamentarians",
options=options,
default=[],
)
# Filter df.
if params.persons != []:
params.persons = [i[: i.find(" - ")] for i in params.persons]
df = df.loc[df["Speaker"].isin(params.persons)]
params.update()
# Give df an index.
df.index = range(1, df.shape[0] + 1)
##* Start render. *##
st.markdown("---") # Draw line after filtering.
st.write(f"**Hits: {df.shape[0]}**")
if "hits" not in st.session_state:
st.session_state["hits"] = df.shape[0]
else:
if st.session_state["hits"] != df.shape[0]:
del st.session_state["df_excerpts"]
del st.session_state["excerpt_page"]
del st.session_state["text_next_page_button"]
del st.session_state["disable_next_page_button"]
st.session_state["hits"] = df.shape[0]
##! Show snippets.
expand = st.expander(
"Show excerpts from the speeches.",
expanded=False,
)
with expand:
if "excerpt_page" not in st.session_state:
st.session_state["excerpt_page"] = 0
excerpts_container = st.container()
show_excerpts()
st.session_state["excerpt_page"] += 1
if "disable_next_page_button" not in st.session_state:
st.session_state["disable_next_page_button"] = False
if "text_next_page_button" not in st.session_state:
st.session_state["text_next_page_button"] = "10 more please"
if st.button(
st.session_state["text_next_page_button"],
key=f"next_page_{st.session_state['excerpt_page']}",
disabled=st.session_state["disable_next_page_button"],
):
st.session_state["excerpt_page"] += 1
show_excerpts()
# * Download all data in df.
st.download_button(
"Download the data as CSV",
data=df.to_csv(
index=False,
sep=";",
columns=[
"speech_id",
"Text",
"Party",
"Speaker",
"Date",
"url",
],
).encode("utf-8"),
file_name=f"{user_input}.csv",
mime="text/csv",
)
# Remove talks from same party within the same session to make the
# statistics more representative.
df_ = df[["speech_id", "Party", "Year"]].drop_duplicates()
if user_input != "speaker":
## Make pie chart.
party_talks = pd.DataFrame(df_["Party"].value_counts())
party_labels = party_talks.index.to_list()
fig, ax1 = plt.subplots()
total = party_talks["Party"].sum()
mentions = party_talks["Party"] #!
ax1.pie(
mentions,
labels=party_labels,
autopct=lambda p: "{:.0f}".format(p * total / 100),
colors=[party_colors[key] for key in party_labels],
startangle=90,
)
# Make bars per year.
years = set(df["Year"].tolist())
df_years = pd.DataFrame(columns=["Party", "Year"])
for year, df_i in df.groupby("Year"):
dff = pd.DataFrame(data=df_i["Party"].value_counts())
dff["Year"] = str(year)
df_years = pd.concat([df_years, dff])
df_years["party_code"] = df_years.index
# df_years = df_years.groupby(["party_code", "Year"]).size()
df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x])
# df_years.drop(["Party"], axis=1, inplace=True)
df_years.rename(
columns={"Party": "Mentions", "party_code": "Party"}, inplace=True
)
chart = (
alt.Chart(df_years)
.mark_bar()
.encode(
x="Year",
y="Mentions",
color=alt.Color("color", scale=None),
tooltip=["Party", "Mentions"],
)
)
if user_input == "speaker":
st.altair_chart(chart, use_container_width=True)
else:
# Put the charts in a table.
fig1, fig2 = st.columns(2)
with fig1:
st.pyplot(fig)
with fig2:
st.altair_chart(chart, use_container_width=True)
# Create an empty container
container = st.empty()
# * Make a summary of the standpoints of the parties.
# A general note about the summaries. #TODO Make better.
st.markdown(summary_note)
# Count the number of speeches by party.
party_counts = df["Party"].value_counts()
# Sort the dataframe by party.
df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values(
by=["Party"]
)
# Create an empty dictionary to store the summaries.
summaries = {}
# Loop through each party and their respective speeches.
for party, df_party in df_sorted.groupby("Party"):
# Create an empty list to store the texts of each speech.
with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."):
summary = summarize(df_party, user_input)
st.markdown(
f"##### <span style='color:{party_colors[party]}'>{parties[party]} ({party})</span>",
unsafe_allow_html=True,
)
st.write(summary)
# A general note about the summaries. #TODO Make better.
# * Get feedback.
st.empty()
feedback_container = st.empty()
with feedback_container.container():
feedback = st.text_area(
"*Feel free to write suggestions for functions and improvements here!*"
)
send = st.button("Send")
if len(feedback) > 2 and send:
doc = {
"feedback": feedback,
"params": st.experimental_get_query_params(),
"where": ("streamlit", title),
"timestamp": str(datetime.now()),
"date": str(datetime.date(datetime.now())),
}
arango_db.collection("feedback").insert(doc)
feedback_container.write("*Thanks*")
params.update()
# st.markdown("##")
except Exception as e:
if (
e == "streamlit.runtime.scriptrunner.script_runner.StopException"
): # If st.stop() is used.
pass
else:
print(traceback.format_exc())
st.markdown(
":red[Something has gone wrong, I'm trying to fix it as soon as possible. Feel free to try searching for something else.]"
)
arango_db.collection("errors").insert(
{
"timestamp": str(datetime.now()),
"error": traceback.format_exc(),
"params": st.experimental_get_query_params(),
"where": ("streamlit", title),
}
)
expand_explainer = st.expander(
"*What is this? Where does the data come from? How do I do?*"
)
with expand_explainer:
st.markdown(explainer)