You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
313 lines
8.8 KiB
313 lines
8.8 KiB
import nltk |
|
import tiktoken |
|
import re |
|
|
|
def normalize_party_names(name): |
|
""" |
|
Normalizes party names to the format used in the database. |
|
|
|
Parameters: |
|
name (str): The party name to be normalized. |
|
|
|
Returns: |
|
str: The normalized party name. |
|
""" |
|
|
|
parties = { |
|
"EPP": "EPP", |
|
"PPE": "EPP", |
|
"RE": "Renew", |
|
"S-D": "S&D", |
|
"S&D": "S&D", |
|
"ID": "ID", |
|
"ECR": "ECR", |
|
"GUE/NGL": "GUE/NGL", |
|
"The Left": "GUE/NGL", |
|
"Greens/EFA": "Greens/EFA", |
|
"G/EFA": "Greens/EFA", |
|
"Verts/ALE": "Greens/EFA", |
|
"NA": "NA", |
|
"NULL": "NA", |
|
None: "NA", |
|
"-": "NA", |
|
"Vacant": "NA", |
|
"NI": "NA", |
|
"Renew": "Renew" |
|
|
|
} |
|
|
|
return parties[name] |
|
|
|
|
|
def count_tokens(string: str) -> int: |
|
"""Returns the number of tokens in a text string.""" |
|
encoding = tiktoken.get_encoding("cl100k_base") |
|
num_tokens = len(encoding.encode(string)) |
|
return num_tokens |
|
|
|
def whitespace_remover(text): |
|
return re.sub(r'^[\w\s]', '', text) |
|
|
|
def fix_date(string, pattern): |
|
""" |
|
Args: |
|
string (str): The input string containing the date. |
|
pattern (str): The pattern to be used for parsing the date from the string. It should contain 'y', 'm', and 'd' to represent the year, month, and day respectively. |
|
|
|
Returns: |
|
str: The formatted date string in the 'yyyy-mm-dd' format. |
|
|
|
Example: |
|
>>> fix_date('20211231', 'yyyymmdd') |
|
'2021-12-31' |
|
""" |
|
|
|
y = re.search(r'y+', pattern) |
|
m = re.search(r'm+', pattern) |
|
d = re.search(r'd+', pattern) |
|
|
|
year = string[y.span()[0]:y.span()[1]] |
|
month = string[m.span()[0]:m.span()[1]] |
|
day = string[d.span()[0]:d.span()[1]] |
|
|
|
return (f'{year}-{month}-{day}') |
|
|
|
|
|
def fix_doc_name(string, pattern): |
|
""" |
|
Returns a dictionary with keys 'year', 'number', and 'letters', and their corresponding values from the string. |
|
|
|
Args: |
|
string (str): The string from which to extract the year, number, and letters. |
|
pattern (str): The pattern which to search for in the string. |
|
|
|
Returns: |
|
dict: A dictionary with keys 'year', 'number', and 'letters' and their corresponding values from the string. |
|
|
|
Example: |
|
>>> fix_doc_name('COM/2021/570', 'lll/yyyy/nnn') |
|
'{'year': 2023, 'number': 570, 'letters': COM}' |
|
""" |
|
|
|
# Find positions for y, n l. |
|
y = re.search(r'y+', pattern) |
|
n = re.search(r'n+', pattern) |
|
l = re.search(r'l+', pattern) |
|
|
|
# Extract the year, number and letters based on the positions. |
|
year = string[y.span()[0]:y.span()[1]] |
|
number = string[n.span()[0]:n.span()[1]] |
|
letters = string[l.span()[0]:l.span()[1]] |
|
|
|
return {'year': year, 'number': number, 'letters': letters} |
|
|
|
|
|
|
|
def text_splitter(text: str, max_tokens=2000): |
|
""" |
|
Splits a given text into chunks of sentences where each chunk has a number of tokens less than or equal to the max_tokens. |
|
|
|
The function first calculates the total number of tokens in the input text. If this number is greater than max_tokens, |
|
it calculates the maximum number of tokens per chunk and splits the text into sentences. Then it iterates over the sentences, |
|
adding them to the current chunk until the number of tokens in the current chunk reaches the maximum limit. When this limit |
|
is reached, the current chunk is added to the list of chunks and a new chunk is started. This process continues until all |
|
sentences have been processed. If the total number of tokens in the text is less than or equal to max_tokens, the function |
|
returns the whole text as a single chunk. |
|
|
|
Parameters: |
|
text (str): The input text to be split into chunks. |
|
max_tokens (int): The maximum number of tokens allowed in each chunk. |
|
|
|
Returns: |
|
chunks (list of str): A list of text chunks where each chunk has a number of tokens less than or equal to max_tokens. |
|
""" |
|
try: |
|
tokens_in_text = count_tokens(text) |
|
except: |
|
tokens_in_text = len(text)/3 |
|
if tokens_in_text > max_tokens: |
|
|
|
# Calculate maximal number of tokens in chunks to make the chunks even. |
|
max_tokens_per_chunk = int(tokens_in_text/int(tokens_in_text / max_tokens)) |
|
|
|
# Split the text into sentences. |
|
sentences = nltk.sent_tokenize(text) |
|
|
|
# Initialize an empty list to hold chunks and a string to hold the current chunk. |
|
chunks = [] |
|
current_chunk = '' |
|
|
|
# Iterate over the sentences. |
|
for sentence in sentences: |
|
# If adding the next sentence doesn't exceed the max tokens limit, add the sentence to the current chunk. |
|
if count_tokens(current_chunk + ' ' + sentence) <= max_tokens_per_chunk: |
|
current_chunk += ' ' + sentence |
|
else: |
|
# If it does, add the current chunk to the chunks list and start a new chunk with the current sentence. |
|
chunks.append(current_chunk) |
|
current_chunk = sentence |
|
|
|
# Add the last chunk to the chunks list. |
|
if current_chunk: |
|
chunks.append(current_chunk) |
|
|
|
else: |
|
chunks = [text] |
|
|
|
return chunks |
|
|
|
|
|
parliamentary_term_now = 9 #* Update this every term. |
|
|
|
|
|
model_mistral = "mistral-openorca" |
|
|
|
eu_country_codes = { |
|
"Belgium": "BE", |
|
"Greece": "EL", |
|
"Lithuania": "LT", |
|
"Portugal": "PT", |
|
"Bulgaria": "BG", |
|
"Spain": "ES", |
|
"Luxembourg": "LU", |
|
"Romania": "RO", |
|
"Czechia": "CZ", |
|
"France": "FR", |
|
"Hungary": "HU", |
|
"Slovenia": "SI", |
|
"Denmark": "DK", |
|
"Croatia": "HR", |
|
"Malta": "MT", |
|
"Slovakia": "SK", |
|
"Germany": "DE", |
|
"Italy": "IT", |
|
"Netherlands": "NL", |
|
"Finland": "FI", |
|
"Estonia": "EE", |
|
"Cyprus": "CY", |
|
"Austria": "AT", |
|
"Sweden": "SE", |
|
"Ireland": "IE", |
|
"Latvia": "LV", |
|
"Poland": "PL", |
|
} |
|
|
|
country_flags = { |
|
"United Kingdom": "🇬🇧", |
|
"Sweden": "🇸🇪", |
|
"Spain": "🇪🇸", |
|
"Slovenia": "🇸🇮", |
|
"Slovakia": "🇸🇰", |
|
"Romania": "🇷🇴", |
|
"Portugal": "🇵🇹", |
|
"Poland": "🇵🇱", |
|
"Netherlands": "🇳🇱", |
|
"Malta": "🇲🇹", |
|
"Luxembourg": "🇱🇺", |
|
"Lithuania": "🇱🇹", |
|
"Latvia": "🇱🇻", |
|
"Italy": "🇮🇹", |
|
"Ireland": "🇮🇪", |
|
"Hungary": "🇭🇺", |
|
"Greece": "🇬🇷", |
|
"Germany": "🇩🇪", |
|
"France": "🇫🇷", |
|
"Finland": "🇫🇮", |
|
"Estonia": "🇪🇪", |
|
"Denmark": "🇩🇰", |
|
"Czechia": "🇨🇿", |
|
"Cyprus": "🇨🇾", |
|
"Croatia": "🇭🇷", |
|
"Bulgaria": "🇧🇬", |
|
"Belgium": "🇧🇪", |
|
"Austria": "🇦🇹", |
|
} |
|
|
|
|
|
policy_areas = [ |
|
"Agriculture", |
|
"Business", |
|
"Industry", |
|
"Climate", |
|
"Culture", |
|
"Customs", |
|
"Development", |
|
"Education", |
|
"Employment", |
|
"Social Affairs", |
|
"Energy", |
|
"Environment", |
|
"FoodSafety", |
|
"SecurityPolicy", |
|
"Health", |
|
"Democracy", |
|
"Humanitarian Aid", |
|
"Justice", |
|
"Research And Innovation", |
|
"Market", |
|
"Taxation", |
|
"Trade", |
|
"Transport", |
|
] |
|
|
|
# From https://eur-lex.europa.eu/browse/summaries.html |
|
policy_areas = ['Agriculture', ' Audiovisual and media', ' Budget', ' Competition', ' Consumers', ' Culture', ' Customs', ' Development', ' Digital single market', ' Economic and monetary affairs', ' Education, training, youth, sport', ' Employment and social policy', ' Energy', ' Enlargement', ' Enterprise', ' Environment and climate change', ' External relations', ' External trade', ' Food safety', ' Foreign and security policy', ' Fraud and corruption', ' Humanitarian Aid and Civil Protection', ' Human rights', ' Institutional affairs', ' Internal market', ' Justice, freedom and security', ' Oceans and fisheries', ' Public health', ' Regional policy', ' Research and innovation', ' Taxation', ' Transport'] |
|
|
|
countries = [ |
|
"Romania", |
|
"Latvia", |
|
"Slovenia", |
|
"Denmark", |
|
"Spain", |
|
"Italy", |
|
"Hungary", |
|
"United Kingdom", |
|
"Netherlands", |
|
"Czechia", |
|
"Finland", |
|
"Belgium", |
|
"Germany", |
|
"France", |
|
"Slovakia", |
|
"Poland", |
|
"Ireland", |
|
"Malta", |
|
"Cyprus", |
|
"Luxembourg", |
|
"Greece", |
|
"Austria", |
|
"Sweden", |
|
"Portugal", |
|
"Lithuania", |
|
"Croatia", |
|
"Bulgaria", |
|
"Estonia", |
|
] |
|
|
|
parties = ["Renew", "S-D", "PPE", "Verts/ALE", "ECR", "NI", "The Left", "ID", "GUE/NGL"] |
|
|
|
party_colors = { |
|
"EPP": "#3399FF", |
|
"S-D": "#F0001C", |
|
"Renew": "gold", |
|
"ID": "#0E408A", |
|
"G/EFA": "#57B45F", |
|
"ECR": "#196CA8", |
|
"GUE/NGL": "#B71C1C", |
|
'The Left': "#B71C1C", # Same as GUE/NGL |
|
"NI": "white", |
|
"Vacant": "white", |
|
"PPE": "#3399FF", # Same as EPP |
|
"NULL": 'white', |
|
'Verts/ALE':"#57B45F", # Same as G/EFA |
|
None: 'white' |
|
} |
|
|
|
def insert_in_db(query): |
|
con = sqlite3.connect(path_db) |
|
con.row_factory = sqlite3.Row |
|
cursor = con.cursor() |
|
query = query |
|
cursor.execute(query) |
|
con.commit() |
|
con.close() |