From 8948116afdd08f1aa861f8fe8d60cf2104f98933 Mon Sep 17 00:00:00 2001 From: lasseedfast <> Date: Sun, 5 Nov 2023 19:08:09 +0100 Subject: [PATCH] --- arango_things.py | 14 +- streamlit_app_talking_ep.py | 511 +++++++++++++++++++----------------- streamlit_info.py | 8 +- 3 files changed, 281 insertions(+), 252 deletions(-) diff --git a/arango_things.py b/arango_things.py index a043ae7..0515d81 100644 --- a/arango_things.py +++ b/arango_things.py @@ -2,7 +2,7 @@ from arango import ArangoClient, exceptions import pandas as pd import yaml -def get_documents(query=False, collection=False, fields=[], filter = '', df=False, index=False, field_names=False): +def get_documents(query=False, collection=False, fields=[], filter = '', df=False, index=False, field_names=False, limit=False): """ This function retrieves documents from a specified collection or based on a query in ArangoDB. @@ -36,10 +36,22 @@ def get_documents(query=False, collection=False, fields=[], filter = '', df=Fals fields_list = [f'{v}: doc.{k}' for k, v in fields_dict.items()] fields_string = ', '.join(fields_list) return_fields = f"{{{fields_string}}}" + + if filter != None and filter != '': + filter = f'filter {filter}' + else: + filter = '' + + if limit: + limit = f'limit {limit}' + else: + limit = '' + query = f''' for doc in {collection} {filter} + {limit} return {return_fields} ''' try: diff --git a/streamlit_app_talking_ep.py b/streamlit_app_talking_ep.py index f61b62d..e930fed 100644 --- a/streamlit_app_talking_ep.py +++ b/streamlit_app_talking_ep.py @@ -8,6 +8,7 @@ import matplotlib.pyplot as plt import pandas as pd import streamlit as st +import yaml from arango_things import arango_db from streamlit_info import ( party_colors, @@ -22,6 +23,20 @@ from streamlit_info import ( from langchain.schema import Document from llama_server import LLM +from langchain.vectorstores import Chroma +from langchain.retrievers.multi_query import MultiQueryRetriever +from langchain.embeddings import HuggingFaceEmbeddings + +from langchain.chat_models import ChatOpenAI + + +class Session(): + def __init__(self) -> None: + self.mode = '' + self.user_input = '' + self.query = '' + self.df = None + class Params: """Class containing parameters for URL. @@ -50,7 +65,6 @@ class Params: self.persons = self.set_param("persons") self.from_year = self.set_param("from_year") self.to_year = self.set_param("to_year") - self.debates = self.set_param("debates") self.excerpt_page = self.set_param("excerpt_page") def set_param(self, key): @@ -87,7 +101,6 @@ class Params: from_year=self.from_year, to_year=self.to_year, parties=",".join(self.parties), - debates=",".join(self.debates), persons=",".join(self.persons), excerpt_page=self.excerpt_page, ) @@ -187,7 +200,21 @@ def summarize(df_party, user_input): #TODO Have to understand sysmtem_prompt in llama_server #system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament." - return llama.generate(prompt=prompt) + from langchain.schema import ( + AIMessage, + HumanMessage, + SystemMessage +) + + messages = [ + SystemMessage(content="You are a journalist reporting from the EU."), + HumanMessage(content=prompt) +] + llm = ChatOpenAI(temperature=0, openai_api_key="sk-xxx", openai_api_base="http://localhost:8081/v1", max_tokens=100) + response=llm(messages) + return response.content + + #return llama.generate(prompt=prompt) def normalize_party_names(name): @@ -221,13 +248,12 @@ def normalize_party_names(name): "Vacant": "NA", "NI": "NA", "Renew": "Renew" - } return parties[name] -def make_snippet(text, input_tokens, token_text_list, token_input_list): +def make_snippet(row, input_tokens, mode): """ Find the word searched for and give it some context. @@ -241,11 +267,17 @@ def make_snippet(text, input_tokens, token_text_list, token_input_list): str: A snippet of text containing the input tokens and some context. """ # If the input token is "speaker", return the first 300 characters of the text. - if input_tokens == "speaker": + if mode == "question": + snippet = row['summary_short'] + + elif input_tokens == "speaker": snippet = str(text[:300]) if len(text) > 300: snippet += "..." else: + text = row["Text"] + token_test_list = row["tokens"] + token_input_list = row["input_tokens"] snippet = [] text_lower = text.lower() # Calculate snippet length in words. @@ -349,28 +381,6 @@ def build_style_mps(mps): return style -def fix_party(party): - """Replace old party codes with new ones.""" - party = party.upper().replace("KDS", "KD").replace("FP", "L") - return party - - -def build_style_debate_types(debates): - """Build a CSS style for debate type buttons.""" - style = "" - return style - - -def highlight_cells(party): - if party in party_colors.keys(): - color = party_colors[party] - return f"background-color: {color}; font-weight: 'bold'" - - @st.cache_data def options_persons(df): d = {} @@ -380,7 +390,7 @@ def options_persons(df): @st.cache_data -def get_data(aql, input_tokens): +def get_data(aql): """Get data from SQL database. Args: @@ -401,27 +411,19 @@ def get_data(aql, input_tokens): else: status = "ok" df = pd.DataFrame([doc for doc in cursor]) + df.mep_id = df.mep_id.astype(int) df["Year"] = df["Date"].apply(lambda x: int(x[:4])) # df.drop_duplicates(ignore_index=True, inplace=True) df["Party"] = df["Party"].apply(lambda x: normalize_party_names(x)) df.sort_values(["Date", "number"], axis=0, ascending=True, inplace=True) - - df["len_translation"] = df["Text"].apply(lambda x: len(x.split(" "))) - df["len_tokens"] = df["tokens"].apply(lambda x: len(x)) df.sort_values(["Date", "speech_id", "number"], axis=0, inplace=True) + df.drop_duplicates(ignore_index=True, inplace=True) return df, status -def user_input_to_db(user_input): - """Writes user input to db for debugging.""" - arango_db.collection("streamlit_user_input").insert( - {"timestamp": datetime.timestamp(datetime.now()), "input": user_input} - ) - - def create_aql_query(input_tokens): """Returns a valid sql query.""" # Split input into tokens @@ -453,21 +455,10 @@ def create_aql_query(input_tokens): str_tokens = str([re.sub(r"[^a-zA-Z0-9\s]", "", token) for token in input_tokens]) select_fields = f""" - {{ - "speech_id": doc._key, - "Text": doc.translation, - "number": doc.speech_number, - "debatetype": doc.debate_type, - "Speaker": doc.name, - "Date": doc.date, - "url": doc.url, - "Party": doc.party, - "mep_id": doc.mep_id, - "Summary": doc.summary, - "debate_id": doc.debate_id, + {strings['arango']['select_fields']} 'relevance_score': BM25(doc), 'tokens': TOKENS(doc.translation, 'text_en'), - 'input_tokens': TOKENS({str_tokens}, 'text_en')}} + 'input_tokens': TOKENS({str_tokens}, 'text_en') """ # Add LIMIT and RETURN clause @@ -480,17 +471,10 @@ def create_aql_query(input_tokens): return aql_query.replace("SEARCH AND", "SEARCH ") - -def error2db(error, user_input, engine): - """Write error to DB for debugging.""" - doc = ( - { - "error": error, - "time": datetime.date(datetime.now()), - "user_input": str(user_input), - }, - ) - arango_db.collection("errors").inser(doc) +def fix_party(party): + """Replace old party codes with new ones.""" + party = party.upper().replace("KDS", "KD").replace("FP", "L") + return party @st.cache_data @@ -570,9 +554,7 @@ def show_excerpts(): ].copy() df_excepts["Excerpt"] = df_excepts.apply( - lambda x: make_snippet( - x["Text"], input_tokens, x["tokens"], x["input_tokens"] - ), + lambda x: make_snippet(x, input_tokens, mode=mode), axis=1, ) @@ -618,17 +600,16 @@ def show_excerpts(): st.markdown(row["Speaker"]) with col2: - snippet = ( - row["Excerpt"] - .replace(":", "\:") - .replace("
", "") - .replace("
", "") - ) + try: + snippet = (row["Excerpt"].replace(":", "\:")) + + st.markdown( + f""" {snippet} """, + unsafe_allow_html=True, + ) + except AttributeError: + snippet = '' - st.markdown( - f""" {snippet} """, - unsafe_allow_html=True, - ) with col3: full_text = st.button("Full text", key=row["speech_id"]) if full_text: @@ -649,6 +630,34 @@ def show_excerpts(): st.markdown(f"📝 [Read the protocol]({row['url']})") +@st.cache_data(show_spinner=False) +def search_with_question(question: str, search_kwargs: dict = {}): + + embeddings = HuggingFaceEmbeddings() + vectordb = Chroma(persist_directory='chroma_db', embedding_function=embeddings) + retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={'k': 20, 'fetch_k': 50}) #https://api.python.langchain.com/en/latest/vectorstores/langchain.vectorstores.chroma.Chroma.html?highlight=as_retriever#langchain.vectorstores.chroma.Chroma.as_retriever + # About MMR https://medium.com/tech-that-works/maximal-marginal-relevance-to-rerank-results-in-unsupervised-keyphrase-extraction-22d95015c7c5 + docs = retriever.get_relevant_documents(question) + + #docs = vectordb.similarity_search(question, search_kwargs=search_kwargs, ) + docs_ids = set([doc.metadata["_id"] for doc in docs]) + return get_data(f'''FOR doc IN speeches FILTER doc._id in {list(docs_ids)} RETURN DISTINCT {{{strings['arango']['select_fields']}}}''') + + print(r) + exit() + + # llm = ChatOpenAI(temperature=0, openai_api_key="sk-xxx", openai_api_base="http://localhost:8081/v1", max_tokens=100) + + # retriever_from_llm = MultiQueryRetriever.from_llm( + # retriever=vectordb.as_retriever(search_kwargs=search_kwargs), llm=llm + # ) + # #search_kwargs={"filter":{"party":"ID"}} + + # unique_docs = retriever_from_llm.get_relevant_documents(query=question) + docs_ids = set([doc.metadata["_id"] for doc in unique_docs]) + return get_data(f'''for doc in speeches filter doc._id in {list(docs_ids)} return {{{strings['arango']['select_fields']}}}''') + + ###* PAGE LOADS *### # Title and explainer for streamlit @@ -671,7 +680,11 @@ partycodes = list(party_colors.keys()) # List of partycodes return_limit = 10000 # Initialize LLM model. -llama = LLM(temperature=0.5) +#llama = LLM(temperature=0.5) + +# Get strings from yaml file. +with open('strings.yml', 'r') as f: + strings = yaml.safe_load(f) # Ask for word to search for. user_input = st.text_input( @@ -679,63 +692,67 @@ user_input = st.text_input( value=params.q, placeholder="Search something", # label_visibility="hidden", - help='You can use asterix (*), minus (-), quotationmarks ("") and OR.', + help='You can use asterix (*), minus (-), quotationmarks ("") and OR. \nYou can also try asking a question, like *What is the Parliament doing for the climate?*', ) if len(user_input) > 3: params.q = user_input - # print(user_input.upper()) - # print(llama.generate(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions. - - # User input: {user_input} - - # Questions: ''')) - - try: + try: # To catch errors. + + meps_df = get_speakers() + meps_ids = meps_df.mep_id.to_list() + + user_input = user_input.replace("'", '"') input_tokens = re.findall(r'(?:"[^"]*"|\S)+', user_input) # Put user input in session state (first run). if "user_input" not in st.session_state: st.session_state["user_input"] = user_input - user_input_to_db(user_input) + # Write user input to DB. + arango_db.collection("streamlit_user_input").insert( + {"timestamp": datetime.timestamp(datetime.now()), "input": user_input} + ) else: if st.session_state["user_input"] != user_input: - # Write user input to DB. st.session_state["user_input"] = user_input - user_input_to_db(user_input) # Reset url parameters. params.reset(q=user_input) + # Write user input to DB. + arango_db.collection("streamlit_user_input").insert( + {"timestamp": datetime.timestamp(datetime.now()), "input": user_input} + ) params.update() + if user_input.strip()[-1] == '?': # Search documents with AI. - meps_df = get_speakers() - meps_ids = meps_df.mep_id.to_list() + mode = 'question' + with st.spinner(f"Trying to find answers..."): + df, status = search_with_question(user_input) + input_tokens = None + + else: + mode = "normal" + # Check if user has searched for a specific politician. + if len(user_input.split(" ")) in [ 2, 3, 4,]: # TODO Better way of telling if name? + list_persons = meps_df["name"].to_list() + if user_input.lower() in list_persons: + sql = search_person(user_input, list_persons) + mode = "speaker" + + aql = create_aql_query(input_tokens) - # Check if user has searched for a specific politician. - if len(user_input.split(" ")) in [ - 2, - 3, - 4, - ]: # TODO Better way of telling if name? - list_persons = meps_df["name"].to_list() - if user_input.lower() in list_persons: - sql = search_person(user_input, list_persons) - search_terms = "speaker" - - aql = create_aql_query(input_tokens) - - ## Fetch data from DB. - df, status = get_data(aql, input_tokens) - - if status == "no hits": # If no hits. - st.write("No hits. Try again!") - st.stop() - elif status == "limit": - st.write(limit_warning) - st.stop() + ## Fetch data from DB. + df, status = get_data(aql) + + if status == "no hits": # If no hits. + st.write("No hits. Try again!") + st.stop() + elif status == "limit": + st.write(limit_warning) + st.stop() party_talks = pd.DataFrame(df["Party"].value_counts()) @@ -743,7 +760,7 @@ if len(user_input) > 3: if type(party_labels) == "list": party_labels.sort() - if user_input != "speaker": + if mode != "speaker": # Let the user select parties to be included. container_parties = st.container() with container_parties: @@ -757,27 +774,17 @@ if len(user_input) > 3: default=party_labels, ) if params.parties != []: - df = df.loc[df["Party"].isin(params.parties)] + if mode == "question": + df, status = search_with_question(user_input, search_kwargs={"filter":{"party":{"$in": params.parties}}}) + else: + df = df.loc[df["Party"].isin(params.parties)] if len(df) == 0: st.stop() # Let the user select type of debate. container_debate = st.container() - with container_debate: - debates = df["debatetype"].unique().tolist() - debates.sort() - - style = build_style_debate_types(debates) - st.markdown(style, unsafe_allow_html=True) - params.debates = st.multiselect( - label="Select type of debate", - options=debates, - default=debates, - ) - if params.debates != []: - df = df.loc[df["debatetype"].isin(params.debates)] - if len(df) == 0: - st.stop() + + params.update() # Let the user select a range of years. @@ -801,7 +808,7 @@ if len(user_input) > 3: params.update() - if user_input != "speaker": + if mode != "speaker": # Let the user select talkers. options = options_persons(df) style_mps = build_style_mps(options) # Make the options the right colors. @@ -859,7 +866,7 @@ if len(user_input) > 3: del st.session_state["disable_next_page_button"] st.session_state["hits"] = df.shape[0] - ##! Show snippets. + # Show snippets. expand = st.expander( "Show excerpts from the speeches.", expanded=False, @@ -867,6 +874,7 @@ if len(user_input) > 3: with expand: if "excerpt_page" not in st.session_state: st.session_state["excerpt_page"] = 0 + excerpts_container = st.container() show_excerpts() st.session_state["excerpt_page"] += 1 @@ -883,138 +891,147 @@ if len(user_input) > 3: st.session_state["excerpt_page"] += 1 show_excerpts() - # * Download all data in df. - st.download_button( - "Download the data as CSV", - data=df.to_csv( - index=False, - sep=";", - columns=[ - "speech_id", - "Text", - "Party", - "Speaker", - "Date", - "url", - ], - ).encode("utf-8"), - file_name=f"{user_input}.csv", - mime="text/csv", - ) - # Remove talks from same party within the same session to make the - # statistics more representative. - - df_ = df[["speech_id", "Party", "Year"]].drop_duplicates() - - if user_input != "speaker": - ## Make pie chart. - - party_talks = pd.DataFrame(df_["Party"].value_counts()) - party_labels = party_talks.index.to_list() - fig, ax1 = plt.subplots() - total = party_talks["count"].sum() - mentions = party_talks["count"] #! - ax1.pie( - mentions, - labels=party_labels, - autopct=lambda p: "{:.0f}".format(p * total / 100), - colors=[party_colors[key] for key in party_labels], - startangle=90, - ) + if mode != 'question': + if mode != "speaker": - # Make bars per year. - years = set(df["Year"].tolist()) + # Remove talks from same party within the same session to make the + # statistics more representative. + df_ = df[["speech_id", "Party", "Year"]].drop_duplicates() - df_years = pd.DataFrame(columns=["Party", "Year"]) + ## Make pie chart. - for year, df_i in df.groupby("Year"): - dff = pd.DataFrame(data=df_i["Party"].value_counts()) - dff["Year"] = str(year) - df_years = pd.concat([df_years, dff]) - df_years["party_code"] = df_years.index - # df_years = df_years.groupby(["party_code", "Year"]).size() - df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x]) - # df_years.drop(["Party"], axis=1, inplace=True) + party_talks = pd.DataFrame(df_["Party"].value_counts()) + party_labels = party_talks.index.to_list() + fig, ax1 = plt.subplots() + total = party_talks["count"].sum() + mentions = party_talks["count"] #! + ax1.pie( + mentions, + labels=party_labels, + autopct=lambda p: "{:.0f}".format(p * total / 100), + colors=[party_colors[key] for key in party_labels], + startangle=90, + ) - df_years.rename( - columns={"Party": "Mentions", "party_code": "Party"}, inplace=True - ) - chart = ( - alt.Chart(df_years) - .mark_bar() - .encode( - x="Year", - y="count", - color=alt.Color("color", scale=None), - tooltip=["Party", "Mentions"], - ) - ) + # Make bars per year. + years = set(df["Year"].tolist()) - if user_input == "speaker": - st.altair_chart(chart, use_container_width=True) + df_years = pd.DataFrame(columns=["Party", "Year"]) - else: - # Put the charts in a table. - fig1, fig2 = st.columns(2) - with fig1: - st.pyplot(fig) - with fig2: + for year, df_i in df.groupby("Year"): + dff = pd.DataFrame(data=df_i["Party"].value_counts()) + dff["Year"] = str(year) + df_years = pd.concat([df_years, dff]) + df_years["party_code"] = df_years.index + # df_years = df_years.groupby(["party_code", "Year"]).size() + df_years["color"] = df_years["party_code"].apply(lambda x: party_colors[x]) + # df_years.drop(["Party"], axis=1, inplace=True) + + df_years.rename( + columns={"Party": "Mentions", "party_code": "Party"}, inplace=True + ) + chart = ( + alt.Chart(df_years) + .mark_bar() + .encode( + x="Year", + y="count", + color=alt.Color("color", scale=None), + tooltip=["Party", "Mentions"], + ) + ) + + if user_input == "speaker": st.altair_chart(chart, use_container_width=True) - # Create an empty container - container = st.empty() + else: + # Put the charts in a table. + fig1, fig2 = st.columns(2) + with fig1: + st.pyplot(fig) + with fig2: + st.altair_chart(chart, use_container_width=True) - # * Make a summary of the standpoints of the parties. + # Create an empty container + container = st.empty() - # A general note about the summaries. #TODO Make better. - st.markdown(summary_note) - # Count the number of speeches by party. - party_counts = df["Party"].value_counts() - # Sort the dataframe by party. - df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values( - by=["Party"] - ) - # Create an empty dictionary to store the summaries. - summaries = {} - # Loop through each party and their respective speeches. - for party, df_party in df_sorted.groupby("Party"): - # Create an empty list to store the texts of each speech. - - with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."): - summary = summarize(df_party, user_input) - st.markdown( - f"##### {parties[party]} ({party})", - unsafe_allow_html=True, - ) - st.write(summary) + # * Make a summary of the standpoints of the parties. + summarize_please = st.button("Summarize please!", ) + if summarize_please: + # A general note about the summaries. #TODO Make better. + st.markdown(summary_note) + + # Count the number of speeches by party. + party_counts = df["Party"].value_counts() + # Sort the dataframe by party. + df_sorted = df.loc[df["Party"].isin(party_counts.index)].sort_values( + by=["Party"] + ) + # Create an empty dictionary to store the summaries. + summaries = {} + # Loop through each party and their respective speeches. + for party, df_party in df_sorted.groupby("Party"): + # Create an empty list to store the texts of each speech. + + with st.spinner(f"Summarizing speeches by **{party}** parliamentarians..."): + summary = summarize(df_party, user_input) + st.markdown( + f"##### {parties[party]} ({party})", + unsafe_allow_html=True, + ) + st.write(summary) - # A general note about the summaries. #TODO Make better. + # A general note about the summaries. #TODO Make better. - # * Get feedback. - st.empty() - feedback_container = st.empty() + feedback_column, download_column = st.columns([3, 1]) + + with feedback_column: + # * Get feedback. + st.empty() + feedback_container = st.empty() - with feedback_container.container(): - feedback = st.text_area( - "*Feel free to write suggestions for functions and improvements here!*" + with feedback_container.container(): + feedback = st.text_area( + 'Feedback', placeholder='Feel free to write suggestions for functions and improvements here!', label_visibility='collapsed' + ) + send = st.button("Send") + if len(feedback) > 2 and send: + doc = { + "feedback": feedback, + "params": st.experimental_get_query_params(), + "where": ("streamlit", title), + "timestamp": str(datetime.now()), + "date": str(datetime.date(datetime.now())), + } + arango_db.collection("feedback").insert(doc) + + feedback_container.write("*Thanks*") + params.update() + + # st.markdown("##") + + with download_column: + # * Download all data in df. + st.download_button( + "Download data as CSV-file", + data=df.to_csv( + index=False, + sep=";", + columns=[ + "speech_id", + "Text", + "Party", + "Speaker", + "Date", + "url", + ], + ).encode("utf-8"), + file_name=f"{user_input}.csv", + mime="text/csv", ) - send = st.button("Send") - if len(feedback) > 2 and send: - doc = { - "feedback": feedback, - "params": st.experimental_get_query_params(), - "where": ("streamlit", title), - "timestamp": str(datetime.now()), - "date": str(datetime.date(datetime.now())), - } - arango_db.collection("feedback").insert(doc) - - feedback_container.write("*Thanks*") - params.update() - # st.markdown("##") except Exception as e: if ( diff --git a/streamlit_info.py b/streamlit_info.py index bc1a8ca..6efa9a9 100644 --- a/streamlit_info.py +++ b/streamlit_info.py @@ -69,7 +69,7 @@ It's a summary of the ten most relevant speeches from each party based on the se Please make sure to check the original text before you use the summary in any way. """ explainer = """This is a database of what members of the European Parliamen have said in various debates in the parliament since 2019. -The data comes partly from the EU. +The data comes from the EU have been translated when not in English. - Start by typing one or more keywords below. You can use asterix (*), minus(-), quotation marks (""), OR and year\:yyyy-yyyy. The search `energy crisis* basic power OR nuclear power "fossil-free energy sources" -wind power year:2019-2022` is looking for quotes like\: - mentions "energy crisis" (incl. e.g. "energy crisis*") @@ -77,9 +77,9 @@ The data comes partly from the EU. - mentions the *exact phrase* "fossil-free energy sources" - *does* not mention "wind power" - found during the years 2019-2022 -- When you have received your results, you can then click away matches or change which years and debate types you are interested in. -- Under "Longer excerpt" you can choose to see the entire speech in text, and under the text there are links to the Riksdag's Web TV and downloadable audio (in the cases -where the debate has been broadcast). +- You can also ask a specific quesion, like `What have parliamentarians said about the energy crisis?` Remember to put a question mark at the end of the question. +- When you have received your results, you can filter on which years you are interested in. +- Under "Excerpt" you can choose to see the entire speech in text, and under the text there are links to official protocol. Please tell us how you would like to use the data and about things that don't work. [Email me](mailto:lasse@edfast.se) or [write to me on Twitter](https://twitter.com/lasseedfast). My name is [Lasse Edfast and I'm a journalist](https://lasseedfast.se) based in Sweden.