import difflib import re import traceback from datetime import datetime import altair as alt import matplotlib.pyplot as plt import pandas as pd import streamlit as st from arango_things import arango_db from things import normalize_party_names from streamlit_info import ( party_colors, explainer, limit_warning, party_colors_lighten, css, results_limit, parties, summary_note, ) from langchain.schema import Document class Params: """Class containing parameters for URL. Attributes: params (dict): A dictionary containing the parameters. q (str): The search query. parties (list): A list of political parties. persons (list): A list of persons. from_year (int): The start year for the search. to_year (int): The end year for the search. debates (list): A list of debates. excerpt_page (int): The page number for the excerpt. """ def __init__(self, params): """Initializes the Params class. Args: params (dict): A dictionary containing the parameters. """ self.params = params # Set parameters. self.q = self.set_param("q") self.parties = self.set_param("parties") self.persons = self.set_param("persons") self.from_year = self.set_param("from_year") self.to_year = self.set_param("to_year") self.debates = self.set_param("debates") self.excerpt_page = self.set_param("excerpt_page") def set_param(self, key): """Returns the value of a parameter if it exists. Args: key (str): The key of the parameter. Returns: The value of the parameter if it exists, otherwise a default value. """ if key in self.params: if key in ["parties", "persons", "debates"]: value = self.params[key][0].split(",") else: value = self.params[key][0] else: value = [] if key == "q": value = "" elif key == "from_year": value = 2019 # Catch all. elif key == "to_year": value = 2023 # Catch all. elif key == "excerpt_page": value = 0 return value def update(self): """Updates the URL parameters.""" st.experimental_set_query_params( q=self.q, from_year=self.from_year, to_year=self.to_year, parties=",".join(self.parties), debates=",".join(self.debates), persons=",".join(self.persons), excerpt_page=self.excerpt_page, ) def update_param(self, key, value): """Updates a single parameter. Args: key (str): The key of the parameter to update. value (str or int or list): The new value of the parameter. """ self.params[key] = value self.update() def reset(self, q=False): """Resets all parameters to their default values. Args: q (str): The new search query (optional). """ for key in self.params: self.params[key] = [] if q: self.q = q @st.cache_data(show_spinner=False) def summarize(df_party, user_input): """ Summarizes the standpoints of a political party in the EU parliament based on a user search. Args: df_party (pandas.DataFrame): A DataFrame containing speeches by parliamentarians from the political party in the EU parliament. user_input (str): The user's search query. Returns: str: A short summary of the standpoints of the political party in the EU parliament, in relation to the user search. """ texts = [] documents = [] for _, row in df_party.iterrows(): documents.append( Document( page_content=f"**{row['Speaker']} {row['Date']}**\n{row['Summary']}", metadata=row.to_dict(), ) ) # https://python.langchain.com/docs/use_cases/question_answering/qa_citations # Get the 12 most relevant speeches (relevance defined by BM25 in the arango query). if df_party.shape[0] > 10: df_party = df_party.sort_values(by="relevance_score", ascending=False)[:10] if df_party.shape[0] > 5: # Sample 6 rows with weights proportional to the year counts sample_size = 5 weights = df_party["Year"].value_counts(normalize=True) df_party = df_party.groupby("Year", group_keys=False).apply( lambda x: x.sample(n=int(sample_size * weights[x.name])) ) # Sort the dataframe by date. df_party.sort_values(by="Date", ascending=False, inplace=True) # Loop through the speeches and add its text to the list. for _, row in df_party.sort_values(by="Date", ascending=False)[:10].iterrows(): texts.append(f"**{row['Speaker']} {row['Date']}**\n{row['Summary']}") # Join the texts into a single string. # Create a prompt for the user to summarize the party's standpoints. prompt = re.sub( r"\W\W\W+", "\n", f""" A user have made this search in a database: <{user_input}> That resulted in the lists below, consisting of speeches by parliamentarians from the {party} party in the EU parliament: {texts} Please make a very short summary of the standpoints of the {party} party in the EU parliament, in relation to the user search. Make note of the speaker and the dates for the speeches. Example 1: In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget. Example 2: Parliamentarians from {party} have been very active in the debate about the EU budget. They have been both for and against the budget. Short summary: """, ) # Generate a summary of the party's standpoints. # Return the summaries dictionary. system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament." return ollama( prompt=prompt, system_prompt=system_prompt, temperature=0.5, print_tokens=False ) def make_snippet(text, input_tokens, token_text_list, token_input_list): """ Find the word searched for and give it some context. Args: text (str): The text to search for the input tokens. input_tokens (str): The input tokens to search for in the text. token_text_list (list): A list of tokens in the text. token_input_list (list): A list of input tokens. Returns: str: A snippet of text containing the input tokens and some context. """ # If the input token is "speaker", return the first 300 characters of the text. if input_tokens == "speaker": snippet = str(text[:300]) if len(text) > 300: snippet += "..." else: snippet = [] text_lower = text.lower() # Calculate snippet length in words. snippet_length = 40 * int(8 / len(input_tokens)) # * Change to another value? # Loop through each input token. for token in input_tokens: # Remove any non-alphanumeric characters and convert to lowercase. token = re.sub(r"[^a-zA-Z0-9\s]", "", token).lower() if token in text_lower: # Find the position of the token in the text. position = text_lower.find(token) position_start = position - snippet_length # Find the start of the snippet by looking for the previous space character. while True and position_start > 0: position_start -= 1 if text_lower[position_start] == " ": break position_end = position + int(snippet_length / 2) # Find the end of the snippet by looking for the next space character. while True and position_end > 0 and position_end < len(text_lower) - 1: position_end += 1 if text_lower[position_end] == " ": break text_before = "".join(text[position_start:position]) if position_end > len(text): position_end = len(text) text_after = "".join(text[position + len(token) + 1 : position_end]) word_context = f"{text_before} {token} {text_after}" snippet.append(word_context) if len(snippet) == 2: break if snippet == [""]: # If the input token is not found in the text, find the closest match in the token list. text_list = text.split(" ") for token in token_input_list: token = "".join(token) if token in token_text_list: position = token_text_list.index(token) w = 20 # Get a list of words around the token, and find the closest match to the token in that list. list_from_text = text_list[position - w : position + w] closest_match = difflib.get_close_matches( token, list_from_text, n=1 ) # If no match is found, expand the list of words and try again. if not closest_match: w = 50 list_from_text = text_list[position - w : position + w] closest_match = difflib.get_close_matches( token, list_from_text, n=1 ) if closest_match: word = closest_match[0] # Calculate the difference between the position of the token and the position of the closest match. diff_token_to_word = list_from_text.index(word) - w position = position + diff_token_to_word else: word = text_list[position] # Get the context of the word, including the words before and after it. position_start = position - snippet_length if position_start < 0: position_start = 0 position_end = position + int(snippet_length / 2) text_before = " ".join(text_list[position_start:position]) if position_end > len(text_list): position_end = len(text_list) text_after = " ".join(text_list[position + 1 : position_end]) word_context = f"{text_before} {word} {text_after}" snippet.append(word_context) snippet = "|".join(snippet) snippet = f"...{snippet}..." return snippet def build_style_parties(parties): """Build a CSS styl for party names buttons.""" style = "" return style def build_style_mps(mps): """Build a CSS style for party names buttons.""" style = "" return style def fix_party(party): """Replace old party codes with new ones.""" party = party.upper().replace("KDS", "KD").replace("FP", "L") return party def build_style_debate_types(debates): """Build a CSS style for debate type buttons.""" style = "" return style def highlight_cells(party): if party in party_colors.keys(): color = party_colors[party] return f"background-color: {color}; font-weight: 'bold'" @st.cache_data def options_persons(df): d = {} for i in df.groupby("Speaker"): d[i[0]] = i[1].shape[0] return [f"{key} - {value}" for key, value in d.items()] @st.cache_data def get_data(aql, input_tokens): """Get data from SQL database. Args: sql (str): A SQL query string. Returns: DataFrame: Dataframe with some adjustments to the data fetched from the DB. """ cursor = arango_db.aql.execute(aql, count=True) if cursor.__len__() == 0: df = None status = "no hits" else: if cursor.__len__() == results_limit: status = "limit" else: status = "ok" df = pd.DataFrame([doc for doc in cursor]) df.mep_id = df.mep_id.astype(int) df["Year"] = df["Date"].apply(lambda x: int(x[:4])) # df.drop_duplicates(ignore_index=True, inplace=True) df["Party"] = df["Party"].apply(lambda x: normalize_party_names(x)) df.sort_values(["Date", "number"], axis=0, ascending=True, inplace=True) df["len_translation"] = df["Text"].apply(lambda x: len(x.split(" "))) df["len_tokens"] = df["tokens"].apply(lambda x: len(x)) df.sort_values(["Date", "speech_id", "number"], axis=0, inplace=True) return df, status def user_input_to_db(user_input): """Writes user input to db for debugging.""" arango_db.collection("streamlit_user_input").insert( {"timestamp": datetime.timestamp(datetime.now()), "input": user_input} ) def create_aql_query(input_tokens): """Returns a valid sql query.""" # Split input into tokens # Build AQL query aql_query = "FOR doc IN speeches_view\nSEARCH " # Add SEARCH PHRASE clause search_phrase = next((token for token in input_tokens if " " in token), None) if search_phrase: aql_query += f'PHRASE(doc.translation, {search_phrase}, "text_en")\n' input_tokens.remove(search_phrase) # Add ANALYZER clause for each token for token in input_tokens: if token.startswith("-"): if "*" in token: aql_query += ( f'AND NOT LIKE(doc.translation, "{token[1:].replace("*", "%")}")\n' ) else: aql_query += f'AND NOT ANALYZER(TOKENS("{token[1:]}", "text_en") ANY == doc.translation, "text_en")\n' elif "*" in token: aql_query += f'AND LIKE(doc.translation, "{token.replace("*", "%")}")\n' else: aql_query += f'AND ANALYZER(TOKENS("{token}", "text_en") ANY == doc.translation, "text_en")\n' # Add fields for relevance and tokens. str_tokens = str([re.sub(r"[^a-zA-Z0-9\s]", "", token) for token in input_tokens]) select_fields = f""" {{ "speech_id": doc._key, "Text": doc.translation, "number": doc.speech_number, "debatetype": doc.debate_type, "Speaker": doc.name, "Date": doc.date, "url": doc.url, "Party": doc.party, "mep_id": doc.mep_id, "Summary": doc.summary, "debate_id": doc.debate_id, 'relevance_score': BM25(doc), 'tokens': TOKENS(doc.translation, 'text_en'), 'input_tokens': TOKENS({str_tokens}, 'text_en')}} """ # Add LIMIT and RETURN clause aql_query += f""" FILTER char_length(doc.summary) > 10 LIMIT {results_limit} SORT BM25(doc) DESC RETURN {select_fields} """ return aql_query.replace("SEARCH AND", "SEARCH ") def error2db(error, user_input, engine): """Write error to DB for debugging.""" doc = ( { "error": error, "time": datetime.date(datetime.now()), "user_input": str(user_input), }, ) arango_db.collection("errors").inser(doc) @st.cache_data def get_speakers(): """Get all meps.""" df = pd.DataFrame( [ i for i in arango_db.aql.execute( "for doc in meps return {'name': doc.name, 'mep_id': doc._key, mep_url: doc.mep_URI}" ) ] ) df.mep_id = df.mep_id.astype(int) return df def search_person(user_input, list_persons): """Returns SQL query made for searching everything a defined speaker has said. Args: user_input (str): The string resulting from user input (input()). Returns: list: List of search terms. """ # List all alternatives. options = [i for i in list_persons if i.lower() == user_input.lower()] options = [f"Yes, search for {i}" for i in options] no_option = f"No, I want to search for what have been said about {user_input}." options += [no_option, "Choose an alternative"] preselected_option = len(options) - 1 # Let the user select a person or no_alternative. speaker = st.selectbox( ":red[Do you want to search for what a specific parliamentarian has said?]", options, index=preselected_option, ) if speaker == "Choose an alternative": st.stop() if speaker == no_option: # Return "normal" query if no_alternative. aql = create_aql_query(user_input) else: speaker = speaker.replace("Yes, search for ", "") aql = f"FOR doc IN meps FILTER doc.name == '{speaker}' RETURN doc" return aql def show_excerpts(): """ Displays excerpts from a dataframe in a Streamlit app. The function uses the global variables `df`, `input_tokens`, `party_colors_lighten`, and `excerpts_container`. It also uses the session state variables `excerpt_page`, `df_excerpts`, `disable_next_page_button`, and `text_next_page_button`. The function displays a specified number of excerpts per page, and allows the user to navigate between pages. Each excerpt includes the speaker, date, party, and a snippet of text. The user can click a button to view the full text of an excerpt. Returns: None """ results_per_page = 10 # Number of excerpts to display per page with excerpts_container: n = results_per_page * int(st.session_state["excerpt_page"]) # Make snippets from the text field (short and long). df_excepts = df.iloc[ int(st.session_state["excerpt_page"]) * results_per_page : int(st.session_state["excerpt_page"]) * results_per_page + results_per_page ].copy() df_excepts["Excerpt"] = df_excepts.apply( lambda x: make_snippet( x["Text"], input_tokens, x["tokens"], x["input_tokens"] ), axis=1, ) if "df_excerpts" in st.session_state: df_excepts = pd.concat([st.session_state["df_excerpts"], df_excepts]) st.session_state["df_excerpts"] = df_excepts else: st.session_state["df_excerpts"] = df_excepts if len(df_excepts) == len(df): st.session_state["disable_next_page_button"] = True st.session_state["text_next_page_button"] = "No more results" new_debate = True debate_id = None for _, row in df_excepts.iterrows(): n += 1 # Find out if it's a new debate. if row["debate_id"] == debate_id: new_debate = False else: new_debate = True debate_id = row["debate_id"] # TODO Clean names from titles etc. # Write to table. if new_debate: # st.write("---", unsafe_allow_html=True) st.markdown( f""" {row['Date']} """, unsafe_allow_html=True, ) col1, col2, col3 = st.columns([2, 7, 2]) with col1: if row["mep_id"] in meps_ids: st.markdown( f"""{row['Speaker']}""", unsafe_allow_html=True, ) else: st.markdown(row["Speaker"]) with col2: snippet = ( row["Excerpt"] .replace(":", "\:") .replace("

", "") .replace("

", "") ) st.markdown( f""" {snippet} """, unsafe_allow_html=True, ) with col3: full_text = st.button("Full text", key=row["speech_id"]) if full_text: with st.sidebar: url_person = ( f'https://www.europarl.europa.eu/meps/en/{row["mep_id"]}' ) st.markdown( f""" [ {row['Speaker']} ]({url_person}) """, unsafe_allow_html=True, ) st.markdown( f""" {row["Date"]} """, unsafe_allow_html=True, ) st.write(row["Text"].replace(":", "\:"), unsafe_allow_html=True) st.markdown(f"📝 [Read the protocol]({row['url']})") ###* PAGE LOADS *### # Title and explainer for streamlit st.set_page_config( page_title="EP Debates", # page_icon="favicon.png", initial_sidebar_state="auto", ) title = "What are they talking about in the EU?" st.title(title) st.markdown(css, unsafe_allow_html=True) # Get params from url. params = Params(st.experimental_get_query_params()) params = Params({}) # The official colors of the parties partycodes = list(party_colors.keys()) # List of partycodes # Max hits returned by db. return_limit = 10000 # Ask for word to search for. user_input = st.text_input( " ", value=params.q, placeholder="Search something", # label_visibility="hidden", help='You can use asterix (*), minus (-), quotationmarks ("") and OR.', ) if len(user_input) > 3: params.q = user_input # print(user_input.upper()) # print(ollama(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions. # User input: {user_input} # Questions: ''')) try: #! When in procution, uncomment this. user_input = user_input.replace("'", '"') input_tokens = re.findall(r'(?:"[^"]*"|\S)+', user_input) # Put user input in session state (first run). if "user_input" not in st.session_state: st.session_state["user_input"] = user_input user_input_to_db(user_input) else: if st.session_state["user_input"] != user_input: # Write user input to DB. st.session_state["user_input"] = user_input user_input_to_db(user_input) # Reset url parameters. params.reset(q=user_input) params.update() meps_df = get_speakers() meps_ids = meps_df.mep_id.to_list() # Check if user has searched for a specific politician. if len(user_input.split(" ")) in [ 2, 3, 4, ]: # TODO Better way of telling if name? list_persons = meps_df["name"].to_list() if user_input.lower() in list_persons: sql = search_person(user_input, list_persons) search_terms = "speaker" aql = create_aql_query(input_tokens) ## Fetch data from DB. df, status = get_data(aql, input_tokens) if status == "no hits": # If no hits. st.write("No hits. Try again!") st.stop() elif status == "limit": st.write(limit_warning) st.stop() party_talks = pd.DataFrame(df["Party"].value_counts()) party_labels = party_talks.index.to_list() # List with active parties. if type(party_labels) == "list": party_labels.sort() if user_input != "speaker": # Let the user select parties to be included. container_parties = st.container() with container_parties: style_Partyes = build_style_parties( party_labels ) # Make the options the right colors. st.markdown(style_Partyes, unsafe_allow_html=True) params.parties = st.multiselect( label="Filter on parties.", options=party_labels, default=party_labels, ) if params.parties != []: df = df.loc[df["Party"].isin(params.parties)] if len(df) == 0: st.stop() # Let the user select type of debate. container_debate = st.container() with container_debate: debates = df["debatetype"].unique().tolist() debates.sort() style = build_style_debate_types(debates) st.markdown(style, unsafe_allow_html=True) params.debates = st.multiselect( label="Select type of debate", options=debates, default=debates, ) if params.debates != []: df = df.loc[df["debatetype"].isin(params.debates)] if len(df) == 0: st.stop() params.update() # Let the user select a range of years. from_year = int(params.from_year) to_year = int(params.to_year) df_ = df.loc[ df["Year"].isin([i for i in range(from_year, to_year)]) ] # TODO Ugly. years = list(range(int(df["Year"].min()), int(df["Year"].max()) + 1)) if len(years) > 1: params.from_year, params.to_year = st.select_slider( "Select years", list(range(int(df["Year"].min()), int(df["Year"].max()) + 1)), value=(years[0], years[-1]), ) df = df.loc[ df["Year"].isin(list(range(params.from_year, params.to_year + 1))) ] elif len(years) == 1: df = df.loc[df["Year"] == years[0]] params.update() if user_input != "speaker": # Let the user select talkers. options = options_persons(df) style_mps = build_style_mps(options) # Make the options the right colors. st.markdown(style_mps, unsafe_allow_html=True) col1_persons, col2_persons = st.columns([5, 2]) # Sort alternatives in column to the right. with col2_persons: sort = st.selectbox( "Sort result by", options=["Alphabetical order", "Most speeches"] ) if sort == "Most speeches": options = sorted( options, key=lambda x: [int(i) for i in x.split() if i.isdigit()][-1], reverse=True, ) else: options.sort() # Present options in column to the left. with col1_persons: expand_persons = st.container() with expand_persons: params.persons = st.multiselect( label="Filter on parlamentarians", options=options, default=[], ) # Filter df. if params.persons != []: params.persons = [i[: i.find(" - ")] for i in params.persons] df = df.loc[df["Speaker"].isin(params.persons)] params.update() # Give df an index. df.index = range(1, df.shape[0] + 1) ##* Start render. *## st.markdown("---") # Draw line after filtering. st.write(f"**Hits: {df.shape[0]}**") if "hits" not in st.session_state: st.session_state["hits"] = df.shape[0] else: if st.session_state["hits"] != df.shape[0]: del st.session_state["df_excerpts"] del st.session_state["excerpt_page"] del st.session_state["text_next_page_button"] del st.session_state["disable_next_page_button"] st.session_state["hits"] = df.shape[0] ##! Show snippets. expand = st.expander( "Show excerpts from the speeches.", expanded=False, ) with expand: if "excerpt_page" not in st.session_state: st.session_state["excerpt_page"] = 0 excerpts_container = st.container() show_excerpts() st.session_state["excerpt_page"] += 1 if "disable_next_page_button" not in st.session_state: st.session_state["disable_next_page_button"] = False if "text_next_page_button" not in st.session_state: st.session_state["text_next_page_button"] = "10 more please" if st.button( st.session_state["text_next_page_button"], key=f"next_page_{st.session_state['excerpt_page']}", disabled=st.session_state["disable_next_page_button"], ): st.session_state["excerpt_page"] += 1 show_excerpts() # * Download all data in df. st.download_button( "Download the data as CSV", data=df.to_csv( index=False, sep=";", columns=[ "speech_id", "Text", "Party", "Speaker", "Date", "url", ], ).encode("utf-8"), file_name=f"{user_input}.csv", mime="text/csv", ) # Remove talks from same party within the same session to make the # statistics more representative. df_ = df[["speech_id", "Party", "Year"]].drop_duplicates() if user_input != "speaker": ## Make pie chart. party_talks = pd.DataFrame(df_["Party"].value_counts()) party_labels = party_talks.index.to_list() fig, ax1 = plt.subplots() total = party_talks["Party"].sum() mentions = party_talks["Party"] #! ax1.pie( mentions, labels=party_labels, autopct=lambda p: "{:.0f}".format(p * total / 100), colors=[party_colors[key] for key in party_labels], startangle=90, ) # Make bars per year. years = set(df["Year"].tolist()) df_years = pd.DataFrame(columns=["Party", "Year"]) for year, df_i in df.groupby("Year"): dff = pd.DataFrame(data=df_i["Party"].value_counts()) dff["Year"] = str(year) df_years = pd.concat([df_years, dff]) df_years["party_code"] = df_years.index # df_years = df_years.groupby(["party_code", "Year"]).size() df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x]) # df_years.drop(["Party"], axis=1, inplace=True) df_years.rename( columns={"Party": "Mentions", "party_code": "Party"}, inplace=True ) chart = ( alt.Chart(df_years) .mark_bar() .encode( x="Year", y="Mentions", color=alt.Color("color", scale=None), tooltip=["Party", "Mentions"], ) ) if user_input == "speaker": st.altair_chart(chart, use_container_width=True) else: # Put the charts in a table. fig1, fig2 = st.columns(2) with fig1: st.pyplot(fig) with fig2: st.altair_chart(chart, use_container_width=True) # Create an empty container container = st.empty() # * Make a summary of the standpoints of the parties. # A general note about the summaries. #TODO Make better. st.markdown(summary_note) # Count the number of speeches by party. party_counts = df["Party"].value_counts() # Sort the dataframe by party. df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values( by=["Party"] ) # Create an empty dictionary to store the summaries. summaries = {} # Loop through each party and their respective speeches. for party, df_party in df_sorted.groupby("Party"): # Create an empty list to store the texts of each speech. with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."): summary = summarize(df_party, user_input) st.markdown( f"##### {parties[party]} ({party})", unsafe_allow_html=True, ) st.write(summary) # A general note about the summaries. #TODO Make better. # * Get feedback. st.empty() feedback_container = st.empty() with feedback_container.container(): feedback = st.text_area( "*Feel free to write suggestions for functions and improvements here!*" ) send = st.button("Send") if len(feedback) > 2 and send: doc = { "feedback": feedback, "params": st.experimental_get_query_params(), "where": ("streamlit", title), "timestamp": str(datetime.now()), "date": str(datetime.date(datetime.now())), } arango_db.collection("feedback").insert(doc) feedback_container.write("*Thanks*") params.update() # st.markdown("##") except Exception as e: if ( e == "streamlit.runtime.scriptrunner.script_runner.StopException" ): # If st.stop() is used. pass else: print(traceback.format_exc()) st.markdown( ":red[Something has gone wrong, I'm trying to fix it as soon as possible. Feel free to try searching for something else.]" ) arango_db.collection("errors").insert( { "timestamp": str(datetime.now()), "error": traceback.format_exc(), "params": st.experimental_get_query_params(), "where": ("streamlit", title), } ) expand_explainer = st.expander( "*What is this? Where does the data come from? How do I do?*" ) with expand_explainer: st.markdown(explainer)