lasseedfast 2 years ago
parent 9122f0c84f
commit 7a529055e8
  1. 6
      .gitignore
  2. 68
      arango_things.py
  3. 313
      things.py

6
.gitignore vendored

@ -1,11 +1,11 @@
* *
!streamlit_info.py
!download_debates.py !download_debates.py
!translate_speeches.py !translate_speeches.py
!arango_things !arango_things.py
!things !things.py
!streamlit_app_talking_ep.py !streamlit_app_talking_ep.py
!.gitignore !.gitignore
!streamlit_info.py
!notes.md !notes.md
!llama_server.py !llama_server.py
!requirements_streamlit.txt !requirements_streamlit.txt

@ -0,0 +1,68 @@
from arango import ArangoClient, exceptions
import pandas as pd
import yaml
def get_documents(query=False, collection=False, fields=[], filter = '', df=False, index=False, field_names=False):
"""
This function retrieves documents from a specified collection or based on a query in ArangoDB.
Parameters:
query (str): If specified this will be the query. Defaults to False.
collection (str): The name of the collection from which to retrieve documents. Defaults to False.
fields (list): The fields of the documents to retrieve. If empty, all fields are retrieved.
filter (str): AQL filter to apply to the retrieval. Defaults to no filter.
df (bool): If True, the result is returned as a pandas DataFrame. Defaults to False.
index (bool): If True and df is True, the DataFrame is indexed. Defaults to False.
field_names (dict): If provided, these field names will replace the original field names in the result.
Returns:
list or DataFrame: The retrieved documents as a list of dictionaries or a DataFrame.
"""
if query:
pass
else:
if fields == []:
return_fields = 'doc'
else:
fields_dict = {}
for field in fields:
fields_dict[field] = field
if field_names:
for k, v in field_names.items():
fields_dict[k] = v
fields_list = [f'{v}: doc.{k}' for k, v in fields_dict.items()]
fields_string = ', '.join(fields_list)
return_fields = f"{{{fields_string}}}"
query = f'''
for doc
in {collection}
{filter}
return {return_fields}
'''
try:
cursor = arango_db.aql.execute(query)
except exceptions.AQLQueryExecuteError:
print('ERROR:\n', query)
exit()
result = [i for i in cursor]
if df:
result = pd.DataFrame(result)
if index:
result.set_index(index, inplace=True)
return result
with open('config.yml', 'r') as f:
config = yaml.safe_load(f)
db = config['arango']['db']
username = config['arango']['username']
pwd = config['arango']['pwd_lasse']
# Initialize the database for ArangoDB.
client = ArangoClient(hosts=config['arango']['hosts'])
arango_db = client.db(db, username=username, password=pwd)

@ -0,0 +1,313 @@
import nltk
import tiktoken
import re
def normalize_party_names(name):
"""
Normalizes party names to the format used in the database.
Parameters:
name (str): The party name to be normalized.
Returns:
str: The normalized party name.
"""
parties = {
"EPP": "EPP",
"PPE": "EPP",
"RE": "Renew",
"S-D": "S&D",
"S&D": "S&D",
"ID": "ID",
"ECR": "ECR",
"GUE/NGL": "GUE/NGL",
"The Left": "GUE/NGL",
"Greens/EFA": "Greens/EFA",
"G/EFA": "Greens/EFA",
"Verts/ALE": "Greens/EFA",
"NA": "NA",
"NULL": "NA",
None: "NA",
"-": "NA",
"Vacant": "NA",
"NI": "NA",
"Renew": "Renew"
}
return parties[name]
def count_tokens(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = len(encoding.encode(string))
return num_tokens
def whitespace_remover(text):
return re.sub(r'^[\w\s]', '', text)
def fix_date(string, pattern):
"""
Args:
string (str): The input string containing the date.
pattern (str): The pattern to be used for parsing the date from the string. It should contain 'y', 'm', and 'd' to represent the year, month, and day respectively.
Returns:
str: The formatted date string in the 'yyyy-mm-dd' format.
Example:
>>> fix_date('20211231', 'yyyymmdd')
'2021-12-31'
"""
y = re.search(r'y+', pattern)
m = re.search(r'm+', pattern)
d = re.search(r'd+', pattern)
year = string[y.span()[0]:y.span()[1]]
month = string[m.span()[0]:m.span()[1]]
day = string[d.span()[0]:d.span()[1]]
return (f'{year}-{month}-{day}')
def fix_doc_name(string, pattern):
"""
Returns a dictionary with keys 'year', 'number', and 'letters', and their corresponding values from the string.
Args:
string (str): The string from which to extract the year, number, and letters.
pattern (str): The pattern which to search for in the string.
Returns:
dict: A dictionary with keys 'year', 'number', and 'letters' and their corresponding values from the string.
Example:
>>> fix_doc_name('COM/2021/570', 'lll/yyyy/nnn')
'{'year': 2023, 'number': 570, 'letters': COM}'
"""
# Find positions for y, n l.
y = re.search(r'y+', pattern)
n = re.search(r'n+', pattern)
l = re.search(r'l+', pattern)
# Extract the year, number and letters based on the positions.
year = string[y.span()[0]:y.span()[1]]
number = string[n.span()[0]:n.span()[1]]
letters = string[l.span()[0]:l.span()[1]]
return {'year': year, 'number': number, 'letters': letters}
def text_splitter(text: str, max_tokens=2000):
"""
Splits a given text into chunks of sentences where each chunk has a number of tokens less than or equal to the max_tokens.
The function first calculates the total number of tokens in the input text. If this number is greater than max_tokens,
it calculates the maximum number of tokens per chunk and splits the text into sentences. Then it iterates over the sentences,
adding them to the current chunk until the number of tokens in the current chunk reaches the maximum limit. When this limit
is reached, the current chunk is added to the list of chunks and a new chunk is started. This process continues until all
sentences have been processed. If the total number of tokens in the text is less than or equal to max_tokens, the function
returns the whole text as a single chunk.
Parameters:
text (str): The input text to be split into chunks.
max_tokens (int): The maximum number of tokens allowed in each chunk.
Returns:
chunks (list of str): A list of text chunks where each chunk has a number of tokens less than or equal to max_tokens.
"""
try:
tokens_in_text = count_tokens(text)
except:
tokens_in_text = len(text)/3
if tokens_in_text > max_tokens:
# Calculate maximal number of tokens in chunks to make the chunks even.
max_tokens_per_chunk = int(tokens_in_text/int(tokens_in_text / max_tokens))
# Split the text into sentences.
sentences = nltk.sent_tokenize(text)
# Initialize an empty list to hold chunks and a string to hold the current chunk.
chunks = []
current_chunk = ''
# Iterate over the sentences.
for sentence in sentences:
# If adding the next sentence doesn't exceed the max tokens limit, add the sentence to the current chunk.
if count_tokens(current_chunk + ' ' + sentence) <= max_tokens_per_chunk:
current_chunk += ' ' + sentence
else:
# If it does, add the current chunk to the chunks list and start a new chunk with the current sentence.
chunks.append(current_chunk)
current_chunk = sentence
# Add the last chunk to the chunks list.
if current_chunk:
chunks.append(current_chunk)
else:
chunks = [text]
return chunks
parliamentary_term_now = 9 #* Update this every term.
model_mistral = "mistral-openorca"
eu_country_codes = {
"Belgium": "BE",
"Greece": "EL",
"Lithuania": "LT",
"Portugal": "PT",
"Bulgaria": "BG",
"Spain": "ES",
"Luxembourg": "LU",
"Romania": "RO",
"Czechia": "CZ",
"France": "FR",
"Hungary": "HU",
"Slovenia": "SI",
"Denmark": "DK",
"Croatia": "HR",
"Malta": "MT",
"Slovakia": "SK",
"Germany": "DE",
"Italy": "IT",
"Netherlands": "NL",
"Finland": "FI",
"Estonia": "EE",
"Cyprus": "CY",
"Austria": "AT",
"Sweden": "SE",
"Ireland": "IE",
"Latvia": "LV",
"Poland": "PL",
}
country_flags = {
"United Kingdom": "🇬🇧",
"Sweden": "🇸🇪",
"Spain": "🇪🇸",
"Slovenia": "🇸🇮",
"Slovakia": "🇸🇰",
"Romania": "🇷🇴",
"Portugal": "🇵🇹",
"Poland": "🇵🇱",
"Netherlands": "🇳🇱",
"Malta": "🇲🇹",
"Luxembourg": "🇱🇺",
"Lithuania": "🇱🇹",
"Latvia": "🇱🇻",
"Italy": "🇮🇹",
"Ireland": "🇮🇪",
"Hungary": "🇭🇺",
"Greece": "🇬🇷",
"Germany": "🇩🇪",
"France": "🇫🇷",
"Finland": "🇫🇮",
"Estonia": "🇪🇪",
"Denmark": "🇩🇰",
"Czechia": "🇨🇿",
"Cyprus": "🇨🇾",
"Croatia": "🇭🇷",
"Bulgaria": "🇧🇬",
"Belgium": "🇧🇪",
"Austria": "🇦🇹",
}
policy_areas = [
"Agriculture",
"Business",
"Industry",
"Climate",
"Culture",
"Customs",
"Development",
"Education",
"Employment",
"Social Affairs",
"Energy",
"Environment",
"FoodSafety",
"SecurityPolicy",
"Health",
"Democracy",
"Humanitarian Aid",
"Justice",
"Research And Innovation",
"Market",
"Taxation",
"Trade",
"Transport",
]
# From https://eur-lex.europa.eu/browse/summaries.html
policy_areas = ['Agriculture', ' Audiovisual and media', ' Budget', ' Competition', ' Consumers', ' Culture', ' Customs', ' Development', ' Digital single market', ' Economic and monetary affairs', ' Education, training, youth, sport', ' Employment and social policy', ' Energy', ' Enlargement', ' Enterprise', ' Environment and climate change', ' External relations', ' External trade', ' Food safety', ' Foreign and security policy', ' Fraud and corruption', ' Humanitarian Aid and Civil Protection', ' Human rights', ' Institutional affairs', ' Internal market', ' Justice, freedom and security', ' Oceans and fisheries', ' Public health', ' Regional policy', ' Research and innovation', ' Taxation', ' Transport']
countries = [
"Romania",
"Latvia",
"Slovenia",
"Denmark",
"Spain",
"Italy",
"Hungary",
"United Kingdom",
"Netherlands",
"Czechia",
"Finland",
"Belgium",
"Germany",
"France",
"Slovakia",
"Poland",
"Ireland",
"Malta",
"Cyprus",
"Luxembourg",
"Greece",
"Austria",
"Sweden",
"Portugal",
"Lithuania",
"Croatia",
"Bulgaria",
"Estonia",
]
parties = ["Renew", "S-D", "PPE", "Verts/ALE", "ECR", "NI", "The Left", "ID", "GUE/NGL"]
party_colors = {
"EPP": "#3399FF",
"S-D": "#F0001C",
"Renew": "gold",
"ID": "#0E408A",
"G/EFA": "#57B45F",
"ECR": "#196CA8",
"GUE/NGL": "#B71C1C",
'The Left': "#B71C1C", # Same as GUE/NGL
"NI": "white",
"Vacant": "white",
"PPE": "#3399FF", # Same as EPP
"NULL": 'white',
'Verts/ALE':"#57B45F", # Same as G/EFA
None: 'white'
}
def insert_in_db(query):
con = sqlite3.connect(path_db)
con.row_factory = sqlite3.Row
cursor = con.cursor()
query = query
cursor.execute(query)
con.commit()
con.close()
Loading…
Cancel
Save