diff --git a/.gitignore b/.gitignore index 0f536d0..8a6efbd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,11 +1,11 @@ * +!streamlit_info.py !download_debates.py !translate_speeches.py -!arango_things -!things +!arango_things.py +!things.py !streamlit_app_talking_ep.py !.gitignore -!streamlit_info.py !notes.md !llama_server.py !requirements_streamlit.txt \ No newline at end of file diff --git a/arango_things.py b/arango_things.py new file mode 100644 index 0000000..a043ae7 --- /dev/null +++ b/arango_things.py @@ -0,0 +1,68 @@ +from arango import ArangoClient, exceptions +import pandas as pd +import yaml + +def get_documents(query=False, collection=False, fields=[], filter = '', df=False, index=False, field_names=False): + """ + This function retrieves documents from a specified collection or based on a query in ArangoDB. + + Parameters: + query (str): If specified this will be the query. Defaults to False. + collection (str): The name of the collection from which to retrieve documents. Defaults to False. + fields (list): The fields of the documents to retrieve. If empty, all fields are retrieved. + filter (str): AQL filter to apply to the retrieval. Defaults to no filter. + df (bool): If True, the result is returned as a pandas DataFrame. Defaults to False. + index (bool): If True and df is True, the DataFrame is indexed. Defaults to False. + field_names (dict): If provided, these field names will replace the original field names in the result. + + Returns: + list or DataFrame: The retrieved documents as a list of dictionaries or a DataFrame. + """ + if query: + pass + + else: + if fields == []: + return_fields = 'doc' + else: + fields_dict = {} + for field in fields: + fields_dict[field] = field + + if field_names: + for k, v in field_names.items(): + fields_dict[k] = v + + fields_list = [f'{v}: doc.{k}' for k, v in fields_dict.items()] + fields_string = ', '.join(fields_list) + return_fields = f"{{{fields_string}}}" + query = f''' + for doc + in {collection} + {filter} + return {return_fields} + ''' + try: + cursor = arango_db.aql.execute(query) + except exceptions.AQLQueryExecuteError: + print('ERROR:\n', query) + exit() + result = [i for i in cursor] + if df: + result = pd.DataFrame(result) + if index: + result.set_index(index, inplace=True) + return result + + +with open('config.yml', 'r') as f: + config = yaml.safe_load(f) + +db = config['arango']['db'] +username = config['arango']['username'] +pwd = config['arango']['pwd_lasse'] + +# Initialize the database for ArangoDB. +client = ArangoClient(hosts=config['arango']['hosts']) +arango_db = client.db(db, username=username, password=pwd) + diff --git a/things.py b/things.py new file mode 100644 index 0000000..cc807d6 --- /dev/null +++ b/things.py @@ -0,0 +1,313 @@ +import nltk +import tiktoken +import re + +def normalize_party_names(name): + """ + Normalizes party names to the format used in the database. + + Parameters: + name (str): The party name to be normalized. + + Returns: + str: The normalized party name. + """ + + parties = { + "EPP": "EPP", + "PPE": "EPP", + "RE": "Renew", + "S-D": "S&D", + "S&D": "S&D", + "ID": "ID", + "ECR": "ECR", + "GUE/NGL": "GUE/NGL", + "The Left": "GUE/NGL", + "Greens/EFA": "Greens/EFA", + "G/EFA": "Greens/EFA", + "Verts/ALE": "Greens/EFA", + "NA": "NA", + "NULL": "NA", + None: "NA", + "-": "NA", + "Vacant": "NA", + "NI": "NA", + "Renew": "Renew" + + } + + return parties[name] + + +def count_tokens(string: str) -> int: + """Returns the number of tokens in a text string.""" + encoding = tiktoken.get_encoding("cl100k_base") + num_tokens = len(encoding.encode(string)) + return num_tokens + +def whitespace_remover(text): + return re.sub(r'^[\w\s]', '', text) + +def fix_date(string, pattern): + """ + Args: + string (str): The input string containing the date. + pattern (str): The pattern to be used for parsing the date from the string. It should contain 'y', 'm', and 'd' to represent the year, month, and day respectively. + + Returns: + str: The formatted date string in the 'yyyy-mm-dd' format. + + Example: + >>> fix_date('20211231', 'yyyymmdd') + '2021-12-31' + """ + + y = re.search(r'y+', pattern) + m = re.search(r'm+', pattern) + d = re.search(r'd+', pattern) + + year = string[y.span()[0]:y.span()[1]] + month = string[m.span()[0]:m.span()[1]] + day = string[d.span()[0]:d.span()[1]] + + return (f'{year}-{month}-{day}') + + +def fix_doc_name(string, pattern): + """ + Returns a dictionary with keys 'year', 'number', and 'letters', and their corresponding values from the string. + + Args: + string (str): The string from which to extract the year, number, and letters. + pattern (str): The pattern which to search for in the string. + + Returns: + dict: A dictionary with keys 'year', 'number', and 'letters' and their corresponding values from the string. + + Example: + >>> fix_doc_name('COM/2021/570', 'lll/yyyy/nnn') + '{'year': 2023, 'number': 570, 'letters': COM}' + """ + + # Find positions for y, n l. + y = re.search(r'y+', pattern) + n = re.search(r'n+', pattern) + l = re.search(r'l+', pattern) + + # Extract the year, number and letters based on the positions. + year = string[y.span()[0]:y.span()[1]] + number = string[n.span()[0]:n.span()[1]] + letters = string[l.span()[0]:l.span()[1]] + + return {'year': year, 'number': number, 'letters': letters} + + + +def text_splitter(text: str, max_tokens=2000): + """ + Splits a given text into chunks of sentences where each chunk has a number of tokens less than or equal to the max_tokens. + + The function first calculates the total number of tokens in the input text. If this number is greater than max_tokens, + it calculates the maximum number of tokens per chunk and splits the text into sentences. Then it iterates over the sentences, + adding them to the current chunk until the number of tokens in the current chunk reaches the maximum limit. When this limit + is reached, the current chunk is added to the list of chunks and a new chunk is started. This process continues until all + sentences have been processed. If the total number of tokens in the text is less than or equal to max_tokens, the function + returns the whole text as a single chunk. + + Parameters: + text (str): The input text to be split into chunks. + max_tokens (int): The maximum number of tokens allowed in each chunk. + + Returns: + chunks (list of str): A list of text chunks where each chunk has a number of tokens less than or equal to max_tokens. + """ + try: + tokens_in_text = count_tokens(text) + except: + tokens_in_text = len(text)/3 + if tokens_in_text > max_tokens: + + # Calculate maximal number of tokens in chunks to make the chunks even. + max_tokens_per_chunk = int(tokens_in_text/int(tokens_in_text / max_tokens)) + + # Split the text into sentences. + sentences = nltk.sent_tokenize(text) + + # Initialize an empty list to hold chunks and a string to hold the current chunk. + chunks = [] + current_chunk = '' + + # Iterate over the sentences. + for sentence in sentences: + # If adding the next sentence doesn't exceed the max tokens limit, add the sentence to the current chunk. + if count_tokens(current_chunk + ' ' + sentence) <= max_tokens_per_chunk: + current_chunk += ' ' + sentence + else: + # If it does, add the current chunk to the chunks list and start a new chunk with the current sentence. + chunks.append(current_chunk) + current_chunk = sentence + + # Add the last chunk to the chunks list. + if current_chunk: + chunks.append(current_chunk) + + else: + chunks = [text] + + return chunks + + +parliamentary_term_now = 9 #* Update this every term. + + +model_mistral = "mistral-openorca" + +eu_country_codes = { + "Belgium": "BE", + "Greece": "EL", + "Lithuania": "LT", + "Portugal": "PT", + "Bulgaria": "BG", + "Spain": "ES", + "Luxembourg": "LU", + "Romania": "RO", + "Czechia": "CZ", + "France": "FR", + "Hungary": "HU", + "Slovenia": "SI", + "Denmark": "DK", + "Croatia": "HR", + "Malta": "MT", + "Slovakia": "SK", + "Germany": "DE", + "Italy": "IT", + "Netherlands": "NL", + "Finland": "FI", + "Estonia": "EE", + "Cyprus": "CY", + "Austria": "AT", + "Sweden": "SE", + "Ireland": "IE", + "Latvia": "LV", + "Poland": "PL", +} + +country_flags = { + "United Kingdom": "🇬🇧", + "Sweden": "🇸🇪", + "Spain": "🇪🇸", + "Slovenia": "🇸🇮", + "Slovakia": "🇸🇰", + "Romania": "🇷🇴", + "Portugal": "🇵🇹", + "Poland": "🇵🇱", + "Netherlands": "🇳🇱", + "Malta": "🇲🇹", + "Luxembourg": "🇱🇺", + "Lithuania": "🇱🇹", + "Latvia": "🇱🇻", + "Italy": "🇮🇹", + "Ireland": "🇮🇪", + "Hungary": "🇭🇺", + "Greece": "🇬🇷", + "Germany": "🇩🇪", + "France": "🇫🇷", + "Finland": "🇫🇮", + "Estonia": "🇪🇪", + "Denmark": "🇩🇰", + "Czechia": "🇨🇿", + "Cyprus": "🇨🇾", + "Croatia": "🇭🇷", + "Bulgaria": "🇧🇬", + "Belgium": "🇧🇪", + "Austria": "🇦🇹", +} + + +policy_areas = [ + "Agriculture", + "Business", + "Industry", + "Climate", + "Culture", + "Customs", + "Development", + "Education", + "Employment", + "Social Affairs", + "Energy", + "Environment", + "FoodSafety", + "SecurityPolicy", + "Health", + "Democracy", + "Humanitarian Aid", + "Justice", + "Research And Innovation", + "Market", + "Taxation", + "Trade", + "Transport", +] + +# From https://eur-lex.europa.eu/browse/summaries.html +policy_areas = ['Agriculture', ' Audiovisual and media', ' Budget', ' Competition', ' Consumers', ' Culture', ' Customs', ' Development', ' Digital single market', ' Economic and monetary affairs', ' Education, training, youth, sport', ' Employment and social policy', ' Energy', ' Enlargement', ' Enterprise', ' Environment and climate change', ' External relations', ' External trade', ' Food safety', ' Foreign and security policy', ' Fraud and corruption', ' Humanitarian Aid and Civil Protection', ' Human rights', ' Institutional affairs', ' Internal market', ' Justice, freedom and security', ' Oceans and fisheries', ' Public health', ' Regional policy', ' Research and innovation', ' Taxation', ' Transport'] + +countries = [ + "Romania", + "Latvia", + "Slovenia", + "Denmark", + "Spain", + "Italy", + "Hungary", + "United Kingdom", + "Netherlands", + "Czechia", + "Finland", + "Belgium", + "Germany", + "France", + "Slovakia", + "Poland", + "Ireland", + "Malta", + "Cyprus", + "Luxembourg", + "Greece", + "Austria", + "Sweden", + "Portugal", + "Lithuania", + "Croatia", + "Bulgaria", + "Estonia", +] + +parties = ["Renew", "S-D", "PPE", "Verts/ALE", "ECR", "NI", "The Left", "ID", "GUE/NGL"] + +party_colors = { + "EPP": "#3399FF", + "S-D": "#F0001C", + "Renew": "gold", + "ID": "#0E408A", + "G/EFA": "#57B45F", + "ECR": "#196CA8", + "GUE/NGL": "#B71C1C", + 'The Left': "#B71C1C", # Same as GUE/NGL + "NI": "white", + "Vacant": "white", + "PPE": "#3399FF", # Same as EPP + "NULL": 'white', + 'Verts/ALE':"#57B45F", # Same as G/EFA + None: 'white' +} + +def insert_in_db(query): + con = sqlite3.connect(path_db) + con.row_factory = sqlite3.Row + cursor = con.cursor() + query = query + cursor.execute(query) + con.commit() + con.close() \ No newline at end of file