From 275ae9ee133a4eed49248227ca13e86ea80a83bc Mon Sep 17 00:00:00 2001 From: lasseedfast <> Date: Sun, 29 Oct 2023 18:49:54 +0100 Subject: [PATCH] Made working with llama.cpp --- streamlit_app_talking_ep.py | 80 ++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 19 deletions(-) diff --git a/streamlit_app_talking_ep.py b/streamlit_app_talking_ep.py index 19e234c..1afb116 100644 --- a/streamlit_app_talking_ep.py +++ b/streamlit_app_talking_ep.py @@ -8,9 +8,7 @@ import matplotlib.pyplot as plt import pandas as pd import streamlit as st - from arango_things import arango_db -from things import normalize_party_names from streamlit_info import ( party_colors, explainer, @@ -23,7 +21,7 @@ from streamlit_info import ( ) from langchain.schema import Document - +from llama_server import LLM class Params: """Class containing parameters for URL. @@ -171,7 +169,8 @@ def summarize(df_party, user_input): {texts} - Please make a very short summary of the standpoints of the {party} party in the EU parliament, in relation to the user search. Make note of the speaker and the dates for the speeches. + You are a journalist and are going to write a short summary of the standpoints of the {party} party in the EU parliament, in relation to the user search. + Make note of the speaker and the dates for the speeches. Example 1: In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget. @@ -185,11 +184,48 @@ def summarize(df_party, user_input): # Generate a summary of the party's standpoints. # Return the summaries dictionary. - system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament." + + #TODO Have to understand sysmtem_prompt in llama_server + #system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament." - return ollama( - prompt=prompt, system_prompt=system_prompt, temperature=0.5, print_tokens=False - ) + return llama.generate(prompt=prompt) + + +def normalize_party_names(name): + """ + Normalizes party names to the format used in the database. + + Parameters: + name (str): The party name to be normalized. + + Returns: + str: The normalized party name. + """ + + parties = { + "EPP": "EPP", + "PPE": "EPP", + "RE": "Renew", + "S-D": "S&D", + "S&D": "S&D", + "ID": "ID", + "ECR": "ECR", + "GUE/NGL": "GUE/NGL", + "The Left": "GUE/NGL", + "Greens/EFA": "Greens/EFA", + "G/EFA": "Greens/EFA", + "Verts/ALE": "Greens/EFA", + "NA": "NA", + "NULL": "NA", + None: "NA", + "-": "NA", + "Vacant": "NA", + "NI": "NA", + "Renew": "Renew" + + } + + return parties[name] def make_snippet(text, input_tokens, token_text_list, token_input_list): @@ -214,7 +250,7 @@ def make_snippet(text, input_tokens, token_text_list, token_input_list): snippet = [] text_lower = text.lower() # Calculate snippet length in words. - snippet_length = 40 * int(8 / len(input_tokens)) # * Change to another value? + snippet_length = 40 * int(10 / len(input_tokens)+1) # * Change to another value? # Loop through each input token. for token in input_tokens: @@ -635,6 +671,9 @@ partycodes = list(party_colors.keys()) # List of partycodes # Max hits returned by db. return_limit = 10000 +# Initialize LLM model. +llama = LLM(temperature=0.5) + # Ask for word to search for. user_input = st.text_input( " ", @@ -644,18 +683,17 @@ user_input = st.text_input( help='You can use asterix (*), minus (-), quotationmarks ("") and OR.', ) - if len(user_input) > 3: params.q = user_input # print(user_input.upper()) - # print(ollama(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions. + # print(llama.generate(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions. # User input: {user_input} # Questions: ''')) - try: #! When in procution, uncomment this. + try: #! When in prodution, uncomment this. user_input = user_input.replace("'", '"') input_tokens = re.findall(r'(?:"[^"]*"|\S)+', user_input) @@ -811,10 +849,15 @@ if len(user_input) > 3: st.session_state["hits"] = df.shape[0] else: if st.session_state["hits"] != df.shape[0]: - del st.session_state["df_excerpts"] - del st.session_state["excerpt_page"] - del st.session_state["text_next_page_button"] + if "df_excerpts" in st.session_state: + del st.session_state["df_excerpts"] + if "excerpt_page" in st.session_state: + del st.session_state["excerpt_page"] + if 'text_next_page_button' in st.session_state: + del st.session_state["text_next_page_button"] del st.session_state["disable_next_page_button"] + if 'disable_next_page_button' in st.session_state: + del st.session_state["disable_next_page_button"] st.session_state["hits"] = df.shape[0] ##! Show snippets. @@ -871,9 +914,8 @@ if len(user_input) > 3: party_talks = pd.DataFrame(df_["Party"].value_counts()) party_labels = party_talks.index.to_list() fig, ax1 = plt.subplots() - - total = party_talks["Party"].sum() - mentions = party_talks["Party"] #! + total = party_talks["count"].sum() + mentions = party_talks["count"] #! ax1.pie( mentions, labels=party_labels, @@ -904,7 +946,7 @@ if len(user_input) > 3: .mark_bar() .encode( x="Year", - y="Mentions", + y="count", color=alt.Color("color", scale=None), tooltip=["Party", "Mentions"], )