@ -8,7 +8,6 @@ import matplotlib.pyplot as plt
import pandas as pd
import pandas as pd
import streamlit as st
import streamlit as st
import yaml
from arango_things import arango_db
from arango_things import arango_db
from streamlit_info import (
from streamlit_info import (
party_colors ,
party_colors ,
@ -23,20 +22,6 @@ from streamlit_info import (
from langchain . schema import Document
from langchain . schema import Document
from llama_server import LLM
from llama_server import LLM
from langchain . vectorstores import Chroma
from langchain . retrievers . multi_query import MultiQueryRetriever
from langchain . embeddings import HuggingFaceEmbeddings
from langchain . chat_models import ChatOpenAI
class Session ( ) :
def __init__ ( self ) - > None :
self . mode = ' '
self . user_input = ' '
self . query = ' '
self . df = None
class Params :
class Params :
""" Class containing parameters for URL.
""" Class containing parameters for URL.
@ -65,6 +50,7 @@ class Params:
self . persons = self . set_param ( " persons " )
self . persons = self . set_param ( " persons " )
self . from_year = self . set_param ( " from_year " )
self . from_year = self . set_param ( " from_year " )
self . to_year = self . set_param ( " to_year " )
self . to_year = self . set_param ( " to_year " )
self . debates = self . set_param ( " debates " )
self . excerpt_page = self . set_param ( " excerpt_page " )
self . excerpt_page = self . set_param ( " excerpt_page " )
def set_param ( self , key ) :
def set_param ( self , key ) :
@ -101,6 +87,7 @@ class Params:
from_year = self . from_year ,
from_year = self . from_year ,
to_year = self . to_year ,
to_year = self . to_year ,
parties = " , " . join ( self . parties ) ,
parties = " , " . join ( self . parties ) ,
debates = " , " . join ( self . debates ) ,
persons = " , " . join ( self . persons ) ,
persons = " , " . join ( self . persons ) ,
excerpt_page = self . excerpt_page ,
excerpt_page = self . excerpt_page ,
)
)
@ -189,6 +176,7 @@ def summarize(df_party, user_input):
""" ,
""" ,
)
)
#TODO Include examples in the prompt?
# Example 1:
# Example 1:
# Short summary: In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget.
# Short summary: In 2020, Parliamentarian NN, ({party}) wanted to decrease the budget for the EU parliament. Later, 2022, she wanted to incrase the budget.
@ -200,21 +188,7 @@ def summarize(df_party, user_input):
#TODO Have to understand sysmtem_prompt in llama_server
#TODO Have to understand sysmtem_prompt in llama_server
#system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament."
#system_prompt = "You are a journalist and have been asked to write a short summary of the standpoints of the party in the EU parliament."
from langchain . schema import (
return llama . generate ( prompt = prompt )
AIMessage ,
HumanMessage ,
SystemMessage
)
messages = [
SystemMessage ( content = " You are a journalist reporting from the EU. " ) ,
HumanMessage ( content = prompt )
]
llm = ChatOpenAI ( temperature = 0 , openai_api_key = " sk-xxx " , openai_api_base = " http://localhost:8081/v1 " , max_tokens = 100 )
response = llm ( messages )
return response . content
#return llama.generate(prompt=prompt)
def normalize_party_names ( name ) :
def normalize_party_names ( name ) :
@ -248,12 +222,13 @@ def normalize_party_names(name):
" Vacant " : " NA " ,
" Vacant " : " NA " ,
" NI " : " NA " ,
" NI " : " NA " ,
" Renew " : " Renew "
" Renew " : " Renew "
}
}
return parties [ name ]
return parties [ name ]
def make_snippet ( row , input_tokens , mode ) :
def make_snippet ( text , input_tokens , token_text_list , token_input_list ) :
"""
"""
Find the word searched for and give it some context .
Find the word searched for and give it some context .
@ -267,17 +242,11 @@ def make_snippet(row, input_tokens, mode):
str : A snippet of text containing the input tokens and some context .
str : A snippet of text containing the input tokens and some context .
"""
"""
# If the input token is "speaker", return the first 300 characters of the text.
# If the input token is "speaker", return the first 300 characters of the text.
if mode == " question " :
if input_tokens == " speaker " :
snippet = row [ ' summary_short ' ]
elif input_tokens == " speaker " :
snippet = str ( text [ : 300 ] )
snippet = str ( text [ : 300 ] )
if len ( text ) > 300 :
if len ( text ) > 300 :
snippet + = " ... "
snippet + = " ... "
else :
else :
text = row [ " Text " ]
token_test_list = row [ " tokens " ]
token_input_list = row [ " input_tokens " ]
snippet = [ ]
snippet = [ ]
text_lower = text . lower ( )
text_lower = text . lower ( )
# Calculate snippet length in words.
# Calculate snippet length in words.
@ -381,6 +350,28 @@ def build_style_mps(mps):
return style
return style
def fix_party ( party ) :
""" Replace old party codes with new ones. """
party = party . upper ( ) . replace ( " KDS " , " KD " ) . replace ( " FP " , " L " )
return party
def build_style_debate_types ( debates ) :
""" Build a CSS style for debate type buttons. """
style = " <style> "
for debate in debates :
style + = f ' span[data-baseweb= " tag " ][aria-label= " { debate } , close by backspace " ] {{ background-color: #767676; }} .st-eg {{ min-width: 14px; }} ' # max-width: 328px;
style + = " </style> "
return style
def highlight_cells ( party ) :
if party in party_colors . keys ( ) :
color = party_colors [ party ]
return f " background-color: { color } ; font-weight: ' bold ' "
@st . cache_data
@st . cache_data
def options_persons ( df ) :
def options_persons ( df ) :
d = { }
d = { }
@ -390,7 +381,7 @@ def options_persons(df):
@st . cache_data
@st . cache_data
def get_data ( aql ) :
def get_data ( aql , input_tokens ) :
""" Get data from SQL database.
""" Get data from SQL database.
Args :
Args :
@ -411,19 +402,27 @@ def get_data(aql):
else :
else :
status = " ok "
status = " ok "
df = pd . DataFrame ( [ doc for doc in cursor ] )
df = pd . DataFrame ( [ doc for doc in cursor ] )
df . mep_id = df . mep_id . astype ( int )
df . mep_id = df . mep_id . astype ( int )
df [ " Year " ] = df [ " Date " ] . apply ( lambda x : int ( x [ : 4 ] ) )
df [ " Year " ] = df [ " Date " ] . apply ( lambda x : int ( x [ : 4 ] ) )
# df.drop_duplicates(ignore_index=True, inplace=True)
# df.drop_duplicates(ignore_index=True, inplace=True)
df [ " Party " ] = df [ " Party " ] . apply ( lambda x : normalize_party_names ( x ) )
df [ " Party " ] = df [ " Party " ] . apply ( lambda x : normalize_party_names ( x ) )
df . sort_values ( [ " Date " , " number " ] , axis = 0 , ascending = True , inplace = True )
df . sort_values ( [ " Date " , " number " ] , axis = 0 , ascending = True , inplace = True )
df [ " len_translation " ] = df [ " Text " ] . apply ( lambda x : len ( x . split ( " " ) ) )
df [ " len_tokens " ] = df [ " tokens " ] . apply ( lambda x : len ( x ) )
df . sort_values ( [ " Date " , " speech_id " , " number " ] , axis = 0 , inplace = True )
df . sort_values ( [ " Date " , " speech_id " , " number " ] , axis = 0 , inplace = True )
df . drop_duplicates ( ignore_index = True , inplace = True )
return df , status
return df , status
def user_input_to_db ( user_input ) :
""" Writes user input to db for debugging. """
arango_db . collection ( " streamlit_user_input " ) . insert (
{ " timestamp " : datetime . timestamp ( datetime . now ( ) ) , " input " : user_input }
)
def create_aql_query ( input_tokens ) :
def create_aql_query ( input_tokens ) :
""" Returns a valid sql query. """
""" Returns a valid sql query. """
# Split input into tokens
# Split input into tokens
@ -455,10 +454,21 @@ def create_aql_query(input_tokens):
str_tokens = str ( [ re . sub ( r " [^a-zA-Z0-9 \ s] " , " " , token ) for token in input_tokens ] )
str_tokens = str ( [ re . sub ( r " [^a-zA-Z0-9 \ s] " , " " , token ) for token in input_tokens ] )
select_fields = f """
select_fields = f """
{ strings [ ' arango ' ] [ ' select_fields ' ] }
{ {
" speech_id " : doc . _key ,
" Text " : doc . translation ,
" number " : doc . speech_number ,
" debatetype " : doc . debate_type ,
" Speaker " : doc . name ,
" Date " : doc . date ,
" url " : doc . url ,
" Party " : doc . party ,
" mep_id " : doc . mep_id ,
" Summary " : doc . summary ,
" debate_id " : doc . debate_id ,
' relevance_score ' : BM25 ( doc ) ,
' relevance_score ' : BM25 ( doc ) ,
' tokens ' : TOKENS ( doc . translation , ' text_en ' ) ,
' tokens ' : TOKENS ( doc . translation , ' text_en ' ) ,
' input_tokens ' : TOKENS ( { str_tokens } , ' text_en ' )
' input_tokens ' : TOKENS ( { str_tokens } , ' text_en ' ) } }
"""
"""
# Add LIMIT and RETURN clause
# Add LIMIT and RETURN clause
@ -471,10 +481,17 @@ def create_aql_query(input_tokens):
return aql_query . replace ( " SEARCH AND " , " SEARCH " )
return aql_query . replace ( " SEARCH AND " , " SEARCH " )
def fix_party ( party ) :
""" Replace old party codes with new ones. """
def error2db ( error , user_input , engine ) :
party = party . upper ( ) . replace ( " KDS " , " KD " ) . replace ( " FP " , " L " )
""" Write error to DB for debugging. """
return party
doc = (
{
" error " : error ,
" time " : datetime . date ( datetime . now ( ) ) ,
" user_input " : str ( user_input ) ,
} ,
)
arango_db . collection ( " errors " ) . inser ( doc )
@st . cache_data
@st . cache_data
@ -554,7 +571,9 @@ def show_excerpts():
] . copy ( )
] . copy ( )
df_excepts [ " Excerpt " ] = df_excepts . apply (
df_excepts [ " Excerpt " ] = df_excepts . apply (
lambda x : make_snippet ( x , input_tokens , mode = mode ) ,
lambda x : make_snippet (
x [ " Text " ] , input_tokens , x [ " tokens " ] , x [ " input_tokens " ]
) ,
axis = 1 ,
axis = 1 ,
)
)
@ -600,16 +619,17 @@ def show_excerpts():
st . markdown ( row [ " Speaker " ] )
st . markdown ( row [ " Speaker " ] )
with col2 :
with col2 :
try :
snippet = (
snippet = ( row [ " Excerpt " ] . replace ( " : " , " \ : " ) )
row [ " Excerpt " ]
. replace ( " : " , " \ : " )
. replace ( " <p> " , " " )
. replace ( " </p> " , " " )
)
st . markdown (
st . markdown (
f """ <span style= " background-color: { party_colors_lighten [ row [ ' Party ' ] ] } ; color:black; " > { snippet } </span> """ ,
f """ <span style= " background-color: { party_colors_lighten [ row [ ' Party ' ] ] } ; color:black; " > { snippet } </span> """ ,
unsafe_allow_html = True ,
unsafe_allow_html = True ,
)
)
except AttributeError :
snippet = ' '
with col3 :
with col3 :
full_text = st . button ( " Full text " , key = row [ " speech_id " ] )
full_text = st . button ( " Full text " , key = row [ " speech_id " ] )
if full_text :
if full_text :
@ -630,34 +650,6 @@ def show_excerpts():
st . markdown ( f " 📝 [Read the protocol]( { row [ ' url ' ] } ) " )
st . markdown ( f " 📝 [Read the protocol]( { row [ ' url ' ] } ) " )
@st . cache_data ( show_spinner = False )
def search_with_question ( question : str , search_kwargs : dict = { } ) :
embeddings = HuggingFaceEmbeddings ( )
vectordb = Chroma ( persist_directory = ' chroma_db ' , embedding_function = embeddings )
retriever = vectordb . as_retriever ( search_type = " mmr " , search_kwargs = { ' k ' : 20 , ' fetch_k ' : 50 } ) #https://api.python.langchain.com/en/latest/vectorstores/langchain.vectorstores.chroma.Chroma.html?highlight=as_retriever#langchain.vectorstores.chroma.Chroma.as_retriever
# About MMR https://medium.com/tech-that-works/maximal-marginal-relevance-to-rerank-results-in-unsupervised-keyphrase-extraction-22d95015c7c5
docs = retriever . get_relevant_documents ( question )
#docs = vectordb.similarity_search(question, search_kwargs=search_kwargs, )
docs_ids = set ( [ doc . metadata [ " _id " ] for doc in docs ] )
return get_data ( f ''' FOR doc IN speeches FILTER doc._id in { list ( docs_ids ) } RETURN DISTINCT {{ { strings [ ' arango ' ] [ ' select_fields ' ] } }} ''' )
print ( r )
exit ( )
# llm = ChatOpenAI(temperature=0, openai_api_key="sk-xxx", openai_api_base="http://localhost:8081/v1", max_tokens=100)
# retriever_from_llm = MultiQueryRetriever.from_llm(
# retriever=vectordb.as_retriever(search_kwargs=search_kwargs), llm=llm
# )
# #search_kwargs={"filter":{"party":"ID"}}
# unique_docs = retriever_from_llm.get_relevant_documents(query=question)
docs_ids = set ( [ doc . metadata [ " _id " ] for doc in unique_docs ] )
return get_data ( f ''' for doc in speeches filter doc._id in { list ( docs_ids ) } return {{ { strings [ ' arango ' ] [ ' select_fields ' ] } }} ''' )
###* PAGE LOADS *###
###* PAGE LOADS *###
# Title and explainer for streamlit
# Title and explainer for streamlit
@ -680,11 +672,7 @@ partycodes = list(party_colors.keys()) # List of partycodes
return_limit = 10000
return_limit = 10000
# Initialize LLM model.
# Initialize LLM model.
#llama = LLM(temperature=0.5)
llama = LLM ( temperature = 0.5 )
# Get strings from yaml file.
with open ( ' strings.yml ' , ' r ' ) as f :
strings = yaml . safe_load ( f )
# Ask for word to search for.
# Ask for word to search for.
user_input = st . text_input (
user_input = st . text_input (
@ -692,60 +680,56 @@ user_input = st.text_input(
value = params . q ,
value = params . q ,
placeholder = " Search something " ,
placeholder = " Search something " ,
# label_visibility="hidden",
# label_visibility="hidden",
help = ' You can use asterix (*), minus (-), quotationmarks ( " " ) and OR. \n You can also try asking a question, like *What is the Parliament doing for the climate?* ' ,
help = ''' You can use asterix (*), minus (-), quotationmarks ( " " ) and OR.
)
You can also ask a question but make sure to use a question mark ! ''' )
if len ( user_input ) > 3 :
if len ( user_input ) > 3 :
params . q = user_input
params . q = user_input
try : # To catch errors.
# print(user_input.upper())
# print(llama.generate(prompt=f'''A user wants to search in a database containing debates in the European Parliament and have made the input below. Take that input and write three questions would generate a good result if used for quering a vector database. Answer with a python style list containing the three questions.
meps_df = get_speakers ( )
# User input: {user_input}
meps_ids = meps_df . mep_id . to_list ( )
# Questions: '''))
try :
user_input = user_input . replace ( " ' " , ' " ' )
user_input = user_input . replace ( " ' " , ' " ' )
input_tokens = re . findall ( r ' (?: " [^ " ]* " | \ S)+ ' , user_input )
input_tokens = re . findall ( r ' (?: " [^ " ]* " | \ S)+ ' , user_input )
# Put user input in session state (first run).
# Put user input in session state (first run).
if " user_input " not in st . session_state :
if " user_input " not in st . session_state :
st . session_state [ " user_input " ] = user_input
st . session_state [ " user_input " ] = user_input
# Write user input to DB.
user_input_to_db ( user_input )
arango_db . collection ( " streamlit_user_input " ) . insert (
{ " timestamp " : datetime . timestamp ( datetime . now ( ) ) , " input " : user_input }
)
else :
else :
if st . session_state [ " user_input " ] != user_input :
if st . session_state [ " user_input " ] != user_input :
# Write user input to DB.
st . session_state [ " user_input " ] = user_input
st . session_state [ " user_input " ] = user_input
user_input_to_db ( user_input )
# Reset url parameters.
# Reset url parameters.
params . reset ( q = user_input )
params . reset ( q = user_input )
# Write user input to DB.
arango_db . collection ( " streamlit_user_input " ) . insert (
{ " timestamp " : datetime . timestamp ( datetime . now ( ) ) , " input " : user_input }
)
params . update ( )
params . update ( )
if user_input . strip ( ) [ - 1 ] == ' ? ' : # Search documents with AI.
mode = ' question '
meps_df = get_speakers ( )
with st . spinner ( f " Trying to find answers... " ) :
meps_ids = meps_df . mep_id . to_list ( )
df , status = search_with_question ( user_input )
input_tokens = None
else :
mode = " normal "
# Check if user has searched for a specific politician.
# Check if user has searched for a specific politician.
if len ( user_input . split ( " " ) ) in [ 2 , 3 , 4 , ] : # TODO Better way of telling if name?
if len ( user_input . split ( " " ) ) in [
2 ,
3 ,
4 ,
] : # TODO Better way of telling if name?
list_persons = meps_df [ " name " ] . to_list ( )
list_persons = meps_df [ " name " ] . to_list ( )
if user_input . lower ( ) in list_persons :
if user_input . lower ( ) in list_persons :
sql = search_person ( user_input , list_persons )
sql = search_person ( user_input , list_persons )
mode = " speaker "
search_terms = " speaker "
aql = create_aql_query ( input_tokens )
aql = create_aql_query ( input_tokens )
## Fetch data from DB.
## Fetch data from DB.
df , status = get_data ( aql )
df , status = get_data ( aql , input_tokens )
if status == " no hits " : # If no hits.
if status == " no hits " : # If no hits.
st . write ( " No hits. Try again! " )
st . write ( " No hits. Try again! " )
@ -760,7 +744,7 @@ if len(user_input) > 3:
if type ( party_labels ) == " list " :
if type ( party_labels ) == " list " :
party_labels . sort ( )
party_labels . sort ( )
if mode != " speaker " :
if user_input != " speaker " :
# Let the user select parties to be included.
# Let the user select parties to be included.
container_parties = st . container ( )
container_parties = st . container ( )
with container_parties :
with container_parties :
@ -774,17 +758,27 @@ if len(user_input) > 3:
default = party_labels ,
default = party_labels ,
)
)
if params . parties != [ ] :
if params . parties != [ ] :
if mode == " question " :
df , status = search_with_question ( user_input , search_kwargs = { " filter " : { " party " : { " $in " : params . parties } } } )
else :
df = df . loc [ df [ " Party " ] . isin ( params . parties ) ]
df = df . loc [ df [ " Party " ] . isin ( params . parties ) ]
if len ( df ) == 0 :
if len ( df ) == 0 :
st . stop ( )
st . stop ( )
# Let the user select type of debate.
# Let the user select type of debate.
container_debate = st . container ( )
container_debate = st . container ( )
with container_debate :
debates = df [ " debatetype " ] . unique ( ) . tolist ( )
debates . sort ( )
style = build_style_debate_types ( debates )
st . markdown ( style , unsafe_allow_html = True )
params . debates = st . multiselect (
label = " Select type of debate " ,
options = debates ,
default = debates ,
)
if params . debates != [ ] :
df = df . loc [ df [ " debatetype " ] . isin ( params . debates ) ]
if len ( df ) == 0 :
st . stop ( )
params . update ( )
params . update ( )
# Let the user select a range of years.
# Let the user select a range of years.
@ -808,7 +802,7 @@ if len(user_input) > 3:
params . update ( )
params . update ( )
if mode != " speaker " :
if user_input != " speaker " :
# Let the user select talkers.
# Let the user select talkers.
options = options_persons ( df )
options = options_persons ( df )
style_mps = build_style_mps ( options ) # Make the options the right colors.
style_mps = build_style_mps ( options ) # Make the options the right colors.
@ -866,7 +860,7 @@ if len(user_input) > 3:
del st . session_state [ " disable_next_page_button " ]
del st . session_state [ " disable_next_page_button " ]
st . session_state [ " hits " ] = df . shape [ 0 ]
st . session_state [ " hits " ] = df . shape [ 0 ]
# Show snippets.
##! Show snippets.
expand = st . expander (
expand = st . expander (
" Show excerpts from the speeches. " ,
" Show excerpts from the speeches. " ,
expanded = False ,
expanded = False ,
@ -874,7 +868,6 @@ if len(user_input) > 3:
with expand :
with expand :
if " excerpt_page " not in st . session_state :
if " excerpt_page " not in st . session_state :
st . session_state [ " excerpt_page " ] = 0
st . session_state [ " excerpt_page " ] = 0
excerpts_container = st . container ( )
excerpts_container = st . container ( )
show_excerpts ( )
show_excerpts ( )
st . session_state [ " excerpt_page " ] + = 1
st . session_state [ " excerpt_page " ] + = 1
@ -891,14 +884,31 @@ if len(user_input) > 3:
st . session_state [ " excerpt_page " ] + = 1
st . session_state [ " excerpt_page " ] + = 1
show_excerpts ( )
show_excerpts ( )
# * Download all data in df.
if mode != ' question ' :
st . download_button (
if mode != " speaker " :
" Download the data as CSV " ,
data = df . to_csv (
index = False ,
sep = " ; " ,
columns = [
" speech_id " ,
" Text " ,
" Party " ,
" Speaker " ,
" Date " ,
" url " ,
] ,
) . encode ( " utf-8 " ) ,
file_name = f " { user_input } .csv " ,
mime = " text/csv " ,
)
# Remove talks from same party within the same session to make the
# Remove talks from same party within the same session to make the
# statistics more representative.
# statistics more representative.
df_ = df [ [ " speech_id " , " Party " , " Year " ] ] . drop_duplicates ( )
df_ = df [ [ " speech_id " , " Party " , " Year " ] ] . drop_duplicates ( )
if user_input != " speaker " :
## Make pie chart.
## Make pie chart.
party_talks = pd . DataFrame ( df_ [ " Party " ] . value_counts ( ) )
party_talks = pd . DataFrame ( df_ [ " Party " ] . value_counts ( ) )
@ -956,10 +966,8 @@ if len(user_input) > 3:
# Create an empty container
# Create an empty container
container = st . empty ( )
container = st . empty ( )
# * Make a summary of the standpoints of the parties.
# * Make a summary of the standpoints of the parties.
summarize_please = st . button ( " Summarize please! " , )
if summarize_please :
# A general note about the summaries. #TODO Make better.
# A general note about the summaries. #TODO Make better.
st . markdown ( summary_note )
st . markdown ( summary_note )
@ -985,16 +993,13 @@ if len(user_input) > 3:
# A general note about the summaries. #TODO Make better.
# A general note about the summaries. #TODO Make better.
feedback_column , download_column = st . columns ( [ 3 , 1 ] )
with feedback_column :
# * Get feedback.
# * Get feedback.
st . empty ( )
st . empty ( )
feedback_container = st . empty ( )
feedback_container = st . empty ( )
with feedback_container . container ( ) :
with feedback_container . container ( ) :
feedback = st . text_area (
feedback = st . text_area (
' Feedback ' , placeholder = ' Feel free to write suggestions for functions and improvements here! ' , label_visibility = ' collapsed '
" *Feel free to write suggestions for functions and improvements here!* "
)
)
send = st . button ( " Send " )
send = st . button ( " Send " )
if len ( feedback ) > 2 and send :
if len ( feedback ) > 2 and send :
@ -1012,27 +1017,6 @@ if len(user_input) > 3:
# st.markdown("##")
# st.markdown("##")
with download_column :
# * Download all data in df.
st . download_button (
" Download data as CSV-file " ,
data = df . to_csv (
index = False ,
sep = " ; " ,
columns = [
" speech_id " ,
" Text " ,
" Party " ,
" Speaker " ,
" Date " ,
" url " ,
] ,
) . encode ( " utf-8 " ) ,
file_name = f " { user_input } .csv " ,
mime = " text/csv " ,
)
except Exception as e :
except Exception as e :
if (
if (
e == " streamlit.runtime.scriptrunner.script_runner.StopException "
e == " streamlit.runtime.scriptrunner.script_runner.StopException "