commit
01df43bba2
26 changed files with 4022 additions and 0 deletions
@ -0,0 +1,66 @@ |
||||
import re |
||||
from arango import ArangoClient |
||||
from dotenv import load_dotenv |
||||
import os |
||||
import env_manager |
||||
|
||||
load_dotenv() # Install with pip install python-dotenv |
||||
|
||||
|
||||
class ArangoDB: |
||||
def __init__(self, user=None, password=None, db_name=None): |
||||
""" |
||||
Initializes an instance of the ArangoClass. |
||||
|
||||
Args: |
||||
db_name (str): The name of the database. |
||||
username (str): The username for authentication. |
||||
password (str): The password for authentication. |
||||
""" |
||||
|
||||
host = os.getenv("ARANGO_HOST") |
||||
if not user: |
||||
user = os.getenv("ARANGO_USER") |
||||
if not password: |
||||
password = os.getenv("ARANGO_PASSWORD") |
||||
if not db_name: |
||||
db_name = os.getenv("ARANGO_DB") |
||||
|
||||
self.client = ArangoClient(hosts=host) |
||||
self.db = self.client.db(db_name, username=user, password=password) |
||||
|
||||
def fix_key(self, _key): |
||||
""" |
||||
Sanitize a given key by replacing all characters that are not alphanumeric, |
||||
underscore, hyphen, dot, at symbol, parentheses, plus, equals, semicolon, |
||||
dollar sign, asterisk, single quote, percent, or colon with an underscore. |
||||
|
||||
Args: |
||||
_key (str): The key to be sanitized. |
||||
|
||||
Returns: |
||||
str: The sanitized key with disallowed characters replaced by underscores. |
||||
""" |
||||
|
||||
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;$!*\'%:]", "_", _key) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
|
||||
arango = ArangoDB(db_name='base') |
||||
articles = arango.db.collection('sci_articles').all() |
||||
for article in articles: |
||||
if 'metadata' in article and article['metadata']: |
||||
if 'abstract' in article['metadata']: |
||||
abstract = article['metadata']['abstract'] |
||||
if isinstance(abstract, str): |
||||
# Remove text within <> brackets and the brackets themselves |
||||
article['metadata']['abstract'] = re.sub(r'<[^>]*>', '', abstract) |
||||
arango.db.collection('sci_articles').update_match( |
||||
filters={'_key': article['_key']}, |
||||
body={'metadata': article['metadata']}, |
||||
merge=True |
||||
) |
||||
print(f"Updated abstract for {article['_key']}") |
||||
|
||||
|
||||
@ -0,0 +1,140 @@ |
||||
# _base_class.py |
||||
import os |
||||
import re |
||||
import streamlit as st |
||||
from _arango import ArangoDB |
||||
from _llm import LLM |
||||
from _chromadb import ChromaDB |
||||
|
||||
|
||||
class BaseClass: |
||||
def __init__(self, username: str, **kwargs) -> None: |
||||
self.username: str = username |
||||
self.project_name: str = kwargs.get('project_name', None) |
||||
self.collection: str = kwargs.get('collection_name', None) |
||||
self.user_arango: ArangoDB = self.get_arango() |
||||
|
||||
|
||||
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB: |
||||
if db_name: |
||||
return ArangoDB(db_name=db_name) |
||||
elif admin: |
||||
return ArangoDB() |
||||
else: |
||||
return ArangoDB(db_name=st.session_state["username"]) |
||||
|
||||
def get_settings(self): |
||||
settings = self.user_arango.db.document("settings/settings") |
||||
if not settings: |
||||
self.user_arango.db.collection("settings").insert( |
||||
{"_key": "settings", "current_collection": None, "current_page": None} |
||||
) |
||||
for i in ["current_collection", "current_page"]: |
||||
if i not in settings: |
||||
settings[i] = None |
||||
st.session_state["settings"] = settings |
||||
return settings |
||||
|
||||
def update_settings(self, key, value) -> None: |
||||
self.user_arango.db.collection("settings").update_match( |
||||
filters={"_key": "settings"}, |
||||
body={key: value}, |
||||
merge=True, |
||||
) |
||||
st.session_state["settings"][key] = value |
||||
|
||||
def get_article_collections(self) -> list: |
||||
article_collections = self.user_arango.db.aql.execute( |
||||
'FOR doc IN article_collections RETURN doc["name"]' |
||||
) |
||||
return list(article_collections) |
||||
|
||||
def get_projects(self) -> list: |
||||
projects = self.user_arango.db.aql.execute( |
||||
'FOR doc IN projects RETURN doc["name"]' |
||||
) |
||||
return list(projects) |
||||
|
||||
def choose_collection(self, text="Select a collection of favorite articles") -> str: |
||||
collections = self.get_article_collections() |
||||
collection = st.selectbox(text, collections, index=None) |
||||
if collection: |
||||
self.project = None |
||||
self.collection = collection |
||||
self.update_settings("current_collection", collection) |
||||
self.update_session_state() |
||||
return collection |
||||
|
||||
def choose_project(self, text="Select a project") -> str: |
||||
projects = self.get_projects() |
||||
|
||||
project = st.selectbox(text, projects, index=None) |
||||
if project: |
||||
from _classes import Project |
||||
self.project = Project(self.username, project, self.user_arango) |
||||
self.collection = None |
||||
self.update_settings("current_project", self.project.name) |
||||
self.update_session_state() |
||||
return self.project |
||||
|
||||
def get_chromadb(self): |
||||
return ChromaDB() |
||||
|
||||
def get_project(self, project_name: str): |
||||
doc = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN projects FILTER doc["name"] == "{project_name}" RETURN doc', |
||||
count=True, |
||||
) |
||||
if doc: |
||||
return doc.next() |
||||
|
||||
def update_current_page(self, page_name): |
||||
if st.session_state.get("current_page") != page_name: |
||||
st.session_state["current_page"] = page_name |
||||
self.update_settings("current_page", page_name) |
||||
|
||||
def update_session_state(self, page_name=None): |
||||
""" |
||||
Updates the Streamlit session state with the attributes of the current instance. |
||||
|
||||
Parameters: |
||||
page_name (str, optional): The name of the page to update in the session state. |
||||
If not provided, it defaults to the current page stored in the session state. |
||||
|
||||
The method iterates over the instance's attributes and updates the session state |
||||
for the given page name with those attributes that are of type str, int, float, list, dict, or bool. |
||||
""" |
||||
|
||||
if not page_name: |
||||
page_name = st.session_state.get("current_page") |
||||
for attr, value in self.__dict__.items(): |
||||
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]): |
||||
st.session_state[page_name][attr] = value |
||||
|
||||
def set_filename(self, filename=None, folder="other_documents"): |
||||
""" |
||||
Checks if a file with the given filename already exists in the download folder. |
||||
If it does, appends or increments a numerical suffix to the filename until a unique name is found. |
||||
|
||||
Args: |
||||
filename (str): The base name of the file to check. |
||||
|
||||
Sets: |
||||
self.file_path (str): The unique file path generated. |
||||
""" |
||||
|
||||
download_folder = f"user_data/{self.username}/{self.document_type}" |
||||
if self.is_sci and not self.document_type == "notes": |
||||
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_") |
||||
return os.path.exists(self.file_path) |
||||
else: |
||||
file_path = f"{self.download_folder}/{filename}" |
||||
while os.path.exists(file_path + ".pdf"): |
||||
if not re.search(r"(_\d+)$", file_path): |
||||
file_path += "_1" |
||||
else: |
||||
file_path = re.sub( |
||||
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path |
||||
) |
||||
self.file_path = file_path + ".pdf" |
||||
return file_path |
||||
@ -0,0 +1,135 @@ |
||||
import chromadb |
||||
import os |
||||
from chromadb.config import Settings |
||||
from dotenv import load_dotenv |
||||
from colorprinter.print_color import * |
||||
|
||||
load_dotenv(".env") |
||||
|
||||
|
||||
class ChromaDB: |
||||
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None): |
||||
if local_deployment: |
||||
self.db = chromadb.PersistentClient(f"chroma_{db}") |
||||
else: |
||||
if not host: |
||||
host = os.getenv("CHROMA_HOST") |
||||
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS") |
||||
auth_token_transport_header = os.getenv( |
||||
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER" |
||||
) |
||||
self.db = chromadb.HttpClient( |
||||
host=host, |
||||
settings=Settings( |
||||
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider", |
||||
chroma_client_auth_credentials=credentials, |
||||
chroma_auth_token_transport_header=auth_token_transport_header, |
||||
), |
||||
) |
||||
|
||||
def query( |
||||
self, |
||||
query, |
||||
collection, |
||||
n_results=6, |
||||
n_sources=3, |
||||
max_retries=None, |
||||
where: dict = None, |
||||
**kwargs, |
||||
): |
||||
if not isinstance(n_sources, int): |
||||
n_sources = int(n_sources) |
||||
if not isinstance(n_results, int): |
||||
n_results = int(n_results) |
||||
if not max_retries: |
||||
max_retries = n_sources |
||||
if n_sources > n_results: |
||||
n_sources = n_results |
||||
col = self.db.get_collection(collection) |
||||
sources = [] |
||||
n = 0 |
||||
|
||||
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]} |
||||
while True: |
||||
n += 1 |
||||
if n > max_retries: |
||||
break |
||||
r = col.query( |
||||
query_texts=query, |
||||
n_results=n_results - len(sources), |
||||
where=where, |
||||
**kwargs, |
||||
) |
||||
if r["ids"][0] == []: |
||||
if result['ids'][0] == []: |
||||
print_red("No results found in vector database.") |
||||
else: |
||||
print_red("No more results found in vector database.") |
||||
break |
||||
|
||||
# Manually extend each list within the lists of lists |
||||
for key in result: |
||||
if key in r: |
||||
result[key][0].extend(r[key][0]) |
||||
|
||||
# Order result by distance |
||||
combined = sorted( |
||||
zip( |
||||
result["distances"][0], |
||||
result["ids"][0], |
||||
result["metadatas"][0], |
||||
result["documents"][0], |
||||
), |
||||
key=lambda x: x[0], |
||||
) |
||||
( |
||||
result["distances"][0], |
||||
result["ids"][0], |
||||
result["metadatas"][0], |
||||
result["documents"][0], |
||||
) = map(list, zip(*combined)) |
||||
sources += list(set([i["_id"] for i in result["metadatas"][0]])) |
||||
if len(sources) >= n_sources: |
||||
break |
||||
elif n != max_retries: |
||||
for k, v in result.items(): |
||||
if k not in r["included"]: |
||||
continue |
||||
result[k][0] = v[0][: n_results - (n_sources - len(sources))] |
||||
if "_id" in where: |
||||
where["_id"]["$in"] = [ |
||||
i for i in where["_id"]["$in"] if i not in sources |
||||
] |
||||
if where["_id"]["$in"] == []: |
||||
break |
||||
return result |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
from colorprinter.print_color import * |
||||
|
||||
chroma = ChromaDB() |
||||
|
||||
result = chroma.query( |
||||
query="What is Open Science)", |
||||
collection="sci_articles", |
||||
n_results=2, |
||||
n_sources=3, |
||||
max_retries=4, |
||||
) |
||||
print(result) |
||||
exit() |
||||
all = chroma_collection.get() |
||||
|
||||
ids = all.get("ids", []) |
||||
metadatas = all.get("metadatas", []) |
||||
|
||||
combined_list = list(zip(ids, metadatas)) |
||||
|
||||
ids = [] |
||||
metadatas = [] |
||||
for id, metadata in combined_list: |
||||
ids.append(id) |
||||
metadata["_id"] = f"sci_articles/{metadata['_key']}" |
||||
metadatas.append(metadata) |
||||
chroma_collection.update(ids=ids, metadatas=metadatas) |
||||
@ -0,0 +1,929 @@ |
||||
# streamlit_pages.py |
||||
|
||||
import re |
||||
import streamlit as st |
||||
from time import sleep |
||||
import pandas as pd |
||||
from datetime import datetime, timedelta |
||||
from PIL import Image |
||||
from io import BytesIO |
||||
import base64 |
||||
from colorprinter.print_color import * |
||||
from article2db import PDFProcessor |
||||
from streamlit_chatbot import Chat, EditorBot, ResearchAssistantBot, PodBot |
||||
|
||||
from info import country_emojis |
||||
from utils import fix_key |
||||
from _arango import ArangoDB |
||||
from _llm import LLM |
||||
from _base_class import BaseClass |
||||
|
||||
from prompts import get_note_summary_prompt, get_image_system_prompt |
||||
|
||||
|
||||
class ArticleCollectionsPage(BaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
self.collection = None |
||||
self.page_name = "Article Collections" |
||||
|
||||
# Initialize attributes from session state if available |
||||
for k, v in st.session_state[self.page_name].items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
if self.user_arango.db.collection("article_collections").count() == 0: |
||||
self.create_new_collection() |
||||
|
||||
self.update_current_page(self.page_name) |
||||
|
||||
self.choose_collection_method() |
||||
self.choose_project_method() |
||||
|
||||
if self.collection: |
||||
self.display_collection() |
||||
self.sidebar_actions() |
||||
|
||||
if st.session_state.get("new_collection"): |
||||
self.create_new_collection() |
||||
|
||||
# Persist state to session_state |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def choose_collection_method(self): |
||||
self.collection = self.choose_collection() |
||||
# Persist state after choosing collection |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def choose_project_method(self): |
||||
# If you have a project selection similar to collection, implement here |
||||
pass # Placeholder for project-related logic |
||||
|
||||
def choose_collection(self): |
||||
collections = self.get_article_collections() |
||||
current_collection = self.collection |
||||
preselected = ( |
||||
collections.index(current_collection) |
||||
if current_collection in collections |
||||
else None |
||||
) |
||||
with st.sidebar: |
||||
collection = st.selectbox( |
||||
"Select a collection of favorite articles", |
||||
collections, |
||||
index=preselected, |
||||
) |
||||
if collection: |
||||
self.collection = collection |
||||
self.update_settings("current_collection", collection) |
||||
return self.collection |
||||
|
||||
def create_new_collection(self): |
||||
with st.form("create_collection_form", clear_on_submit=True): |
||||
new_collection_name = st.text_input("Enter the name of the new collection") |
||||
submitted = st.form_submit_button("Create Collection") |
||||
if submitted: |
||||
if new_collection_name: |
||||
self.user_arango.db.collection("article_collections").insert( |
||||
{"name": new_collection_name, "articles": []} |
||||
) |
||||
st.success(f'New collection "{new_collection_name}" created') |
||||
self.collection = new_collection_name |
||||
self.update_settings("current_collection", new_collection_name) |
||||
# Persist state after creating a new collection |
||||
self.update_session_state(page_name=self.page_name) |
||||
sleep(1) |
||||
st.rerun() |
||||
|
||||
def display_collection(self): |
||||
with st.sidebar: |
||||
col1, col2 = st.columns(2) |
||||
with col1: |
||||
if st.button("Create new collection"): |
||||
st.session_state["new_collection"] = True |
||||
with col2: |
||||
if st.button(f'Remove collection "{self.collection}"'): |
||||
self.user_arango.db.collection("article_collections").delete_match( |
||||
{"name": self.collection} |
||||
) |
||||
st.success(f'Collection "{self.collection}" removed') |
||||
self.collection = None |
||||
self.update_settings("current_collection", None) |
||||
# Persist state after removing a collection |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
self.show_articles_in_collection() |
||||
|
||||
def show_articles_in_collection(self): |
||||
collection_articles_cursor = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN article_collections FILTER doc["name"] == "{self.collection}" RETURN doc["articles"]' |
||||
) |
||||
collection_articles = list(collection_articles_cursor) |
||||
if collection_articles and collection_articles[0]: |
||||
articles = collection_articles[0] |
||||
st.markdown(f"#### Articles in *{self.collection}*:") |
||||
|
||||
for article in articles: |
||||
if article is None: |
||||
continue |
||||
metadata = article.get("metadata") |
||||
if metadata is None: |
||||
continue |
||||
|
||||
title = metadata.get("title", "No Title").strip() |
||||
journal = metadata.get("journal", "No Journal").strip() |
||||
published_year = metadata.get("published_year", "No Year") |
||||
published_date = metadata.get("published_date", None) |
||||
language = metadata.get("language", "No Language") |
||||
icon = country_emojis.get(language.upper(), "") if language else "" |
||||
|
||||
expander_title = f"**{title}** *{journal}* ({published_year}) {icon}" |
||||
with st.expander(expander_title): |
||||
if not title == "No Title": |
||||
st.markdown(f"**Title:** \n{title}") |
||||
if not journal == "No Journal": |
||||
st.markdown(f"**Journal:** \n{journal}") |
||||
|
||||
if published_date: |
||||
st.markdown(f"**Published Date:** \n{published_date}") |
||||
for key, value in article.items(): |
||||
if key in [ |
||||
"_key", |
||||
"text", |
||||
"file", |
||||
"_rev", |
||||
"chunks", |
||||
"user_access", |
||||
"_id", |
||||
"metadata", |
||||
"doi", |
||||
]: |
||||
continue |
||||
if isinstance(value, list): |
||||
value = ", ".join(value) |
||||
st.markdown(f"**{key.capitalize()}**: \n{value} ") |
||||
if "doi" in article: |
||||
st.markdown( |
||||
f"**DOI:** \n[{article['doi']}](https://doi.org/{article['doi']}) " |
||||
) |
||||
|
||||
st.button( |
||||
key=f'delete_{article["_id"]}', |
||||
label="Delete article from collection", |
||||
on_click=self.delete_article, |
||||
args=(self.collection, article["_id"]), |
||||
) |
||||
else: |
||||
st.write("No articles in this collection.") |
||||
|
||||
def sidebar_actions(self): |
||||
with st.sidebar: |
||||
st.markdown(f"### Add new articles to {self.collection}") |
||||
with st.form("add_articles_form", clear_on_submit=True): |
||||
pdf_files = st.file_uploader( |
||||
"Upload PDF file(s)", type=["pdf"], accept_multiple_files=True |
||||
) |
||||
is_sci = st.checkbox("All articles are from scientific journals") |
||||
submitted = st.form_submit_button("Upload") |
||||
if submitted and pdf_files: |
||||
self.add_articles(pdf_files, is_sci) |
||||
# Persist state after adding articles |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
help_text = 'Paste a text containing DOIs, e.g., the reference section of a paper, and click "Add Articles" to add them to the collection.' |
||||
new_articles = st.text_area( |
||||
"Add articles to this collection", help=help_text |
||||
) |
||||
if st.button("Add Articles"): |
||||
with st.spinner("Processing..."): |
||||
self.process_dois( |
||||
article_collection_name=self.collection, text=new_articles |
||||
) |
||||
# Persist state after processing DOIs |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
self.write_not_downloaded() |
||||
|
||||
def add_articles(self, pdf_files: list, is_sci: bool) -> None: |
||||
|
||||
for pdf_file in pdf_files: |
||||
status_container = st.empty() |
||||
with status_container: |
||||
is_sci = is_sci if is_sci else None |
||||
with st.status(f"Processing {pdf_file.name}..."): |
||||
processor = PDFProcessor( |
||||
pdf_file=pdf_file, |
||||
filename=pdf_file.name, |
||||
process=False, |
||||
username=st.session_state["username"], |
||||
document_type="other_documents", |
||||
is_sci=is_sci, |
||||
) |
||||
_id, db, doi = processor.process_document() |
||||
print_rainbow(_id, db, doi) |
||||
if doi in st.session_state.get("not_downloaded", {}): |
||||
st.session_state["not_downloaded"].pop(doi) |
||||
self.articles2collection(collection=self.collection, db=db, _id=_id) |
||||
st.success("Done!") |
||||
sleep(1.5) |
||||
|
||||
def articles2collection(self, collection: str, db: str, _id: str = None) -> None: |
||||
info = self.get_article_info(db, _id=_id) |
||||
info = { |
||||
k: v for k, v in info.items() if k in ["_id", "doi", "title", "metadata"] |
||||
} |
||||
doc_cursor = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc' |
||||
) |
||||
doc = next(doc_cursor, None) |
||||
if doc: |
||||
articles = doc.get("articles", []) |
||||
keys = [i["_id"] for i in articles] |
||||
if info["_id"] not in keys: |
||||
articles.append(info) |
||||
self.user_arango.db.collection("article_collections").update_match( |
||||
filters={"name": collection}, |
||||
body={"articles": articles}, |
||||
merge=True, |
||||
) |
||||
# Persist state after updating articles |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def get_article_info(self, db: str, _id: str = None, doi: str = None) -> dict: |
||||
assert _id or doi, "Either _id or doi must be provided." |
||||
arango = self.get_arango(db_name=db) |
||||
if _id: |
||||
query = """ |
||||
RETURN { |
||||
"_id": DOCUMENT(@doc_id)._id, |
||||
"doi": DOCUMENT(@doc_id).doi, |
||||
"title": DOCUMENT(@doc_id).title, |
||||
"metadata": DOCUMENT(@doc_id).metadata, |
||||
"summary": DOCUMENT(@doc_id).summary |
||||
} |
||||
""" |
||||
|
||||
info_cursor = arango.db.aql.execute(query, bind_vars={"doc_id": _id}) |
||||
elif doi: |
||||
info_cursor = arango.db.aql.execute( |
||||
f'FOR doc IN sci_articles FILTER doc["doi"] == "{doi}" LIMIT 1 RETURN {{"_id": doc["_id"], "doi": doc["doi"], "title": doc["title"], "metadata": doc["metadata"], "summary": doc["summary"]}}' |
||||
) |
||||
return next(info_cursor, None) |
||||
|
||||
def process_dois( |
||||
self, article_collection_name: str, text: str = None, dois: list = None |
||||
) -> None: |
||||
processor = PDFProcessor(process=False) |
||||
if not dois and text: |
||||
dois = processor.extract_doi(text, multi=True) |
||||
if "not_downloaded" not in st.session_state: |
||||
st.session_state["not_downloaded"] = {} |
||||
for doi in dois: |
||||
downloaded, url, path, in_db = processor.doi2pdf(doi) |
||||
if downloaded and not in_db: |
||||
processor.process_pdf(path) |
||||
in_db = True |
||||
elif not downloaded and not in_db: |
||||
st.session_state["not_downloaded"][doi] = url |
||||
|
||||
if in_db: |
||||
st.success(f"Article with DOI {doi} added") |
||||
self.articles2collection( |
||||
collection=article_collection_name, |
||||
db="base", |
||||
_id=f"sci_articles/{fix_key(doi)}", |
||||
) |
||||
# Persist state after processing DOIs |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def write_not_downloaded(self): |
||||
not_downloaded = st.session_state.get("not_downloaded", {}) |
||||
if not_downloaded: |
||||
st.markdown( |
||||
"*The articles below were not downloaded. Download them yourself and add them to the collection by dropping them in the area above. Some of them can be downloaded using the link.*" |
||||
) |
||||
for doi, url in not_downloaded.items(): |
||||
if url: |
||||
st.markdown(f"- [{doi}]({url})") |
||||
else: |
||||
st.markdown(f"- {doi}") |
||||
|
||||
def delete_article(self, collection, _id): |
||||
doc_cursor = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc' |
||||
) |
||||
doc = next(doc_cursor, None) |
||||
if doc: |
||||
articles = [ |
||||
article for article in doc.get("articles", []) if article["_id"] != _id |
||||
] |
||||
self.user_arango.db.collection("article_collections").update_match( |
||||
filters={"_id": doc["_id"]}, |
||||
body={"articles": articles}, |
||||
) |
||||
# Persist state after deleting an article |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
|
||||
class BotChatPage(BaseClass): |
||||
def __init__(self, username): |
||||
super().__init__(username=username) |
||||
self.collection_name = None |
||||
self.project_name = None |
||||
self.project: Project = None |
||||
self.chat = None |
||||
self.role = "Research Assistant" # Default persona |
||||
self.page_name = "Bot Chat" |
||||
|
||||
# Initialize attributes from session state if available |
||||
if self.page_name in st.session_state: |
||||
for k, v in st.session_state[self.page_name].items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
bot = None |
||||
self.update_current_page("Bot Chat") |
||||
self.remove_old_unsaved_chats() |
||||
self.sidebar_actions() |
||||
|
||||
if self.collection_name or self.project: |
||||
# If no chat exists, create a new Chat instance |
||||
if not self.chat: |
||||
self.chat = Chat(username=self.username, role=self.role) |
||||
self.chat.show_chat_history() |
||||
# Create a Bot instance with the Chat object |
||||
if self.role == "Research Assistant": |
||||
bot = ResearchAssistantBot( |
||||
username=self.username, |
||||
chat=self.chat, |
||||
collection=self.collection_name, |
||||
project=self.project, |
||||
) |
||||
elif self.role == "Editor": |
||||
bot = EditorBot( |
||||
username=self.username, |
||||
chat=self.chat, |
||||
collection=self.collection, |
||||
project=self.project, |
||||
) |
||||
elif self.role == "Podcast": |
||||
st.session_state["make_podcast"] = True |
||||
# with st.sidebar: |
||||
with st.sidebar: |
||||
with st.form("make_podcast_form"): |
||||
instructions = st.text_area( |
||||
"What should the podcast be about? Give a brief description, as if you were the producer." |
||||
) |
||||
start = st.form_submit_button("Make Podcast!") |
||||
if start: |
||||
bot = PodBot( |
||||
subject=self.project.name, |
||||
username=self.username, |
||||
chat=self.chat, |
||||
collection=self.collection, |
||||
project=self.project, |
||||
instructions=instructions, |
||||
) |
||||
|
||||
# Run the bot (this will display chat history and process user input) |
||||
if bot: |
||||
bot.run() |
||||
|
||||
# Save updated chat state to session state |
||||
st.session_state[self.page_name] = { |
||||
"collection": self.collection, |
||||
"project": self.project, |
||||
"chat": self.chat, |
||||
"role": self.role, |
||||
} |
||||
|
||||
def sidebar_actions(self): |
||||
with st.sidebar: |
||||
self.collection = self.choose_collection( |
||||
"Article collection to use for chat:" |
||||
) |
||||
self.project_name = self.choose_project("Project to use for chat:") |
||||
|
||||
if self.collection or self.project: |
||||
st.write("---") |
||||
if self.project_name: |
||||
self.role = st.selectbox( |
||||
"Choose Bot Role", |
||||
options=["Research Assistant", "Editor", "Podcast"], |
||||
index=0, |
||||
) |
||||
elif self.collection: |
||||
self.role = "Research Assistant" |
||||
|
||||
# Load existing chats from the database |
||||
if self.project: |
||||
chat_history = list( |
||||
self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc["project"] == "{self.project}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}' |
||||
) |
||||
) |
||||
# self.project = Project(username=self.username, project_name=self.project_name, user_arango=self.user_arango) |
||||
elif self.collection: |
||||
chat_history = list( |
||||
self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc["collection"] == "{self.collection}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}' |
||||
) |
||||
) |
||||
|
||||
chats = {i["name"]: i["_key"] for i in chat_history} |
||||
selected_chat = st.selectbox( |
||||
"Continue another chat", options=[""] + list(chats.keys()), index=0 |
||||
) |
||||
if selected_chat: |
||||
chat_data = self.user_arango.db.collection("chats").get( |
||||
chats[selected_chat] |
||||
) |
||||
chat_dict = chat_data.get("chat_data", {}) |
||||
self.chat = Chat.from_dict(chat_dict) |
||||
else: |
||||
self.chat = None |
||||
|
||||
def remove_old_unsaved_chats(self): |
||||
two_weeks_ago = datetime.now() - timedelta(weeks=2) |
||||
old_chats = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc.saved == false AND doc.last_updated < "{two_weeks_ago.isoformat()}" RETURN doc' |
||||
) |
||||
for chat in old_chats: |
||||
self.user_arango.db.collection("chats").delete(chat["_key"]) |
||||
|
||||
|
||||
class ProjectsPage(BaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
self.projects = [] |
||||
self.selected_project_name = None |
||||
self.project = None |
||||
self.page_name = "Projects" |
||||
|
||||
# Initialize attributes from session state if available |
||||
page_state = st.session_state.get(self.page_name, {}) |
||||
for k, v in page_state.items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
self.update_current_page(self.page_name) |
||||
self.load_projects() |
||||
self.display_projects() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def load_projects(self): |
||||
projects_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN projects RETURN doc", count=True |
||||
) |
||||
self.projects = list(projects_cursor) |
||||
|
||||
def display_projects(self): |
||||
with st.sidebar: |
||||
self.new_project_button() |
||||
self.selected_project_name = st.selectbox( |
||||
"Select a project to manage", |
||||
options=[proj["name"] for proj in self.projects], |
||||
) |
||||
if self.selected_project_name: |
||||
self.project = Project( |
||||
username=self.username, |
||||
project_name=self.selected_project_name, |
||||
user_arango=self.user_arango, |
||||
) |
||||
self.manage_project() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def new_project_button(self): |
||||
st.session_state.setdefault("new_project", False) |
||||
with st.sidebar: |
||||
if st.button("New project", type="primary"): |
||||
st.session_state["new_project"] = True |
||||
if st.session_state["new_project"]: |
||||
self.create_new_project() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def create_new_project(self): |
||||
new_project_name = st.text_input("Enter the name of the new project") |
||||
new_project_description = st.text_area( |
||||
"Enter the description of the new project" |
||||
) |
||||
if st.button("Create Project"): |
||||
if new_project_name: |
||||
self.user_arango.db.collection("projects").insert( |
||||
{ |
||||
"name": new_project_name, |
||||
"description": new_project_description, |
||||
"collections": [], |
||||
"notes": [], |
||||
"note_keys_hash": hash(""), |
||||
"settings": {}, |
||||
} |
||||
) |
||||
st.success(f'New project "{new_project_name}" created') |
||||
st.session_state["new_project"] = False |
||||
self.update_settings("current_project", new_project_name) |
||||
sleep(1) |
||||
st.rerun() |
||||
|
||||
def show_project_notes(self): |
||||
|
||||
with st.expander("Show summarised notes"): |
||||
st.markdown(self.project.notes_summary) |
||||
|
||||
with st.expander("Show project notes"): |
||||
notes_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc", |
||||
bind_vars={"note_ids": self.project.notes}, |
||||
) |
||||
notes = list(notes_cursor) |
||||
if notes: |
||||
for note in notes: |
||||
st.markdown(f'_{note.get("timestamp", "")}_') |
||||
st.markdown(note["text"].replace("\n", " \n")) |
||||
st.button( |
||||
key=f'delete_note_{note["_id"]}', |
||||
label="Delete note", |
||||
on_click=self.project.delete_note, |
||||
args=(note["_id"],), |
||||
) |
||||
st.write("---") |
||||
else: |
||||
st.write("No notes in this project.") |
||||
|
||||
def manage_project(self): |
||||
self.update_settings("current_project", self.selected_project_name) |
||||
# Initialize the Project instance |
||||
self.project = Project( |
||||
self.username, self.selected_project_name, self.user_arango |
||||
) |
||||
st.write(f"## {self.project.name}") |
||||
self.show_project_notes() |
||||
self.relate_collections() |
||||
self.sidebar_actions() |
||||
self.project.update_notes_hash() |
||||
if st.button(f"Remove project *{self.project.name}*"): |
||||
self.user_arango.db.collection("projects").delete_match( |
||||
{"name": self.project.name} |
||||
) |
||||
self.update_settings("current_project", None) |
||||
st.success(f'Project "{self.project.name}" removed') |
||||
st.rerun() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def relate_collections(self): |
||||
collections = [ |
||||
col["name"] |
||||
for col in self.user_arango.db.collection("article_collections").all() |
||||
] |
||||
selected_collections = st.multiselect( |
||||
"Relate existing collections", options=collections |
||||
) |
||||
if st.button("Relate Collections"): |
||||
self.project.add_collections(selected_collections) |
||||
st.success("Collections related to the project") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
new_collection_name = st.text_input( |
||||
"Enter the name of the new collection to create and relate" |
||||
) |
||||
if st.button("Create and Relate Collection"): |
||||
if new_collection_name: |
||||
self.user_arango.db.collection("article_collections").insert( |
||||
{"name": new_collection_name, "articles": []} |
||||
) |
||||
self.project.add_collection(new_collection_name) |
||||
st.success( |
||||
f'New collection "{new_collection_name}" created and related to the project' |
||||
) |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def sidebar_actions(self): |
||||
self.sidebar_notes() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def sidebar_notes(self): |
||||
with st.sidebar: |
||||
st.markdown(f"### Add new notes to {self.project.name}") |
||||
self.upload_notes_form() |
||||
self.add_text_note() |
||||
self.add_wikipedia_data() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def upload_notes_form(self): |
||||
with st.form("add_notes", clear_on_submit=True): |
||||
files = st.file_uploader( |
||||
"Upload PDF or image", |
||||
type=["png", "jpg", "pdf"], |
||||
accept_multiple_files=True, |
||||
) |
||||
submitted = st.form_submit_button("Upload") |
||||
if submitted: |
||||
self.project.process_uploaded_files(files) |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def add_text_note(self): |
||||
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies." |
||||
note_text = st.text_area("Write or paste anything.", help=help_text) |
||||
if st.button("Add Note"): |
||||
self.project.add_note( |
||||
{ |
||||
"text": note_text, |
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"), |
||||
} |
||||
) |
||||
st.success("Note added to the project") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def add_wikipedia_data(self): |
||||
wiki_url = st.text_input( |
||||
"Paste the address to a Wikipedia page to add its summary as a note", |
||||
placeholder="Paste Wikipedia URL", |
||||
) |
||||
if st.button("Add Wikipedia data"): |
||||
with st.spinner("Fetching Wikipedia data..."): |
||||
wiki_data = self.project.get_wikipedia_data(wiki_url) |
||||
if wiki_data: |
||||
self.project.process_wikipedia_data(wiki_data, wiki_url) |
||||
st.success("Wikipedia data added to notes") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
st.rerun() |
||||
|
||||
|
||||
class Project(BaseClass): |
||||
def __init__(self, username: str, project_name: str, user_arango: ArangoDB): |
||||
super().__init__(username=username) |
||||
self.name = project_name |
||||
self.user_arango = user_arango |
||||
self.description = "" |
||||
self.collections = [] |
||||
self.notes = [] |
||||
self.note_keys_hash = 0 |
||||
self.settings = {} |
||||
self.notes_summary = "" |
||||
|
||||
# Initialize attributes from arango doc if available |
||||
self.load_project() |
||||
|
||||
def load_project(self): |
||||
print_blue("Project name:", self.name) |
||||
print(self.user_arango, type(self.user_arango)) |
||||
project_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN projects FILTER doc.name == @name RETURN doc", |
||||
bind_vars={"name": self.name}, |
||||
) |
||||
project = next(project_cursor, None) |
||||
if not project: |
||||
raise ValueError(f"Project '{self.name}' not found.") |
||||
self._key = project["_key"] |
||||
self.name = project.get("name", "") |
||||
self.description = project.get("description", "") |
||||
self.collections = project.get("collections", []) |
||||
self.notes = project.get("notes", []) |
||||
self.note_keys_hash = project.get("note_keys_hash", 0) |
||||
self.settings = project.get("settings", {}) |
||||
self.notes_summary = project.get("notes_summary", "") |
||||
|
||||
def update_project(self): |
||||
updated_doc = { |
||||
"_key": self._key, |
||||
"name": self.name, |
||||
"description": self.description, |
||||
"collections": self.collections, |
||||
"notes": self.notes, |
||||
"note_keys_hash": self.note_keys_hash, |
||||
"settings": self.settings, |
||||
"notes_summary": self.notes_summary, |
||||
} |
||||
self.user_arango.db.collection("projects").update(updated_doc, check_rev=False) |
||||
self.update_session_state() |
||||
|
||||
def add_collections(self, collections): |
||||
self.collections.extend(collections) |
||||
self.update_project() |
||||
|
||||
def add_collection(self, collection_name): |
||||
self.collections.append(collection_name) |
||||
self.update_project() |
||||
|
||||
def add_note(self, note: dict): |
||||
assert note["text"], "Note text cannot be empty" |
||||
note["text"] = note["text"].strip().strip("\n") |
||||
if "timestamp" not in note: |
||||
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M") |
||||
note_doc = self.user_arango.db.collection("notes").insert(note) |
||||
if note_doc["_id"] not in self.notes: |
||||
self.notes.append(note_doc["_id"]) |
||||
self.update_project() |
||||
|
||||
def delete_note(self, note_id): |
||||
if note_id in self.notes: |
||||
self.notes.remove(note_id) |
||||
self.update_project() |
||||
|
||||
def update_notes_hash(self): |
||||
current_hash = self.make_project_notes_hash() |
||||
if current_hash != self.note_keys_hash: |
||||
self.note_keys_hash = current_hash |
||||
with st.spinner("Summarizing notes for chatbot..."): |
||||
self.create_notes_summary() |
||||
self.update_project() |
||||
|
||||
def make_project_notes_hash(self): |
||||
if not self.notes: |
||||
return hash("") |
||||
note_keys_str = "".join(self.notes) |
||||
return hash(note_keys_str) |
||||
|
||||
def create_notes_summary(self): |
||||
notes_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text", |
||||
bind_vars={"note_ids": self.notes}, |
||||
) |
||||
notes = list(notes_cursor) |
||||
notes_string = "\n---\n".join(notes) |
||||
llm = LLM(model="small") |
||||
query = get_note_summary_prompt(self, notes_string) |
||||
summary = llm.generate(query) |
||||
print_purple("New summary of notes:", summary) |
||||
self.notes_summary = summary |
||||
self.update_session_state() |
||||
|
||||
def analyze_image(self, image_base64, text=None): |
||||
project_data = {"name": self.name} |
||||
llm = LLM(system_message=get_image_system_prompt(project_data)) |
||||
prompt = ( |
||||
f'Analyze the image. The text found in it read: "{text}"' |
||||
if text |
||||
else "Analyze the image." |
||||
) |
||||
description = llm.generate(query=prompt, images=[image_base64], stream=False) |
||||
print_green("Image description:", description) |
||||
|
||||
def process_uploaded_files(self, files): |
||||
with st.spinner("Processing files..."): |
||||
for file in files: |
||||
st.write("Processing...") |
||||
filename = fix_key(file.name) |
||||
|
||||
image_file = self.file2img(file) |
||||
pdf_file = self.convert_image_to_pdf(image_file) |
||||
pdf = PDFProcessor( |
||||
pdf_file=pdf_file, |
||||
is_sci=False, |
||||
document_type="notes", |
||||
is_image=True, |
||||
process=False, |
||||
) |
||||
text = pdf.process_pdf() |
||||
base64_str = base64.b64encode(file.read()) |
||||
image_caption = self.analyze_image(base64_str, text=text) |
||||
self.add_note( |
||||
{ |
||||
"_id": f"notes/{filename}", |
||||
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}", |
||||
} |
||||
) |
||||
st.success("Done!") |
||||
sleep(1.5) |
||||
self.update_session_state() |
||||
st.rerun() |
||||
|
||||
def file2img(self, file): |
||||
img_bytes = file.read() |
||||
if not img_bytes: |
||||
raise ValueError("Uploaded file is empty.") |
||||
return Image.open(BytesIO(img_bytes)) |
||||
|
||||
def convert_image_to_pdf(self, img): |
||||
import pytesseract |
||||
|
||||
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img) |
||||
pdf_file = BytesIO(pdf_bytes) |
||||
pdf_file.name = ( |
||||
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf" |
||||
) |
||||
return pdf_file |
||||
|
||||
def get_wikipedia_data(self, page_url: str) -> dict: |
||||
import wikipedia |
||||
from urllib.parse import urlparse |
||||
|
||||
parsed_url = urlparse(page_url) |
||||
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path) |
||||
if page_name_match: |
||||
page_name = page_name_match.group(0) |
||||
else: |
||||
st.warning("Invalid Wikipedia URL") |
||||
return None |
||||
|
||||
try: |
||||
page = wikipedia.page(page_name) |
||||
data = { |
||||
"title": page.title, |
||||
"summary": page.summary, |
||||
"content": page.content, |
||||
"url": page.url, |
||||
"references": page.references, |
||||
} |
||||
return data |
||||
except Exception as e: |
||||
st.error(f"Error fetching Wikipedia data: {e}") |
||||
return None |
||||
|
||||
def process_wikipedia_data(self, wiki_data, wiki_url): |
||||
llm = LLM( |
||||
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!", |
||||
model="small", |
||||
) |
||||
if wiki_data.get("summary"): |
||||
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.''' |
||||
summary = llm.generate(query) |
||||
wiki_data["text"] = ( |
||||
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}" |
||||
) |
||||
wiki_data.pop("summary", None) |
||||
wiki_data.pop("content", None) |
||||
self.user_arango.db.collection("notes").insert( |
||||
wiki_data, overwrite=True, silent=True |
||||
) |
||||
self.add_note(wiki_data) |
||||
|
||||
processor = PDFProcessor(process=False) |
||||
dois = [ |
||||
processor.extract_doi(ref) |
||||
for ref in wiki_data.get("references", []) |
||||
if processor.extract_doi(ref) |
||||
] |
||||
if dois: |
||||
current_collection = st.session_state["settings"].get("current_collection") |
||||
st.markdown( |
||||
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?" |
||||
) |
||||
if st.button("Add DOIs"): |
||||
self.process_dois(current_collection, dois=dois) |
||||
self.update_session_state() |
||||
|
||||
def process_dois( |
||||
self, article_collection_name: str, text: str = None, dois: list = None |
||||
) -> None: |
||||
processor = PDFProcessor(process=False) |
||||
if not dois and text: |
||||
dois = processor.extract_doi(text, multi=True) |
||||
if "not_downloaded" not in st.session_state: |
||||
st.session_state["not_downloaded"] = {} |
||||
for doi in dois: |
||||
downloaded, url, path, in_db = processor.doi2pdf(doi) |
||||
if downloaded and not in_db: |
||||
processor.process_pdf(path) |
||||
in_db = True |
||||
elif not downloaded and not in_db: |
||||
st.session_state["not_downloaded"][doi] = url |
||||
|
||||
if in_db: |
||||
st.success(f"Article with DOI {doi} added") |
||||
self.articles2collection( |
||||
collection=article_collection_name, |
||||
db="base", |
||||
_id=f"sci_articles/{fix_key(doi)}", |
||||
) |
||||
self.update_session_state() |
||||
|
||||
|
||||
class SettingsPage(BaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
|
||||
def run(self): |
||||
self.update_current_page("Settings") |
||||
self.set_profile_picture() |
||||
|
||||
def set_profile_picture(self): |
||||
st.markdown("Profile picture") |
||||
profile_picture = st.file_uploader( |
||||
"Upload profile picture", type=["png", "jpg", "jpeg"] |
||||
) |
||||
if profile_picture: |
||||
# Resize the image to 64x64 pixels |
||||
from PIL import Image |
||||
|
||||
img = Image.open(profile_picture) |
||||
img.thumbnail((64, 64)) |
||||
img_path = f"user_data/{st.session_state['username']}/profile_picture.png" |
||||
img.save(img_path) |
||||
self.update_settings("avatar", img_path) |
||||
st.success("Profile picture uploaded") |
||||
sleep(1) |
||||
@ -0,0 +1,253 @@ |
||||
import os |
||||
from typing import Literal, Optional |
||||
import requests |
||||
from requests.auth import HTTPBasicAuth |
||||
import tiktoken |
||||
import json |
||||
from colorprinter.print_color import * |
||||
import env_manager |
||||
import re |
||||
|
||||
|
||||
env_manager.set_env() |
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") |
||||
|
||||
print(os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE")) |
||||
class LLM: |
||||
def __init__( |
||||
self, |
||||
system_message="You are an assistant.", |
||||
num_ctx=8192, |
||||
temperature=0.01, |
||||
model: Optional[Literal["small", "standard", "vision"]] = "standard", |
||||
max_length_answer=4096, |
||||
messages=None, |
||||
chat=True, |
||||
chosen_backend=None, |
||||
) -> None: |
||||
""" |
||||
Initialize the assistant with specified parameters. |
||||
|
||||
Args: |
||||
system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.". |
||||
num_ctx (int): The number of context tokens to use. Defaults to 4096. |
||||
temperature (float): The temperature setting for the model's response generation. Defaults to 0.01. |
||||
chat (bool): Flag to indicate if the assistant is in chat mode. Defaults to True. |
||||
model (str): The model type to use. Defaults to "standard". Alternatives: 'small', 'standard', 'vision'. |
||||
max_length_answer (int): The maximum length of the generated answer. Defaults to 4000. |
||||
chosen_backend (str): The chosen ollama server for the request. Defaults to None. |
||||
|
||||
Returns: |
||||
None |
||||
""" |
||||
self.model = self.get_model(model) |
||||
self.system_message = system_message |
||||
self.options = {"temperature": temperature, "num_ctx": num_ctx} |
||||
self.messages = messages or [{"role": "system", "content": self.system_message}] |
||||
self.max_length_answer = max_length_answer |
||||
self.chat = chat |
||||
self.chosen_backend = chosen_backend |
||||
|
||||
def get_model(self, model_alias): |
||||
models = { |
||||
"standard": "LLM_MODEL", |
||||
"small": "LLM_MODEL_SMALL", |
||||
"vision": "LLM_MODEL_VISION", |
||||
"standard_64k": "LLM_MODEL_64K", |
||||
} |
||||
return os.getenv(models.get(model_alias, "LLM_MODEL")) |
||||
|
||||
def count_tokens(self): |
||||
num_tokens = 0 |
||||
for i in self.messages: |
||||
for k, v in i.items(): |
||||
if k == "content": |
||||
if not isinstance(v, str): |
||||
v = str(v) |
||||
tokens = tokenizer.encode(v) |
||||
num_tokens += len(tokens) |
||||
return int(num_tokens) |
||||
|
||||
def read_stream(self, response): |
||||
""" |
||||
Reads a stream of data from the given response object and yields the content of each message. |
||||
|
||||
Args: |
||||
response (requests.Response): The response object to read the stream from. |
||||
|
||||
Yields: |
||||
str: The content of each message in the stream. |
||||
|
||||
Notes: |
||||
- The response is expected to provide data in chunks, which are decoded as UTF-8. |
||||
- Lines are split by newline characters. |
||||
- Each line is expected to be a JSON object containing a "message" key with a "content" field. |
||||
- If a chunk cannot be decoded as UTF-8, it is skipped. |
||||
- If a line cannot be parsed as JSON, it is skipped. |
||||
""" |
||||
buffer = "" |
||||
message = "" |
||||
for chunk in response.iter_content(chunk_size=64): |
||||
if chunk: |
||||
try: |
||||
message_part = chunk.decode("utf-8") |
||||
buffer += message_part |
||||
message += message_part |
||||
|
||||
except UnicodeDecodeError: |
||||
continue |
||||
while "\n" in buffer: |
||||
line, buffer = buffer.split("\n", 1) |
||||
if line: |
||||
try: |
||||
json_data = json.loads(line) |
||||
yield json_data["message"]["content"] |
||||
except json.JSONDecodeError: |
||||
continue |
||||
self.messages.append({"role": "assistant", "content": message.strip('"')}) |
||||
|
||||
def generate( |
||||
self, |
||||
query, |
||||
stream=False, |
||||
tools=None, |
||||
function_call=None, |
||||
images: list = None, |
||||
model: Optional[Literal["small", "standard", "vision"]] = None, |
||||
temperature=None, |
||||
): |
||||
""" |
||||
Generates a response from the language model based on the provided query and options. |
||||
Args: |
||||
query (str): The input query to be processed by the language model. |
||||
stream (bool, optional): Whether to stream the response. Defaults to False. |
||||
tools (list, optional): A list of tools to be used by the language model. Defaults to None. |
||||
function_call (dict, optional): A dictionary specifying a function call to be made by the language model. Defaults to None. |
||||
images (list, optional): A list of image paths or base64-encoded images to be included in the request. Defaults to None. |
||||
model (str, optional): The model alias to be used for generating the response. Defaults to None. Alternatives: 'small', 'standard', 'vision'. |
||||
Returns: |
||||
str: The generated response from the language model. If streaming is enabled, returns the streamed response. |
||||
""" |
||||
# Add custom header if large model is chosen |
||||
model = self.get_model(model) if model else self.model |
||||
temperature = temperature if temperature else self.options["temperature"] |
||||
|
||||
# Normalize whitespace and add the query to the messages |
||||
query = re.sub(r"\s*\n\s*", "\n", query) |
||||
message = {"role": "user", "content": query} |
||||
|
||||
headers = {"Content-Type": "application/json"} |
||||
from colorprinter.print_color import print_rainbow |
||||
|
||||
if images: |
||||
import base64 |
||||
|
||||
model = self.get_model("vision") |
||||
# Convert image paths to base64 |
||||
base64_images = [] |
||||
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") |
||||
|
||||
# Handle images |
||||
for image in images: |
||||
if isinstance(image, str): |
||||
if base64_pattern.match(image): |
||||
base64_images.append(image) |
||||
else: |
||||
with open(image, "rb") as image_file: |
||||
base64_images.append( |
||||
base64.b64encode(image_file.read()).decode("utf-8") |
||||
) |
||||
message["images"] = base64_images |
||||
# Set the Content-Type header based on the presence of images |
||||
headers = {"Content-Type": "application/json; images"} |
||||
|
||||
# Set the model type to the vision model |
||||
if self.chosen_backend: |
||||
headers["X-Chosen-Backend"] = self.chosen_backend |
||||
|
||||
|
||||
self.messages.append(message) |
||||
|
||||
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer |
||||
if self.chat or len(self.messages) > 15000: |
||||
num_tokens = self.count_tokens() + self.max_length_answer / 2 |
||||
if num_tokens < 8000 and "num_ctx" in self.options: |
||||
del self.options["num_ctx"] |
||||
else: |
||||
model = self.get_model("large") |
||||
headers["X-Model-Type"] = "standard_64k" |
||||
|
||||
if tools: |
||||
stream = False |
||||
|
||||
data = { |
||||
"messages": self.messages, |
||||
"stream": stream, |
||||
"keep_alive": 3600 * 24 * 7, |
||||
"model": model if model else self.model, |
||||
"options": self.options, |
||||
} |
||||
|
||||
# Include tools if provided |
||||
if tools: |
||||
data["tools"] = tools |
||||
|
||||
# Include function_call if provided |
||||
if function_call: |
||||
data["function_call"] = function_call |
||||
|
||||
|
||||
response = requests.post( |
||||
os.getenv("LLM_API_URL"), |
||||
headers=headers, |
||||
json=data, |
||||
auth=HTTPBasicAuth( |
||||
os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE") |
||||
), |
||||
stream=stream, |
||||
timeout=3600, |
||||
) |
||||
|
||||
self.chosen_backend = response.headers.get('X-Chosen-Backend') |
||||
|
||||
if response.status_code != 200: |
||||
print_red("Error!") |
||||
print_rainbow(response.content.decode("utf-8")) |
||||
|
||||
if response.status_code == 404: |
||||
return "Target endpoint not found" |
||||
|
||||
if response.status_code == 504: |
||||
return f"Gateway Timeout: {response.content.decode('utf-8')}" |
||||
|
||||
if stream: |
||||
return self.read_stream(response) |
||||
else: |
||||
try: |
||||
response_json = response.json() |
||||
if tools and not response_json["message"].get("tool_calls"): |
||||
print_yellow("No tool calls in response".upper()) |
||||
if "tool_calls" in response_json.get("message", {}): |
||||
# The LLM wants to invoke a function (tool) |
||||
result = response_json["message"] |
||||
else: |
||||
result = response_json["message"]["content"].strip('"') |
||||
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
||||
except requests.exceptions.JSONDecodeError: |
||||
print_red("Error: ", response.status_code, response.text) |
||||
return "An error occurred." |
||||
|
||||
if not self.chat: |
||||
self.messages = [self.messages[0]] |
||||
return result |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
llm = LLM() |
||||
images = ["th-2182728540.jpeg"] |
||||
print( |
||||
llm.generate( |
||||
query="Hi there", |
||||
) |
||||
) |
||||
@ -0,0 +1,79 @@ |
||||
import base64 |
||||
import os |
||||
from time import sleep |
||||
from mailersend import emails |
||||
import dotenv |
||||
|
||||
class MailSender: |
||||
def __init__(self): |
||||
dotenv.load_dotenv() |
||||
self.mailersend_api_key = os.getenv("MAILERSEND_API_KEY") |
||||
self.mailer = emails.NewEmail(mailersend_api_key=self.mailersend_api_key) |
||||
self.mail_body = {} |
||||
self.set_mail_to(recipients=input("Enter recipient email: ")) |
||||
self.set_mail_from() |
||||
self.set_reply_to("No Reply", "noreply@assistant.fish") |
||||
self.set_subject(input("Enter subject: ")) |
||||
plaintext_content = input("Enter plaintext content: ") |
||||
self.set_plaintext_content(plaintext_content) |
||||
html_content = input("Enter HTML content: ") |
||||
if html_content == "": |
||||
html_content = plaintext_content |
||||
self.set_html_content(html_content) |
||||
attachment_path = input("Path to image: ") |
||||
if attachment_path != "": |
||||
self.set_attachments(attachment_path) |
||||
# TODO: Add support for multiple attachments and other types of attachments |
||||
|
||||
self.send_mail() |
||||
|
||||
def set_mail_from(self): |
||||
mail_from = {"name": "SCI Fish", "email": 'sci@assistant.fish'} |
||||
self.mailer.set_mail_from(mail_from, self.mail_body) |
||||
|
||||
def set_mail_to(self, recipients): |
||||
if isinstance(recipients, str): |
||||
recipients = [{"name": recipients, "email": recipients}] |
||||
elif isinstance(recipients, list): |
||||
recipients = [{"name": i, "email": i} for i in recipients] |
||||
self.mailer.set_mail_to(recipients, self.mail_body) |
||||
|
||||
def set_subject(self, subject): |
||||
self.mailer.set_subject(subject, self.mail_body) |
||||
|
||||
def set_html_content(self, html_content): |
||||
self.mailer.set_html_content(html_content, self.mail_body) |
||||
|
||||
def set_plaintext_content(self, plaintext_content): |
||||
self.mailer.set_plaintext_content(plaintext_content, self.mail_body) |
||||
|
||||
def set_reply_to(self, name, email): |
||||
reply_to = {"name": name, "email": email} |
||||
self.mailer.set_reply_to(reply_to, self.mail_body) |
||||
|
||||
def set_attachments(self, file_path): |
||||
with open(file_path, "rb") as attachment: |
||||
att_read = attachment.read() |
||||
att_base64 = base64.b64encode(bytes(att_read)) |
||||
attachments = [ |
||||
{ |
||||
"id": os.path.basename(file_path), |
||||
"filename": os.path.basename(file_path), |
||||
"content": f"{att_base64.decode('ascii')}", |
||||
"disposition": "attachment", |
||||
} |
||||
] |
||||
self.mailer.set_attachments(attachments, self.mail_body) |
||||
|
||||
def send_mail(self): |
||||
r = self.mailer.send(self.mail_body) |
||||
sleep(4) # wait for email to be sent |
||||
if r.split("\n")[0].strip() != "202": |
||||
print("Error sending email") |
||||
else: |
||||
print("Email sent successfully") |
||||
|
||||
# Example usage |
||||
if __name__ == "__main__": |
||||
mail_sender = MailSender() |
||||
|
||||
@ -0,0 +1,781 @@ |
||||
import io |
||||
import os |
||||
import re |
||||
from time import sleep |
||||
from datetime import datetime |
||||
|
||||
import crossref_commons.retrieval as crossref |
||||
import pymupdf |
||||
import pymupdf4llm |
||||
import requests |
||||
from bs4 import BeautifulSoup |
||||
from pymupdf import Document |
||||
from semantic_text_splitter import MarkdownSplitter |
||||
from pyppeteer import launch |
||||
from arango.collection import StandardCollection as ArangoCollection |
||||
from arango.database import StandardDatabase as ArangoDatabase |
||||
import xml.etree.ElementTree as ET |
||||
from streamlit.runtime.uploaded_file_manager import UploadedFile |
||||
import streamlit as st |
||||
|
||||
from _arango import ArangoDB |
||||
from _chromadb import ChromaDB |
||||
from _llm import LLM |
||||
from colorprinter.print_color import * |
||||
from utils import fix_key |
||||
|
||||
|
||||
class Document: |
||||
def __init__( |
||||
self, |
||||
pdf_file=None, |
||||
filename=None, |
||||
doi=None, |
||||
username=None, |
||||
is_sci=None, |
||||
is_image=False, |
||||
): |
||||
self.filename = filename |
||||
self.pdf_file = pdf_file |
||||
self.doi = doi |
||||
self.username = username |
||||
self.is_sci = is_sci |
||||
self.is_image = is_image |
||||
self.pdf = None |
||||
self._key = None |
||||
self._id = None |
||||
|
||||
self.chunks = [] |
||||
self.metadata = None |
||||
self.title = None |
||||
self.open_access = False |
||||
self.file_path = None |
||||
self.download_folder = None |
||||
self.text = "" |
||||
self.arango_db_name = None |
||||
self.document_type = None |
||||
|
||||
if self.pdf_file: |
||||
self.open_pdf(self.pdf_file) |
||||
|
||||
|
||||
|
||||
def make_summary_in_background(self): |
||||
data = { |
||||
"text": self.text, |
||||
"arango_db_name": self.arango_db_name, |
||||
"_id": self._id, |
||||
"is_sci": self.is_sci, |
||||
} |
||||
|
||||
# Send the data to the FastAPI server |
||||
url = "http://192.168.1.11:8100/summarise_document" |
||||
requests.post(url, json=data) |
||||
|
||||
def open_pdf(self, pdf_file): |
||||
st.write(f"Reading the file...") |
||||
if isinstance(pdf_file, bytes): |
||||
from io import BytesIO |
||||
|
||||
pdf_file = BytesIO(pdf_file) |
||||
|
||||
if isinstance(pdf_file, str): |
||||
self.pdf: Document = pymupdf.open(pdf_file) |
||||
elif isinstance(pdf_file, io.BytesIO): |
||||
try: |
||||
self.pdf: Document = pymupdf.open(stream=pdf_file, filetype="pdf") |
||||
except: |
||||
pdf_bytes = pdf_file.read() |
||||
pdf_stream = io.BytesIO(pdf_bytes) |
||||
self.pdf: Document = pymupdf.open(stream=pdf_stream, filetype="pdf") |
||||
|
||||
def extract_text(self): |
||||
md_pages = pymupdf4llm.to_markdown( |
||||
self.pdf, page_chunks=True, show_progress=False |
||||
) |
||||
md_text = "" |
||||
for page in md_pages: |
||||
md_text += f"{page['text'].strip()}\n@{page['metadata']['page']}@\n" |
||||
|
||||
md_text = re.sub(r"[-]{3,}", "", md_text) |
||||
md_text = re.sub(r"\n{3,}", "\n\n", md_text) |
||||
md_text = re.sub(r"\s{2,}", " ", md_text) |
||||
md_text = re.sub(r"\s*\n\s*", "\n", md_text) |
||||
|
||||
self.text = md_text |
||||
|
||||
def make_chunks(self, len_chunks=2200): |
||||
better_chunks = [] |
||||
|
||||
ts = MarkdownSplitter(len_chunks) |
||||
chunks = ts.chunks(self.text) |
||||
for chunk in chunks: |
||||
if len(chunk) < 40 and len(chunks) > 1: |
||||
continue |
||||
elif all( |
||||
[ |
||||
len(chunk) < int(len_chunks / 3), |
||||
len(chunks[-1]) < int(len_chunks * 1.5), |
||||
len(better_chunks) > 0, |
||||
] |
||||
): |
||||
better_chunks[-1] += chunk |
||||
else: |
||||
better_chunks.append(chunk.strip()) |
||||
|
||||
self.chunks = better_chunks |
||||
|
||||
def get_title(self, only_meta=False): |
||||
""" |
||||
Extracts the title from the PDF metadata or generates a title based on the filename. |
||||
|
||||
Args: |
||||
only_meta (bool): If True, only attempts to retrieve the title from metadata. |
||||
If False, generates a title from the filename if metadata is not available. |
||||
|
||||
Returns: |
||||
str: The title of the PDF if found in metadata or generated from the filename. |
||||
Returns None if only_meta is True and no title is found in metadata. |
||||
|
||||
Raises: |
||||
AssertionError: If only_meta is False and no PDF file is provided to generate a title. |
||||
""" |
||||
xml_metadata = self.pdf.get_xml_metadata() |
||||
|
||||
if not xml_metadata.strip(): |
||||
return None |
||||
|
||||
try: |
||||
root = ET.fromstring(xml_metadata) |
||||
except ET.ParseError: |
||||
return None |
||||
|
||||
namespaces = {} |
||||
for elem in root.iter(): |
||||
if elem.tag.startswith("{"): |
||||
uri, tag = elem.tag[1:].split("}") |
||||
prefix = uri.split("/")[-1] |
||||
namespaces[prefix] = uri |
||||
|
||||
namespaces["rdf"] = "http://www.w3.org/1999/02/22-rdf-syntax-ns#" |
||||
namespaces["dc"] = "http://purl.org/dc/elements/1.1/" |
||||
|
||||
title_element = root.find( |
||||
".//rdf:Description/dc:title/rdf:Alt/rdf:li", namespaces |
||||
) |
||||
|
||||
if title_element is not None: |
||||
self.title = title_element.text |
||||
return title_element.text |
||||
else: |
||||
if only_meta: |
||||
return None |
||||
else: |
||||
assert ( |
||||
self.pdf_file |
||||
), "PDF file must be provided to generate a title if no title in metadata." |
||||
try: |
||||
filename = self.pdf_file.split("/")[-1].replace(".pdf", "") |
||||
except: |
||||
filename = self.pdf_file.name.split("/")[-1].replace(".pdf", "") |
||||
self.title = f"{filename}_{datetime.now().strftime('%Y%m%d%H%M%S')}" |
||||
return self.title |
||||
|
||||
def save_pdf(self, document_type): |
||||
assert ( |
||||
self.is_sci or self.username |
||||
), "To save a PDF username must be provided for non-sci articles." |
||||
|
||||
if self.is_sci: |
||||
download_folder = "sci_articles" |
||||
else: |
||||
download_folder = f"user_data/{self.username}/{document_type}" |
||||
|
||||
if not os.path.exists(download_folder): |
||||
os.makedirs(download_folder) |
||||
self.download_folder = download_folder |
||||
|
||||
if self.doi and not document_type == "notes": |
||||
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_") |
||||
if not os.path.exists(self.file_path): |
||||
self.file_path = f"{self.download_folder}/{fix_key(self.doi)}.pdf" |
||||
self.pdf.save(self.file_path) |
||||
else: |
||||
self.file_path = self.set_filename(self.get_title()) |
||||
if not self.file_path: |
||||
try: |
||||
self.file_path = self.pdf_file.name |
||||
except: |
||||
self.file_path = self.pdf_file.split("/")[-1] |
||||
self.pdf.save(self.file_path) |
||||
|
||||
return self.file_path |
||||
|
||||
def set_filename(self, filename=None): |
||||
if self.is_sci and not self.document_type == "notes": |
||||
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_") |
||||
return os.path.exists(self.file_path) |
||||
else: |
||||
file_path = f"{self.download_folder}/{filename}" |
||||
while os.path.exists(file_path + ".pdf"): |
||||
if not re.search(r"(_\d+)$", file_path): |
||||
file_path += "_1" |
||||
else: |
||||
file_path = re.sub( |
||||
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path |
||||
) |
||||
self.file_path = file_path + ".pdf" |
||||
return file_path |
||||
|
||||
|
||||
class Processor: |
||||
def __init__( |
||||
self, |
||||
document: Document, |
||||
filename: str = None, |
||||
chroma_db: str = "sci_articles", |
||||
len_chunks: int = 2200, |
||||
local_chroma_deployment: bool = False, |
||||
process: bool = True, |
||||
document_type: str = None, |
||||
): |
||||
self.document = document |
||||
self.chromadb = ChromaDB(local_deployment=local_chroma_deployment, db=chroma_db) |
||||
self.len_chunks = len_chunks |
||||
self.document_type = document_type |
||||
|
||||
self._id = None |
||||
|
||||
if process: |
||||
self.process_document() |
||||
|
||||
def get_arango(self, db_name=None, document_type=None): |
||||
if db_name and document_type: |
||||
arango = ArangoDB(db_name=db_name) |
||||
arango_collection = arango.db.collection(document_type) |
||||
elif self.document.is_sci: |
||||
arango = ArangoDB(db_name="base") |
||||
arango_collection = arango.db.collection("sci_articles") |
||||
elif self.document.open_access: |
||||
arango = ArangoDB(db_name="base") |
||||
arango_collection = arango.db.collection("other_documents") |
||||
else: |
||||
arango = ArangoDB(db_name=self.document.username) |
||||
arango_collection: ArangoCollection = arango.db.collection( |
||||
self.document_type |
||||
) |
||||
self.document.arango_db_name = arango.db.name |
||||
self.arango_collection = arango_collection |
||||
return arango_collection |
||||
|
||||
|
||||
def extract_doi(self, text, multi=False): |
||||
doi_pattern = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+" |
||||
|
||||
if multi: |
||||
dois = re.findall(doi_pattern, text) |
||||
processed_dois = [doi.strip(".").replace(".pdf", "") for doi in dois] |
||||
return processed_dois if processed_dois else None |
||||
else: |
||||
doi = re.search(doi_pattern, text) |
||||
if doi: |
||||
doi = doi.group() |
||||
doi = doi.strip(".").replace(".pdf", "") |
||||
if self.get_crossref(doi): |
||||
self.document.metadata = self.get_crossref(doi) |
||||
self.document.doi = doi |
||||
else: |
||||
for page in self.document.pdf.pages(0, 6): |
||||
text = page.get_text() |
||||
if re.search(doi_pattern, text): |
||||
llm = LLM( |
||||
temperature=0.01, |
||||
system_message='You are an assistant helping a user to extract the DOI from a scientific article. \ |
||||
A DOI always starts with "10." and is followed by a series of numbers and letters, and a "/" in the middle.\ |
||||
Sometimes the DOI is split by a line break, so be sure to check for that.', |
||||
max_length_answer=50, |
||||
) |
||||
prompt = f''' |
||||
This is the text of an article: |
||||
""" |
||||
{text} |
||||
""" |
||||
I want you to find the DOI of the article. Ansewer ONLY with the DOI, nothing else. |
||||
If you can't find the DOI, answer "not_found". |
||||
''' |
||||
st.write('Trying to extract DOI from text using LLM...') |
||||
doi = llm.generate(prompt).replace('https://doi.org/', '') |
||||
if doi == "not_found": |
||||
return None |
||||
else: |
||||
doi = re.search(doi_pattern, doi).group() |
||||
break |
||||
|
||||
return doi |
||||
else: |
||||
return None |
||||
|
||||
def chunks2chroma(self, _id, key): |
||||
st.write("Adding to vector database...") |
||||
assert self.document.text, "Document must have 'text' attribute." |
||||
|
||||
ids = [] |
||||
documents = [] |
||||
metadatas = [] |
||||
|
||||
last_page = 1 |
||||
for i, chunk in enumerate(self.document.chunks): |
||||
page_numbers = re.findall(r"@(\d+)@", chunk) |
||||
if page_numbers == []: |
||||
page_numbers = [last_page] |
||||
else: |
||||
last_page = page_numbers[-1] |
||||
id = fix_key(f"{key}_{i}") |
||||
ids.append(id) |
||||
|
||||
metadata = { |
||||
"_key": id, |
||||
"file": self.document.file_path, |
||||
"chunk_nr": i, |
||||
"pages": ",".join([str(i) for i in page_numbers]), |
||||
"_id": _id, |
||||
} |
||||
if self.document.doi: |
||||
metadata["doi"] = self.document.doi |
||||
metadatas.append(metadata) |
||||
|
||||
chunk = re.sub(r"@(\d+)@", "", chunk) |
||||
documents.append(chunk) |
||||
|
||||
if self.document.is_sci: |
||||
chroma_collection = self.chromadb.db.get_or_create_collection( |
||||
"sci_articles" |
||||
) |
||||
else: |
||||
chroma_collection = self.chromadb.db.get_or_create_collection( |
||||
"other_documents" |
||||
) |
||||
|
||||
chroma_collection.add(ids=ids, documents=documents, metadatas=metadatas) |
||||
|
||||
def chunks2arango(self): |
||||
st.write("Adding to document database...") |
||||
assert self.document.text, "Document must have 'text' attribute." |
||||
if self.document.is_sci: |
||||
for key in ["doi", "metadata"]: |
||||
assert getattr( |
||||
self.document, key |
||||
), f"Document must have '{key}' attribute." |
||||
else: |
||||
assert ( |
||||
getattr(self.document, "_key", None) or self.document.doi |
||||
), "Document must have '_key' attribute or DOI." |
||||
|
||||
arango_collection = self.get_arango() |
||||
|
||||
if self.document.doi: |
||||
key = self.document.doi |
||||
else: |
||||
key = self.document._key |
||||
|
||||
arango_chunks = [] |
||||
|
||||
last_page = 1 |
||||
for i, chunk in enumerate(self.document.chunks): |
||||
page_numbers = re.findall(r"@(\d+)@", chunk) |
||||
if page_numbers == []: |
||||
page_numbers = [last_page] |
||||
else: |
||||
last_page = page_numbers[-1] |
||||
id = fix_key(key) + f"_{i}" |
||||
|
||||
chunk = re.sub(r"@(\d+)@", "", chunk) |
||||
|
||||
arango_chunks.append({"text": chunk, "pages": page_numbers, "id": id}) |
||||
|
||||
if not hasattr(self.document, "_key"): |
||||
self.document._key = fix_key(key) |
||||
|
||||
user_access = [self.document.username] |
||||
if not self.document.open_access: |
||||
if arango_collection.has(self.document._key): |
||||
doc = arango_collection.get(self.document._key) |
||||
if "user_access" in doc: |
||||
if doc["user_access"]: |
||||
if self.document.username not in doc["user_access"]: |
||||
user_access = doc["user_access"] + [self.document.username] |
||||
else: |
||||
user_access = [self.document.username] |
||||
if self.document.open_access: |
||||
user_access = None |
||||
|
||||
arango_document = { |
||||
"_key": fix_key(self.document._key), |
||||
"file": self.document.file_path, |
||||
"chunks": arango_chunks, |
||||
"text": self.document.text, |
||||
"open_access": self.document.open_access, |
||||
"user_access": user_access, |
||||
"doi": self.document.doi, |
||||
"metadata": self.document.metadata, |
||||
"filename": self.document.filename, |
||||
} |
||||
|
||||
if self.document.metadata and self.document.is_sci: |
||||
if "abstract" in self.document.metadata: |
||||
if isinstance(self.document.metadata["abstract"], str): |
||||
self.document.metadata["abstract"] = re.sub( |
||||
r"<[^>]*>", "", self.document.metadata["abstract"] |
||||
) |
||||
arango_document["metadata"] = self.document.metadata |
||||
arango_document["summary"] = { |
||||
"text_sum": self.document.metadata["abstract"], |
||||
"meta": {"model": "from_metadata"}, |
||||
} |
||||
|
||||
arango_document["crossref"] = True |
||||
|
||||
doc = arango_collection.insert( |
||||
arango_document, overwrite=True, overwrite_mode="update", keep_none=False |
||||
) |
||||
self.document._id = doc["_id"] |
||||
|
||||
if "summary" not in arango_document: |
||||
# Make a summary in the background |
||||
self.document.make_summary_in_background() |
||||
|
||||
return doc["_id"], key |
||||
|
||||
def llm2metadata(self): |
||||
st.write("Extracting metadata using LLM...") |
||||
llm = LLM( |
||||
temperature=0.01, |
||||
system_message="You are an assistant helping a user to extract metadata from a scientific article.", |
||||
model="small", |
||||
max_length_answer=500, |
||||
) |
||||
text = pymupdf4llm.to_markdown( |
||||
self.document.pdf, page_chunks=False, show_progress=False, pages=[0, 1] |
||||
) |
||||
if len(self.document.pdf) == 1: |
||||
pages = [0] |
||||
prompt = f''' |
||||
Below is the beginning of an article. I want to know when it's published, the title, and the journal. |
||||
|
||||
""" |
||||
{text} |
||||
""" |
||||
|
||||
Answer ONLY with the information requested. |
||||
I want to know the published date on the form "YYYY-MM-DD". |
||||
I want the full title of the article and the journal. |
||||
Be sure to answer on the form "published_date;title;journal" as the answer will be used in a CSV. |
||||
If you can't find the information, answer "not_found". |
||||
''' |
||||
result = llm.generate(prompt) |
||||
print_blue(result) |
||||
if result == "not_found": |
||||
return None |
||||
else: |
||||
parts = result.split(";", 2) |
||||
if len(parts) != 3: |
||||
return None |
||||
published_date, title, journal = parts |
||||
if published_date == "not_found": |
||||
published_date = "[Unknown date]" |
||||
else: |
||||
try: |
||||
published_year = int(published_date.split("-")[0]) |
||||
except: |
||||
published_year = None |
||||
if title == "not_found": |
||||
title = "[Unknown title]" |
||||
if journal == "not_found": |
||||
journal = "[Unknown publication]" |
||||
return { |
||||
"published_date": published_date, |
||||
"published_year": published_year, |
||||
"title": title, |
||||
"journal": journal, |
||||
} |
||||
|
||||
def get_crossref(self, doi): |
||||
try: |
||||
print(f"Retrieving metadata for DOI {doi}...") |
||||
work = crossref.get_publication_as_json(doi) |
||||
print_green(f"Metadata retrieved for DOI {doi}.") |
||||
if "published-print" in work: |
||||
publication_date = work["published-print"]["date-parts"][0] |
||||
elif "published-online" in work: |
||||
publication_date = work["published-online"]["date-parts"][0] |
||||
elif "issued" in work: |
||||
publication_date = work["issued"]["date-parts"][0] |
||||
else: |
||||
publication_date = [None] |
||||
publication_year = publication_date[0] |
||||
|
||||
metadata = { |
||||
"doi": work.get("DOI", None), |
||||
"title": work.get("title", [None])[0], |
||||
"authors": [ |
||||
f"{author['given']} {author['family']}" |
||||
for author in work.get("author", []) |
||||
], |
||||
"abstract": work.get("abstract", None), |
||||
"journal": work.get("container-title", [None])[0], |
||||
"volume": work.get("volume", None), |
||||
"issue": work.get("issue", None), |
||||
"pages": work.get("page", None), |
||||
"published_date": "-".join(map(str, publication_date)), |
||||
"published_year": publication_year, |
||||
"url_doi": work.get("URL", None), |
||||
"link": ( |
||||
work.get("link", [None])[0]["URL"] |
||||
if work.get("link", None) |
||||
else None |
||||
), |
||||
"language": work.get("language", None), |
||||
} |
||||
if "abstract" in metadata and isinstance(metadata["abstract"], str): |
||||
metadata["abstract"] = re.sub(r"<[^>]*>", "", metadata["abstract"]) |
||||
self.document.metadata = metadata |
||||
self.document.is_sci = True |
||||
return metadata |
||||
|
||||
except Exception as e: |
||||
if not self.document.is_sci: |
||||
self.document.is_sci = False |
||||
return None |
||||
|
||||
def check_doaj(self, doi): |
||||
url = f"https://doaj.org/api/search/articles/{doi}" |
||||
response = requests.get(url) |
||||
if response.status_code == 200: |
||||
data = response.json() |
||||
if data.get("results", []) == []: |
||||
print(f"DOI {doi} not found in DOAJ.") |
||||
return False |
||||
else: |
||||
return data |
||||
else: |
||||
print( |
||||
f"Error fetching metadata for DOI from DOAJ: {doi}. HTTP Status Code: {response.status_code}" |
||||
) |
||||
return |
||||
|
||||
def process_document(self): |
||||
assert self.document.pdf_file or self.document.pdf, "PDF file must be provided." |
||||
if not self.document.pdf: |
||||
self.document.open_pdf(self.document.pdf_file) |
||||
|
||||
if self.document.is_image: |
||||
return pymupdf4llm.to_markdown( |
||||
self.document.pdf, page_chunks=False, show_progress=False |
||||
) |
||||
self.document.title = self.document.get_title() |
||||
|
||||
if not self.document.doi and self.document.filename: |
||||
self.document.doi = self.extract_doi(self.document.filename) |
||||
if not self.document.doi: |
||||
text = "" |
||||
for page in self.document.pdf.pages(0, 6): |
||||
text += page.get_text() |
||||
self.document.doi = self.extract_doi(text) |
||||
|
||||
if self.document.doi: |
||||
self.document._key = fix_key(self.document.doi) |
||||
if self.check_doaj(self.document.doi): |
||||
self.document.open_access = True |
||||
self.document.is_sci = True |
||||
self.document.metadata = self.get_crossref(self.document.doi) |
||||
if not self.document.is_sci: |
||||
self.document.is_sci = bool(self.document.metadata) |
||||
|
||||
|
||||
arango_collection = self.get_arango() |
||||
|
||||
doc = arango_collection.get(self.document._key) if self.document.doi else None |
||||
|
||||
if doc: |
||||
print_green(f"Document with key {self.document._key} already in database.") |
||||
self.document.doc = doc |
||||
crossref = self.get_crossref(self.document.doi) |
||||
if crossref: |
||||
self.document.doc["metadata"] = crossref |
||||
elif "metadata" not in doc or not doc["metadata"]: |
||||
self.document.doc["metadata"] = { |
||||
"title": self.document.get_title(only_meta=True) |
||||
} |
||||
|
||||
elif 'title' not in doc['metadata']: |
||||
self.document.doc["metadata"]["title"] = self.document.get_title(only_meta=True) |
||||
|
||||
|
||||
if "user_access" not in doc or doc['user_access'] == None: |
||||
self.document.doc["user_access"] = [self.document.username] |
||||
else: |
||||
if self.document.username not in doc['user_access']: |
||||
self.document.doc["user_access"] = doc.get("user_access", []) + [ |
||||
self.document.username |
||||
] |
||||
self.metadata = self.document.doc["metadata"] |
||||
arango_collection.update(self.document.doc) |
||||
return doc["_id"], arango_collection.db_name, self.document.doi |
||||
|
||||
else: |
||||
self.document.doc = ( |
||||
{"doi": self.document.doi, "_key": fix_key(self.document.doi)} |
||||
if self.document.doi |
||||
else {} |
||||
) |
||||
if self.document.doi: |
||||
if not self.document.metadata: |
||||
self.document.metadata = self.get_crossref(self.document.doi) |
||||
if self.document.metadata: |
||||
self.document.doc["metadata"] = self.document.metadata or { |
||||
"title": self.document.get_title(only_meta=True) |
||||
} |
||||
else: |
||||
self.document.doc["metadata"] = self.llm2metadata() |
||||
if self.document.get_title(only_meta=True): |
||||
self.document.doc["metadata"]["title"] = ( |
||||
self.document.get_title(only_meta=True) |
||||
) |
||||
else: |
||||
self.document.doc["metadata"] = self.llm2metadata() |
||||
if self.document.get_title(only_meta=True): |
||||
self.document.doc["metadata"]["title"] = self.document.get_title( |
||||
only_meta=True |
||||
) |
||||
if "_key" not in self.document.doc: |
||||
_key = ( |
||||
self.document.doi |
||||
or self.document.title |
||||
or self.document.get_title() |
||||
) |
||||
print_yellow(f"Document key: {_key}") |
||||
print(self.document.doi, self.document.title, self.document.get_title()) |
||||
self.document.doc["_key"] = fix_key(_key) |
||||
self.document._key = fix_key(_key) |
||||
self.document.metadata = self.document.doc["metadata"] |
||||
if not self.document.text: |
||||
self.document.extract_text() |
||||
|
||||
if self.document.doi: |
||||
self.document.doc["doi"] = self.document.doi |
||||
self.document.doc["doi"] = self.document.doi |
||||
self.document._key = fix_key(self.document.doi) |
||||
|
||||
self.document.save_pdf(self.document_type) |
||||
|
||||
self.document.make_chunks() |
||||
|
||||
_id, key = self.chunks2arango() |
||||
self.chunks2chroma(_id=_id, key=key) |
||||
|
||||
self._id = _id |
||||
return _id, arango_collection.db_name, self.document.doi |
||||
|
||||
async def dl_pyppeteer(self, doi, url): |
||||
browser = await launch( |
||||
headless=True, args=["--no-sandbox", "--disable-setuid-sandbox"] |
||||
) |
||||
page = await browser.newPage() |
||||
await page.setUserAgent( |
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X x.y; rv:10.0) Gecko/20100101 Firefox/10.0" |
||||
) |
||||
await page.goto(url) |
||||
await page.waitFor(5000) |
||||
content = await page.content() |
||||
await page.pdf({"path": f"{doi}.pdf".replace("/", "_"), "format": "A4"}) |
||||
|
||||
await browser.close() |
||||
|
||||
def doi2pdf(self, doi): |
||||
url = None |
||||
downloaded = False |
||||
path = None |
||||
in_db = False |
||||
sci_articles = self.get_arango(db_name="base", document_type="sci_articles") |
||||
if sci_articles.has(fix_key(doi)): |
||||
in_db = True |
||||
downloaded = True |
||||
doc = sci_articles.get(fix_key(doi)) |
||||
url = doc["metadata"]["link"] |
||||
path = doc["file"] |
||||
print_green(f"Article {doi} already in database.") |
||||
return downloaded, url, doc["file"], in_db |
||||
|
||||
doaj_data = self.check_doaj(doi) |
||||
sleep(0.5) |
||||
if doaj_data: |
||||
for link in doaj_data.get("bibjson", {}).get("link", []): |
||||
if "mdpi.com" in link["url"]: |
||||
r = requests.get(link["url"]) |
||||
soup = BeautifulSoup(r.content, "html.parser") |
||||
pdf_link_html = soup.find("a", {"class": "UD_ArticlePDF"}) |
||||
pdf_url = "https://www.mdpi.com" + pdf_link_html["href"] |
||||
pdf = requests.get(pdf_url) |
||||
|
||||
path = f"sci_articles/{doi}.pdf".replace("/", "_") |
||||
|
||||
with open(path, "wb") as f: |
||||
f.write(pdf.content) |
||||
self.process_document() |
||||
print(f"Downloaded PDF for {doi}") |
||||
downloaded = True |
||||
url = link["url"] |
||||
|
||||
else: |
||||
downloaded = False |
||||
|
||||
else: |
||||
metadata = self.get_crossref(doi) |
||||
if metadata: |
||||
url = metadata["link"] |
||||
else: |
||||
print(f"Error fetching metadata for DOI: {doi}") |
||||
|
||||
return downloaded, url, path, in_db |
||||
|
||||
|
||||
class PDFProcessor(Processor): |
||||
def __init__( |
||||
self, |
||||
pdf_file=None, |
||||
filename=None, |
||||
chroma_db: str = "sci_articles", |
||||
document_type: str = None, |
||||
len_chunks: int = 2200, |
||||
local_chroma_deployment: bool = False, |
||||
process: bool = True, |
||||
doi=False, |
||||
username=None, |
||||
is_sci=None, |
||||
is_image=False, |
||||
): |
||||
self.document = Document( |
||||
pdf_file=pdf_file, |
||||
filename=filename, |
||||
doi=doi, |
||||
username=username, |
||||
is_sci=is_sci, |
||||
is_image=is_image, |
||||
) |
||||
super().__init__( |
||||
document=self.document, |
||||
filename=filename, |
||||
chroma_db=chroma_db, |
||||
len_chunks=len_chunks, |
||||
local_chroma_deployment=local_chroma_deployment, |
||||
process=process, |
||||
document_type=document_type, |
||||
) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
doi = "10.1007/s10584-019-02646-9" |
||||
print(f"Processing article with DOI: {doi}") |
||||
ap = PDFProcessor(doi=doi, process=False) |
||||
print(f"Downloading article with DOI: {doi}") |
||||
ap.doi2pdf(doi) |
||||
@ -0,0 +1,4 @@ |
||||
from PyPaperBot import __main__ as ppb |
||||
|
||||
ppb.start(DOIs=["10.1016/j.jas.2021.103086", "10.1016/j.jas.2021.103087", "10.1016/j.jas.2021.103088"]) |
||||
|
||||
@ -0,0 +1,192 @@ |
||||
country_emojis = { |
||||
"ad": "🇦🇩", |
||||
"ae": "🇦🇪", |
||||
"af": "🇦🇫", |
||||
"ag": "🇦🇬", |
||||
"ai": "🇦🇮", |
||||
"al": "🇦🇱", |
||||
"am": "🇦🇲", |
||||
"ao": "🇦🇴", |
||||
"aq": "🇦🇶", |
||||
"ar": "🇦🇷", |
||||
"as": "🇦🇸", |
||||
"at": "🇦🇹", |
||||
"au": "🇦🇺", |
||||
"aw": "🇦🇼", |
||||
"ax": "🇦🇽", |
||||
"az": "🇦🇿", |
||||
"ba": "🇧🇦", |
||||
"bb": "🇧🇧", |
||||
"bd": "🇧🇩", |
||||
"be": "🇧🇪", |
||||
"bf": "🇧🇫", |
||||
"bg": "🇧🇬", |
||||
"bh": "🇧🇭", |
||||
"bi": "🇧🇮", |
||||
"bj": "🇧🇯", |
||||
"bl": "🇧🇱", |
||||
"bm": "🇧🇲", |
||||
"bn": "🇧🇳", |
||||
"bo": "🇧🇴", |
||||
"bq": "🇧🇶", |
||||
"br": "🇧🇷", |
||||
"bs": "🇧🇸", |
||||
"bt": "🇧🇹", |
||||
"bv": "🇧🇻", |
||||
"bw": "🇧🇼", |
||||
"by": "🇧🇾", |
||||
"bz": "🇧🇿", |
||||
"ca": "🇨🇦", |
||||
"cc": "🇨🇨", |
||||
"cd": "🇨🇩", |
||||
"cf": "🇨🇫", |
||||
"cg": "🇨🇬", |
||||
"ch": "🇨🇭", |
||||
"ci": "🇨🇮", |
||||
"ck": "🇨🇰", |
||||
"cl": "🇨🇱", |
||||
"cm": "🇨🇲", |
||||
"cn": "🇨🇳", |
||||
"co": "🇨🇴", |
||||
"cr": "🇨🇷", |
||||
"cu": "🇨🇺", |
||||
"cv": "🇨🇻", |
||||
"cw": "🇨🇼", |
||||
"cx": "🇨🇽", |
||||
"cy": "🇨🇾", |
||||
"cz": "🇨🇿", |
||||
"de": "🇩🇪", |
||||
"dj": "🇩🇯", |
||||
"dk": "🇩🇰", |
||||
"dm": "🇩🇲", |
||||
"do": "🇩🇴", |
||||
"dz": "🇩🇿", |
||||
"ec": "🇪🇨", |
||||
"ee": "🇪🇪", |
||||
"eg": "🇪🇬", |
||||
"eh": "🇪🇭", |
||||
"er": "🇪🇷", |
||||
"es": "🇪🇸", |
||||
"et": "🇪🇹", |
||||
"fi": "🇫🇮", |
||||
"fj": "🇫🇯", |
||||
"fk": "🇫🇰", |
||||
"fm": "🇫🇲", |
||||
"fo": "🇫🇴", |
||||
"fr": "🇫🇷", |
||||
"ga": "🇬🇦", |
||||
"gb": "🇬🇧", |
||||
"gd": "🇬🇩", |
||||
"ge": "🇬🇪", |
||||
"gf": "🇬🇫", |
||||
"gg": "🇬🇬", |
||||
"gh": "🇬🇭", |
||||
"gi": "🇬🇮", |
||||
"gl": "🇬🇱", |
||||
"gm": "🇬🇲", |
||||
"gn": "🇬🇳", |
||||
"gp": "🇬🇵", |
||||
"gq": "🇬🇶", |
||||
"gr": "🇬🇷", |
||||
"gs": "🇬🇸", |
||||
"gt": "🇬🇹", |
||||
"gu": "🇬🇺", |
||||
"gw": "🇬🇼", |
||||
"gy": "🇬🇾", |
||||
"hk": "🇭🇰", |
||||
"hm": "🇭🇲", |
||||
"hn": "🇭🇳", |
||||
"hr": "🇭🇷", |
||||
"ht": "🇭🇹", |
||||
"hu": "🇭🇺", |
||||
"id": "🇮🇩", |
||||
"ie": "🇮🇪", |
||||
"il": "🇮🇱", |
||||
"im": "🇮🇲", |
||||
"in": "🇮🇳", |
||||
"io": "🇮🇴", |
||||
"iq": "🇮🇶", |
||||
"ir": "🇮🇷", |
||||
"is": "🇮🇸", |
||||
"it": "🇮🇹", |
||||
"je": "🇯🇪", |
||||
"jm": "🇯🇲", |
||||
"jo": "🇯🇴", |
||||
"jp": "🇯🇵", |
||||
"ke": "🇰🇪", |
||||
"kg": "🇰🇬", |
||||
"kh": "🇰🇭", |
||||
"ki": "🇰🇮", |
||||
"km": "🇰🇲", |
||||
"kn": "🇰🇳", |
||||
"kp": "🇰🇵", |
||||
"kr": "🇰🇷", |
||||
"kw": "🇰🇼", |
||||
"ky": "🇰🇾", |
||||
"kz": "🇰🇿", |
||||
"la": "🇱🇦", |
||||
"lb": "🇱🇧", |
||||
"lc": "🇱🇨", |
||||
"li": "🇱🇮", |
||||
"lk": "🇱🇰", |
||||
"lr": "🇱🇷", |
||||
"ls": "🇱🇸", |
||||
"lt": "🇱🇹", |
||||
"lu": "🇱🇺", |
||||
"lv": "🇱🇻", |
||||
"ly": "🇱🇾", |
||||
"ma": "🇲🇦", |
||||
"mc": "🇲🇨", |
||||
"md": "🇲🇩", |
||||
"me": "🇲🇪", |
||||
"mf": "🇲🇫", |
||||
"mg": "🇲🇬", |
||||
"mh": "🇲🇭", |
||||
"mk": "🇲🇰", |
||||
"ml": "🇲🇱", |
||||
"mm": "🇲🇲", |
||||
"mn": "🇲🇳", |
||||
"mo": "🇲🇴", |
||||
"mp": "🇲🇵", |
||||
"mq": "🇲🇶", |
||||
"mr": "🇲🇷", |
||||
"ms": "🇲🇸", |
||||
"mt": "🇲🇹", |
||||
"mu": "🇲🇺", |
||||
"mv": "🇲🇻", |
||||
"mw": "🇲🇼", |
||||
"mx": "🇲🇽", |
||||
"my": "🇲🇾", |
||||
"mz": "🇲🇿", |
||||
"na": "🇳🇦", |
||||
"nc": "🇳🇨", |
||||
"ne": "🇳🇪", |
||||
"nf": "🇳🇫", |
||||
"ng": "🇳🇬", |
||||
"ni": "🇳🇮", |
||||
"nl": "🇳🇱", |
||||
"no": "🇳🇴", |
||||
"np": "🇳🇵", |
||||
"nr": "🇳🇷", |
||||
"nu": "🇳🇺", |
||||
"nz": "🇳🇿", |
||||
"om": "🇴🇲", |
||||
"pa": "🇵🇦", |
||||
"pe": "🇵🇪", |
||||
"pf": "🇵🇫", |
||||
"pg": "🇵🇬", |
||||
"ph": "🇵🇭", |
||||
"pk": "🇵🇰", |
||||
"pl": "🇵🇱", |
||||
"pm": "🇵🇲", |
||||
"pn": "🇵🇳", |
||||
"pr": "🇵🇷", |
||||
"ps": "🇵🇸", |
||||
"pt": "🇵🇹", |
||||
"pw": "🇵🇼", |
||||
"py": "🇵🇾", |
||||
"qa": "🇶🇦", |
||||
"re": "🇷🇪", |
||||
"ro": "🇷🇴", |
||||
"rs": "🇷🇸", |
||||
} |
||||
@ -0,0 +1,53 @@ |
||||
from fastapi import FastAPI, BackgroundTasks |
||||
from pydantic import BaseModel |
||||
from typing import Optional |
||||
|
||||
from prompts import get_summary_prompt |
||||
from _llm import LLM |
||||
from _arango import ArangoDB |
||||
|
||||
app = FastAPI() |
||||
|
||||
|
||||
class DocumentData(BaseModel): |
||||
text: str |
||||
arango_db_name: str |
||||
arango_id: str |
||||
is_sci: Optional[bool] = False |
||||
|
||||
|
||||
@app.post("/summarise_document") |
||||
async def summarize_document(doc_data: DocumentData, background_tasks: BackgroundTasks): |
||||
background_tasks.add_task(summarise_document_task, doc_data.dict()) |
||||
return {"message": "Document summarization has started."} |
||||
|
||||
|
||||
def summarise_document_task(doc_data: dict): |
||||
text = doc_data.get("text") |
||||
is_sci = doc_data.get("is_sci", False) |
||||
|
||||
system_message = "You are summarising scientific articles. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English." |
||||
llm = LLM(system_message=system_message) |
||||
|
||||
summary = llm.generate(query=get_summary_prompt(text, is_sci)) |
||||
|
||||
summary_doc = { |
||||
"text_sum": summary, |
||||
"meta": { |
||||
"model": llm.model, |
||||
"system_message": system_message, |
||||
"temperature": llm.options["temperature"], |
||||
}, |
||||
} |
||||
|
||||
arango = ArangoDB(db_name=doc_data.get("arango_db_name")) |
||||
arango.db.update_document( |
||||
{"summary": summary_doc, "_id": doc_data.get("arango_id")}, |
||||
silent=True, |
||||
check_rev=False, |
||||
) |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
import uvicorn |
||||
uvicorn.run(app, host="0.0.0.0", port=8100) |
||||
@ -0,0 +1,27 @@ |
||||
|
||||
from typing import Callable, Dict, Any, List |
||||
|
||||
class ToolRegistry: |
||||
_tools = [] |
||||
|
||||
@classmethod |
||||
def register(cls, name: str, description: str, parameters: Dict[str, Any] = None): |
||||
def decorator(func: Callable): |
||||
cls._tools.append({ |
||||
"type": "function", |
||||
"function": { |
||||
"name": name, |
||||
"description": description, |
||||
"parameters": parameters or {} |
||||
} |
||||
}) |
||||
# No need for the wrapper since we're not adding any extra logic |
||||
return func |
||||
return decorator |
||||
|
||||
@classmethod |
||||
def get_tools(cls, tools: list = None) -> List[Dict[str, Any]]: |
||||
if tools: |
||||
return [tool for tool in cls._tools if tool['function']['name'] in tools] |
||||
else: |
||||
return cls._tools |
||||
@ -0,0 +1,97 @@ |
||||
import yaml |
||||
import sys |
||||
import bcrypt |
||||
from _arango import ArangoDB |
||||
import os |
||||
import dotenv |
||||
import getpass |
||||
|
||||
dotenv.load_dotenv() |
||||
|
||||
|
||||
def read_yaml(file_path): |
||||
with open(file_path, "r") as file: |
||||
return yaml.safe_load(file) |
||||
|
||||
|
||||
def write_yaml(file_path, data): |
||||
with open(file_path, "w") as file: |
||||
yaml.safe_dump(data, file) |
||||
|
||||
|
||||
def add_user(data, username, email, name, password): |
||||
# Check for existing username |
||||
if username in data["credentials"]["usernames"]: |
||||
print(f"Error: Username '{username}' already exists.") |
||||
sys.exit(1) |
||||
|
||||
# Check for existing email |
||||
for user in data["credentials"]["usernames"].values(): |
||||
if user["email"] == email: |
||||
print(f"Error: Email '{email}' already exists.") |
||||
sys.exit(1) |
||||
|
||||
# Hash the password using bcrypt |
||||
hashed_password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode( |
||||
"utf-8" |
||||
) |
||||
|
||||
# Add the new user |
||||
data["credentials"]["usernames"][username] = { |
||||
"email": email, |
||||
"name": name, |
||||
"password": hashed_password, |
||||
} |
||||
|
||||
|
||||
def make_arango(username): |
||||
root_user = os.getenv("ARANGO_ROOT_USER") |
||||
root_password = os.getenv("ARANGO_ROOT_PASSWORD") |
||||
arango = ArangoDB(user=root_user, password=root_password, db_name="_system") |
||||
|
||||
if not arango.db.has_database(username): |
||||
arango.db.create_database( |
||||
username, |
||||
users=[ |
||||
{ |
||||
"username": os.getenv("ARANGO_USER"), |
||||
"password": os.getenv("ARANGO_PASSWORD"), |
||||
"active": True, |
||||
"extra": {}, |
||||
} |
||||
] |
||||
) |
||||
arango = ArangoDB(user=root_user, password=root_password, db_name=username) |
||||
for collection in ["projects", "favorite_articles", "article_collections", "settings", 'chats', 'notes', 'other_documents']: |
||||
if not arango.db.has_collection(collection): |
||||
arango.db.create_collection(collection) |
||||
user_arango = ArangoDB(db_name=username) |
||||
user_arango.db.collection("settings").insert( |
||||
{"current_page": 'Bot Chat', "current_project": None} |
||||
) |
||||
|
||||
|
||||
def main(): |
||||
|
||||
yaml_file = "streamlit_users.yaml" |
||||
if len(sys.argv) == 5: |
||||
username = sys.argv[1] |
||||
email = sys.argv[2] |
||||
name = sys.argv[3] |
||||
password = sys.argv[4] |
||||
else: |
||||
username = input("Enter username: ") |
||||
email = input("Enter email: ") |
||||
name = input("Enter name: ") |
||||
password = getpass.getpass("Enter password: ") |
||||
|
||||
|
||||
data = read_yaml(yaml_file) |
||||
add_user(data, username, email, name, password) |
||||
make_arango(username) |
||||
write_yaml(yaml_file, data) |
||||
print(f"User {username} added successfully.") |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,187 @@ |
||||
|
||||
import re |
||||
#from _classes import Project |
||||
|
||||
def description_string(project: "Project"): |
||||
if project.description != "": |
||||
description = f'The project is about "{project.description}".' |
||||
else: |
||||
description = '' |
||||
return description |
||||
|
||||
def use_tools(use: bool = False): |
||||
if use: |
||||
return 'If you need to use a tool to fetch information, you can do that as well.' |
||||
else: |
||||
return '' |
||||
def get_assistant_prompt(): |
||||
""" |
||||
Returns a multi-line string that serves as a prompt for a research assistant AI. |
||||
|
||||
Returns: |
||||
str: The prompt for the research assistant AI. |
||||
""" |
||||
return """You are a research assistant chatting with a researcher. Only use the information from scientific articles you are provided with to answer questions. |
||||
The articles are ordered by publication date, so the first article is the oldest one. Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account. |
||||
Be sure to reference the source of the information with the number of the article inside square brackets (e.g. "<answer based on an article> [article number]"). |
||||
If you have to reference the articles in running text, e.g. in a headline or the beginning of a bullet point, use the title of the article. Do NOT write something like "Article 1" but use the actual title of the article. |
||||
You should not write a reference section as this will be added later. |
||||
Format your answers in Markdown format. """ |
||||
|
||||
def get_editor_prompt(project: "Project", tools: bool = False): |
||||
"""Generates a coaching prompt for an editor to assist a reporter with a specific project. |
||||
|
||||
Args: |
||||
project (dict): A dictionary containing information about the project the reporter is working on. The dictionary should have the following keys: |
||||
- name (str): The name of the project. |
||||
- description (str): The description of the project. |
||||
- notes_summary (str): A summary of notes about the project and |
||||
Returns: |
||||
str: A formatted string containing the coaching prompt for the editor.""" |
||||
|
||||
|
||||
if project.notes_summary: |
||||
notes_string = f'''Here are other important things you should know about the project and the topic: |
||||
""" |
||||
{project.notes_summary} |
||||
""" |
||||
''' |
||||
else: |
||||
notes_string = '' |
||||
|
||||
return f'''You are an editor coaching a journalist who is working on the project "{project.name}". {description_string(project)} |
||||
{notes_string} |
||||
When writing with the reporter you will also get other information, like excerpts from articles and other documents. Use the notes to put the information in context and help the reporter to move forward. |
||||
The project is a journalistic piece, so it is important that you help the reporter to be critical of the sources and to provide a balanced view of the topic. |
||||
Be sure to understand what the reporter is asking and provide the information in a way that is helpful for the reporter to move forward. Try to understand if the reporter is asking for a specific piece of information or if they are looking for guidance on how to move forward, or just want to discuss the topic. |
||||
If you need more information to answer the question, try to get it. |
||||
''' |
||||
|
||||
def get_chat_prompt(user_input, content_string, role): |
||||
if role == "Research Assistant": |
||||
|
||||
prompt = f'''{user_input} |
||||
Below are snippets from different articles, often with title and date of publication. |
||||
ONLY use the information below to answer the question. Do not use any other information. |
||||
|
||||
""" |
||||
{content_string} |
||||
""" |
||||
Remember: |
||||
- Reference the source of the information with the number of the article inside square brackets (e.g. "<answer based on an article> [article number]"). |
||||
- If you have to reference the articles in running text, e.g. in a headline or the beginning of a bullet point, use the title of the article. Do NOT write something like "Article 1" but use the actual title of the article. |
||||
- The articles are ordered by publication date, so the first article is the oldest one. Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account. |
||||
|
||||
{user_input} |
||||
''' |
||||
|
||||
elif role == "Editor": |
||||
prompt = f'''The reporter has asked: "{user_input}". Try to answer the question or provide guidance based on the information below, and your knowledge of the project. |
||||
|
||||
""" |
||||
{content_string} |
||||
""" |
||||
Remember: |
||||
- Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account. |
||||
- If you think you need more information to answer the question, ask the reporter for more context. |
||||
|
||||
{user_input} |
||||
''' |
||||
|
||||
elif role == "Guest": |
||||
prompt = f'''The podcast host has asked: "{user_input}". Try to answer the question based on the information below. |
||||
|
||||
""" |
||||
{content_string} |
||||
""" |
||||
Remember: |
||||
- Answer in a way that is easy to understand for a general audience. |
||||
- Only answer based on the information above. |
||||
- Answer in a "spoken" way, formatting the answer as if you were speaking it out loud. |
||||
|
||||
{user_input} |
||||
''' |
||||
|
||||
elif role == "Host": |
||||
prompt = f'''The expert has stated: "{user_input}". Try to come up with a new question based on the information below. |
||||
|
||||
""" |
||||
{content_string} |
||||
""" |
||||
Remember: |
||||
- The information above is the context for the expert's statement, so the new question should be relevant to that context, as well as the conversation as a whole. |
||||
- You are a critical journalist, so try to come up with a question that challenges the expert's statement or asks for more information. |
||||
- Make sure not to repeat yourself! Check what questions you have already asked to avoid repetition. |
||||
''' |
||||
return prompt |
||||
|
||||
def get_query_builder_system_message(): |
||||
system_message = """ |
||||
Take the user input and write it as a sentence that could be used as a query for a vector database. |
||||
The vector database will return text snippets that semantically match the query, so you CAN'T USE NEGATIONS or other complex language constructs. If there is a negation in the user input, exclude that part from the query. |
||||
If the user input seems to be a follow-up question or comment, use the context from the chat history to make a relevant query. |
||||
Answer ONLY with the query, no explanation or reasoning! |
||||
""" |
||||
return re.sub(r"\s*\n\s*", "\n", system_message) |
||||
|
||||
def get_note_summary_prompt(project: "Project", notes_string: str): |
||||
query = f''' |
||||
Below are notes from a project called "{project.name}". {description_string(project)}. |
||||
""" |
||||
{notes_string} |
||||
""" |
||||
I want you to summarize the notes in a concise manner. The summary will be used to create a system message for chatting with LLMs about the project. |
||||
Make sure to include the most important points, and include any key terms or concepts. |
||||
Try to gather at least something from each note, but don't reference the notes directly. |
||||
Answer ONLY with the summary, nothing else. |
||||
''' |
||||
return re.sub(r"\s*\n\s*", "\n", query) |
||||
|
||||
|
||||
def get_image_system_prompt(project: "Project"): |
||||
|
||||
system_message = f""" |
||||
You are an assistant to a journalist who is working on a project called {project.name}. Your task is to analyze and describe the images that are part of the project. |
||||
The images you get might show a graph, chart or other visual representation of data. If so, answer in a way that describes the data and the trends or patterns that can be seen in the image. |
||||
The images might also show a photo or illustration, in that case describe the content of the image in a way that could be used in a caption. |
||||
- Don't include any information that is not visible in the image. |
||||
- Don't focus to much on the different parts of the image/figure, but rather on the meaning of it. |
||||
- Answer ONLY with the description of the image, nothing else argumenting or explaining. |
||||
""" |
||||
return re.sub(r"\s*\n\s*", "\n", system_message) |
||||
|
||||
def get_tools_prompt(user_input): |
||||
return f'''The reporter has asked: "{user_input}" |
||||
What information is needed to answer the question? Choose one or many tools in order to answer the question. Make sure to read the description of the tools carefully before choosing. |
||||
If you are shure that you can answer the question in a correct way without fetching data, you can do that as well. |
||||
|
||||
''' |
||||
|
||||
|
||||
def get_summary_prompt(text, is_sci): |
||||
text = re.sub(r"\s*\n\s*", "\n", text) |
||||
if is_sci: |
||||
s = 'The text will be used as an abstract for a scientific article. Make sure to include the most important results and findings, as well with relevant information about methods, data and conclusions. Keep the summary concise.' |
||||
else: |
||||
s = 'Make sure to include the most important points and facts, and to keep the summary concise.' |
||||
|
||||
prompt = f'''Summarise the text below: |
||||
""" |
||||
{text} |
||||
""" |
||||
{s} |
||||
Use ONLY the information from the text to write the summary. Do not include any additional information or your own interpretation. |
||||
Answer ONLY with the summary, nothing else like reasoning or explanation. |
||||
''' |
||||
return re.sub(r"\s*\n\s*", "\n", prompt) |
||||
|
||||
|
||||
def get_generate_vector_query_prompt(user_input: str, role: str): |
||||
print(role.upper()) |
||||
if role in ["Research Assistant", "Editor"]: |
||||
query = f"""A user asked this question: "{user_input}". Generate a query for the vector database. Make sure to follow the instructions you got earlier!""" |
||||
elif role == "Guest": |
||||
query = f"""A podcast host has asked this question in an interview: "{user_input}". Generate a query for the vector database to answer the actial question. Make sure to follow the instructions you got earlier!""" |
||||
elif role == "Host": |
||||
query = f"""An expert has stated: "{user_input}". Generate a query for the vector database to get context for that answer in order to come up with a new question. Make sure to follow the instructions you got earlier!""" |
||||
return query |
||||
@ -0,0 +1,10 @@ |
||||
from _arango import ArangoDB |
||||
from _chromadb import ChromaDB |
||||
|
||||
arango = ArangoDB(db_name='lasse') |
||||
admin_arango = ArangoDB(db_name='base') |
||||
arango.db.collection('other_documents').truncate() |
||||
|
||||
chroma = ChromaDB() |
||||
chroma.db.delete_collection("other_documents") |
||||
chroma.db.create_collection("other_documents") |
||||
@ -0,0 +1,87 @@ |
||||
import streamlit as st |
||||
import streamlit_authenticator as stauth |
||||
import yaml |
||||
from yaml.loader import SafeLoader |
||||
from streamlit_authenticator import LoginError |
||||
from time import sleep |
||||
|
||||
from colorprinter.print_color import * |
||||
from _arango import ArangoDB |
||||
|
||||
def get_settings(): |
||||
""" |
||||
Function to get the settings from the ArangoDB. |
||||
""" |
||||
arango = ArangoDB(db_name=st.session_state["username"]) |
||||
st.session_state["settings"] = arango.db.collection("settings").get("settings") |
||||
return st.session_state["settings"] |
||||
|
||||
|
||||
st.set_page_config(page_title="Science Assistant Fish", page_icon="🐟") |
||||
|
||||
with open("streamlit_users.yaml") as file: |
||||
config = yaml.load(file, Loader=SafeLoader) |
||||
|
||||
authenticator = stauth.Authenticate( |
||||
config["credentials"], |
||||
config["cookie"]["name"], |
||||
config["cookie"]["key"], |
||||
config["cookie"]["expiry_days"], |
||||
) |
||||
|
||||
try: |
||||
authenticator.login() |
||||
except LoginError as e: |
||||
st.error(e) |
||||
|
||||
if st.session_state["authentication_status"]: |
||||
sleep(0.1) |
||||
# Retry mechanism for importing get_settings |
||||
for _ in range(3): |
||||
try: |
||||
get_settings() |
||||
except ImportError as e: |
||||
sleep(0.3) |
||||
print_red(e) |
||||
print("Retrying to import get_settings...") |
||||
|
||||
# Retry mechanism for importing pages |
||||
for _ in range(3): |
||||
|
||||
try: |
||||
from streamlit_pages import Article_Collections, Bot_Chat, Projects, Settings |
||||
break |
||||
except ImportError as e: |
||||
# Write the full error traceback |
||||
sleep(0.3) |
||||
print_red(e) |
||||
print("Retrying to import pages...") |
||||
|
||||
get_settings() |
||||
if 'current_page' in st.session_state["settings"]: |
||||
st.session_state["current_page"] = st.session_state["settings"]["current_page"] |
||||
else: |
||||
if 'current_page' not in st.session_state: |
||||
st.session_state["current_page"] = None |
||||
|
||||
if "not_downloaded" not in st.session_state: |
||||
st.session_state["not_downloaded"] = {} |
||||
|
||||
# Pages |
||||
bot_chat = st.Page(Bot_Chat) |
||||
projects = st.Page(Projects) |
||||
article_collections = st.Page(Article_Collections) |
||||
settings = st.Page(Settings) |
||||
|
||||
|
||||
pg = st.navigation([bot_chat, projects, article_collections, settings]) |
||||
pg.run() |
||||
with st.sidebar: |
||||
st.write("---") |
||||
authenticator.logout() |
||||
|
||||
|
||||
elif st.session_state["authentication_status"] is False: |
||||
st.error("Username/password is incorrect") |
||||
elif st.session_state["authentication_status"] is None: |
||||
st.warning("Please enter your username and password") |
||||
@ -0,0 +1,726 @@ |
||||
from datetime import datetime |
||||
import streamlit as st |
||||
from _base_class import BaseClass |
||||
from _llm import LLM |
||||
from prompts import * |
||||
from colorprinter.print_color import * |
||||
from llm_tools import ToolRegistry |
||||
|
||||
|
||||
class Chat(BaseClass): |
||||
def __init__(self, username: str, role: str, **kwargs): |
||||
super().__init__(username=username, **kwargs) |
||||
self.name = kwargs.get("name", None) |
||||
self.chat_history = kwargs.get("chat_history", []) |
||||
self.role = role |
||||
|
||||
def add_message(self, role, content): |
||||
self.chat_history.append( |
||||
{"role": role, "content": content.strip().strip('"'), "role_type": self.role} |
||||
) |
||||
|
||||
def to_dict(self): |
||||
return { |
||||
"name": self.name, |
||||
"chat_history": self.chat_history, |
||||
"role": self.role, |
||||
} |
||||
|
||||
def show_chat_history(self): |
||||
for message in self.chat_history: |
||||
if message["role"] not in ["user", "assistant"]: |
||||
continue |
||||
avatar = self.get_avatar(message) |
||||
with st.chat_message(message["role"], avatar=avatar): |
||||
if message["content"]: |
||||
st.markdown(message["content"].strip('"')) |
||||
|
||||
def get_avatar(self, message: dict = None, role=None) -> str: |
||||
assert message or role, "Either message or role must be provided" |
||||
if message and message.get("role", None) == "user" or role == "user": |
||||
avatar = st.session_state["settings"].get("avatar", "user") |
||||
elif ( |
||||
message and message.get("role", None) == "assistant" or role == "assistant" |
||||
): |
||||
role_type = message.get("role_type", self.role) if message else self.role |
||||
if role_type == "Research Assistant": |
||||
avatar = "img/avatar_researcher.png" |
||||
elif role_type == "Editor": |
||||
avatar = "img/avatar_editor.png" |
||||
elif role_type == "Host": |
||||
avatar = "img/avatar_host.png" |
||||
elif role_type == "Guest": |
||||
avatar = "img/avatar_guest.png" |
||||
else: |
||||
avatar = None |
||||
else: |
||||
avatar = None |
||||
return avatar |
||||
|
||||
def set_name(self, user_input): |
||||
llm = LLM( |
||||
model="small", |
||||
max_length_answer=50, |
||||
temperature=0.4, |
||||
system_message="You are a chatbot who will be chatting with a user", |
||||
) |
||||
prompt = ( |
||||
f'Give a short name to the chat based on this user input: "{user_input}" ' |
||||
"No more than 30 characters. Answer ONLY with the name of the chat." |
||||
) |
||||
name = llm.generate(prompt) |
||||
print_blue("Chat Name", name) |
||||
name = f'{name} - {datetime.now().strftime("%B %d")}' |
||||
# Check if the chat name already exists |
||||
existing_chat = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc.name == "{name}" RETURN doc', count=True |
||||
) |
||||
if existing_chat.count() > 0: |
||||
name = f'{name} ({datetime.now().strftime("%H:%M")})' |
||||
|
||||
name += f" - [{self.role}]" |
||||
self.name = name |
||||
return name |
||||
|
||||
@classmethod |
||||
def from_dict(cls, data): |
||||
return cls( |
||||
name=data.get("name"), |
||||
chat_history=data.get("chat_history"), |
||||
role=data.get("role", "Research Assistant"), |
||||
) |
||||
|
||||
|
||||
class Bot(BaseClass): |
||||
def __init__(self, chat: Chat, username: str, **kwargs): |
||||
super().__init__(username=username, **kwargs) |
||||
self.chat = chat |
||||
self.project = kwargs.get("project", None) |
||||
self.collection: list = kwargs.get( |
||||
"collection", st.session_state["settings"]["current_collection"] |
||||
) |
||||
if not self.collection and self.project: |
||||
self.collection = self.project.collections |
||||
|
||||
if not isinstance(self.collection, list): |
||||
self.collection = [self.collection] |
||||
|
||||
# Load articles in the collections |
||||
self.arango_ids = [] |
||||
for collection in self.collection: |
||||
for _id in self.user_arango.db.aql.execute( |
||||
''' |
||||
FOR doc IN article_collections |
||||
FILTER doc.name == @collection |
||||
FOR article IN doc.articles |
||||
RETURN article._id |
||||
''', |
||||
bind_vars={"collection": collection}, |
||||
): |
||||
self.arango_ids.append(_id) |
||||
|
||||
self.chosen_backend = kwargs.get("chosen_backend", None) |
||||
|
||||
self.chatbot: LLM = None |
||||
self.tools: list[dict] = None |
||||
|
||||
self.chatbot_memory = None |
||||
self.helperbot_memory = None |
||||
|
||||
# Initialize LLM instances |
||||
self.helperbot = LLM( |
||||
temperature=0, |
||||
model="small", |
||||
max_length_answer=500, |
||||
system_message=get_query_builder_system_message(), |
||||
messages=self.helperbot_memory, |
||||
) |
||||
|
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message="Choose one or many tools to use in order to assist the user. Make sure to read the description of the tools carefully.", |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
|
||||
|
||||
# self.sidebar_content() |
||||
|
||||
def sidebar_content(self): |
||||
with st.sidebar: |
||||
st.write("---") |
||||
st.markdown(f'#### {self.chat.name if self.chat.name else ""}') |
||||
st.button("Delete this chat", on_click=self.delete_chat) |
||||
|
||||
def delete_chat(self): |
||||
self.user_arango.db.collection("chats").delete_match( |
||||
filters={"name": self.chat.name} |
||||
) |
||||
self.chat = Chat() |
||||
|
||||
def get_chunks( |
||||
self, |
||||
user_input, |
||||
collections=["sci_articles", "other_documents"], |
||||
n_results=7, |
||||
n_sources=4, |
||||
filter=True, |
||||
): |
||||
|
||||
if not isinstance(n_sources, int): |
||||
n_sources = int(n_sources) |
||||
if not isinstance(n_results, int): |
||||
n_results = int(n_results) |
||||
|
||||
# Use self.chat.chat_history if needed |
||||
if self.chat.role == "Editor": |
||||
n_results = 4 |
||||
n_sources = 3 |
||||
|
||||
query = self.helperbot.generate( |
||||
get_generate_vector_query_prompt(user_input, self.chat.role) |
||||
) |
||||
print_rainbow("Vector query:", query) |
||||
|
||||
combined_chunks = [] |
||||
|
||||
for collection in collections: |
||||
where_filter = {"_id": {"$in": self.arango_ids}} if filter else {} |
||||
|
||||
chunks = self.get_chromadb().query( |
||||
query=query, |
||||
collection=collection, |
||||
n_results=n_results, |
||||
n_sources=n_sources, |
||||
where=where_filter, |
||||
max_retries=3, |
||||
) |
||||
|
||||
for doc, meta, dist in zip( |
||||
chunks["documents"][0], |
||||
chunks["metadatas"][0], |
||||
chunks["distances"][0], |
||||
): |
||||
combined_chunks.append( |
||||
{"document": doc, "metadata": meta, "distance": dist} |
||||
) |
||||
|
||||
combined_chunks.sort(key=lambda x: x["distance"]) |
||||
|
||||
sources = set() |
||||
closest_chunks = [] |
||||
for chunk in combined_chunks: |
||||
source_id = chunk["metadata"]["_id"] |
||||
if source_id not in sources: |
||||
sources.add(source_id) |
||||
closest_chunks.append(chunk) |
||||
if len(sources) >= n_sources: |
||||
break |
||||
|
||||
if len(closest_chunks) < n_results: |
||||
remaining_chunks = [ |
||||
chunk for chunk in combined_chunks if chunk not in closest_chunks |
||||
] |
||||
closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)]) |
||||
|
||||
for chunk in closest_chunks: |
||||
_id = chunk["metadata"]["_id"] |
||||
if _id.split("/")[0] == "sci_articles": |
||||
arango_doc = self.get_arango(admin=True).db.document(_id) |
||||
arango_metadata = arango_doc.get("metadata", {}) if arango_doc else {} |
||||
else: |
||||
arango_doc = self.user_arango.db.document(_id) |
||||
arango_metadata = arango_doc.get("metadata", {}) if arango_doc else {} |
||||
|
||||
chunk["metadata"] = arango_metadata or { |
||||
"title": "No title", |
||||
"published_date": "No published date", |
||||
"journal": "No journal", |
||||
} |
||||
|
||||
for k in ["published_date", "journal", "title"]: |
||||
chunk["metadata"].setdefault(k, f"No {k}") |
||||
|
||||
sorted_chunks = sorted( |
||||
closest_chunks, |
||||
key=lambda x: ( |
||||
x["metadata"]["published_date"], |
||||
x["metadata"]["title"], |
||||
), |
||||
) |
||||
|
||||
grouped_chunks = {} |
||||
article_number = 1 |
||||
for chunk in sorted_chunks: |
||||
title = chunk["metadata"]["title"] |
||||
chunk["article_number"] = article_number |
||||
if title not in grouped_chunks: |
||||
grouped_chunks[title] = { |
||||
"article_number": article_number, |
||||
"chunks": [], |
||||
} |
||||
article_number += 1 |
||||
grouped_chunks[title]["chunks"].append(chunk) |
||||
|
||||
return grouped_chunks |
||||
|
||||
def process_user_input(self, user_input): |
||||
|
||||
# Add user's message to chat history |
||||
self.chat.add_message("user", user_input) |
||||
|
||||
# Generate response with tool support |
||||
prompt = get_tools_prompt(user_input) |
||||
response = self.toolbot.generate(prompt, tools=self.tools, stream=False) |
||||
print_yellow("Tool to use") |
||||
# Check if the LLM wants to use a tool |
||||
if isinstance(response, dict) and "tool_calls" in response: |
||||
bot_response = self.answer_tool_call(response, user_input) |
||||
|
||||
else: |
||||
# Use the LLM's direct response |
||||
bot_response = response.strip('"') |
||||
with st.chat_message( |
||||
"assistant", avatar=self.chat.get_avatar(role="assitant") |
||||
): |
||||
st.write(bot_response) |
||||
|
||||
# Add assistant's message to chat history |
||||
if self.chat.chat_history[-1]["role"] != "assistant": |
||||
self.chat.add_message("assistant", bot_response) |
||||
self.chatbot_memory = self.chatbot.messages |
||||
self.helperbot_memory = self.helperbot.messages |
||||
|
||||
# Save the chat data without heavy objects |
||||
chat_data = self.chat.to_dict() |
||||
self.user_arango.db.collection("chats").update_match( |
||||
filters={"name": self.chat.name}, |
||||
body={ |
||||
"chat_data": chat_data, |
||||
"chatbot_memory": self.chatbot_memory, |
||||
"helperbot_memory": self.helperbot_memory, |
||||
}, |
||||
) |
||||
|
||||
def answer_tool_call(self, response, user_input): |
||||
bot_responses = [] |
||||
tool_calls = response["tool_calls"] |
||||
for tool_call in tool_calls: |
||||
function = tool_call["function"] |
||||
function_name = function["name"] |
||||
arguments = function.get("arguments", {}) |
||||
arguments["query"] = user_input |
||||
|
||||
# Find and execute the tool function |
||||
with st.chat_message( |
||||
"assistant", avatar=self.chat.get_avatar(role="assistant") |
||||
): |
||||
if function_name in [ |
||||
"fetch_other_documents", |
||||
"fetch_science_articles", |
||||
"fetch_science_articles_and_other_documents", |
||||
]: |
||||
chunks = getattr(self, function_name)(**arguments) |
||||
# Provide the tool's output back to the LLM |
||||
response = self.generate_from_chunks(user_input, chunks) |
||||
bot_response = st.write_stream(response) |
||||
bot_response = bot_response.strip('"') |
||||
|
||||
if len(chunks) > 0: |
||||
sources = "###### Sources: \n" |
||||
for title, group in chunks.items(): |
||||
sources += ( |
||||
f"[{group['article_number']}] **{title}** " |
||||
f":gray[{group['chunks'][0]['metadata']['journal']} " |
||||
f"({group['chunks'][0]['metadata']['published_date']})] \n" |
||||
) |
||||
st.markdown(sources) |
||||
bot_response = f"{bot_response}\n\n{sources}" |
||||
bot_responses.append(bot_response) |
||||
|
||||
elif function_name == "fetch_notes": |
||||
notes = getattr(self, function_name)() |
||||
response = self.generate_from_notes(user_input, notes) |
||||
bot_response = st.write_stream(response) |
||||
bot_responses.append(bot_response.strip('"')) |
||||
|
||||
elif function_name == "conversational_response": |
||||
response = getattr(self, function_name)(user_input) |
||||
bot_response = st.write_stream(response) |
||||
bot_responses.append(bot_response.strip('"')) |
||||
|
||||
return "\n\n".join(bot_responses) |
||||
|
||||
def generate_from_notes(self, user_input, notes): |
||||
notes_string = "" |
||||
for note in notes: |
||||
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n" |
||||
|
||||
prompt = get_chat_prompt(user_input, notes_string, role=self.chat.role) |
||||
|
||||
with st.spinner("Reading project notes..."): |
||||
return self.chatbot.generate(prompt, stream=True) |
||||
|
||||
def generate_from_chunks(self, user_input, chunks): |
||||
|
||||
chunks_string = "" |
||||
for title, group in chunks.items(): |
||||
chunks_content_string = "\n(...)\n".join( |
||||
[chunk["document"] for chunk in group["chunks"]] |
||||
) |
||||
chunks_string += ( |
||||
f"\n# {title}\n" |
||||
f"## Article number: {group['article_number']}\n" |
||||
f"## {group['chunks'][0]['metadata']['published_date']} in " |
||||
f"{group['chunks'][0]['metadata']['journal']}\n" |
||||
f"{chunks_content_string}\n---\n" |
||||
) |
||||
|
||||
prompt = get_chat_prompt(user_input, chunks_string, role=self.chat.role) |
||||
|
||||
magazines = list( |
||||
set( |
||||
[ |
||||
f"*{group['chunks'][0]['metadata']['journal']}*" |
||||
for group in chunks.values() |
||||
if "metadata" in group["chunks"][0] |
||||
] |
||||
) |
||||
) |
||||
if len(magazines) > 0: |
||||
s = f"Reading articles from {', '.join(magazines[:-1])} and {magazines[-1]}..." |
||||
else: |
||||
s = "Reading articles..." |
||||
with st.spinner(s): |
||||
return ( |
||||
self.chatbot.generate(prompt, stream=True) |
||||
if self.chatbot |
||||
else self.llm.generate(prompt, stream=True) |
||||
) |
||||
|
||||
def run(self): |
||||
if not hasattr(st.session_state, "bot"): |
||||
st.session_state.bot = self |
||||
|
||||
# Display chat history |
||||
self.chat.show_chat_history() |
||||
|
||||
if user_input := st.chat_input("Write your message here..."): |
||||
with st.chat_message("user", avatar=self.chat.get_avatar(role="user")): |
||||
st.write(user_input) |
||||
if not self.chat.name: |
||||
self.chat.set_name(user_input) |
||||
existing_chat = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN chats FILTER doc.name == "{self.chat.name}" RETURN doc', |
||||
count=True, |
||||
) |
||||
if existing_chat.count() == 0: |
||||
chat_data = self.chat.to_dict() |
||||
chat_doc = self.user_arango.db.collection("chats").insert( |
||||
{ |
||||
"name": self.chat.name, |
||||
"collection": self.collection, |
||||
"project": self.project.name if self.project else None, |
||||
"last_updated": datetime.now().isoformat(), |
||||
"saved": False, |
||||
"chat_data": chat_data, |
||||
} |
||||
) |
||||
self.chat_key = chat_doc["_key"] |
||||
self.process_user_input(user_input) |
||||
self.update_session_state() |
||||
|
||||
def get_notes(self): |
||||
notes = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN notes FILTER doc.project == "{self.project.name}" RETURN doc' |
||||
) |
||||
return list(notes) |
||||
|
||||
# Register tools |
||||
@ToolRegistry.register( |
||||
name="fetch_science_articles", |
||||
description="Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles.", |
||||
parameters={ |
||||
"type": "object", |
||||
"properties": { |
||||
"n_documents": { |
||||
"type": "integer", |
||||
"description": "How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.", |
||||
} |
||||
}, |
||||
"required": ["n_documents"], |
||||
}, |
||||
) |
||||
def fetch_science_articles(self, query: str, n_documents: int): |
||||
return self.get_chunks( |
||||
query, collections=["sci_articles"], n_results=n_documents |
||||
) |
||||
|
||||
@ToolRegistry.register( |
||||
name="fetch_other_documents", |
||||
description="Fetches information from other documents based on the user's query. Other documents can include reports, news articles, and other kinds of texts. Use this tool only when it's obvious that the user is not looking for scientific articles.", |
||||
parameters={ |
||||
"type": "object", |
||||
"properties": { |
||||
"n_documents": { |
||||
"type": "integer", |
||||
"description": "How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10.", |
||||
} |
||||
}, |
||||
"required": ["n_documents"], |
||||
}, |
||||
) |
||||
def fetch_other_documents(self, query: str, n_documents: int): |
||||
return self.get_chunks( |
||||
query, collections=["other_documents"], n_results=n_documents |
||||
) |
||||
|
||||
@ToolRegistry.register( |
||||
name="fetch_science_articles_and_other_documents", |
||||
description="Fetches information from both scientific articles and other documents. This is often used when the user hasn't specified what kind of sources they are interested in.", |
||||
parameters={ |
||||
"type": "object", |
||||
"properties": { |
||||
"n_documents": { |
||||
"type": "integer", |
||||
"description": "How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.", |
||||
} |
||||
}, |
||||
"required": ["n_documents"], |
||||
}, |
||||
) |
||||
def fetch_science_articles_and_other_documents(self, query: str, n_documents: int): |
||||
return self.get_chunks( |
||||
query, |
||||
collections=["sci_articles", "other_documents"], |
||||
n_results=n_documents, |
||||
) |
||||
|
||||
@ToolRegistry.register( |
||||
name="fetch_notes", |
||||
description="Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools!", |
||||
) |
||||
def fetch_notes(self): |
||||
return self.get_notes() |
||||
|
||||
@ToolRegistry.register( |
||||
name="conversational_response", |
||||
description="Generates a conversational response without fetching data. Use this ONLY if it is obvious that the user is not looking for information but only wants to chat.", |
||||
) |
||||
def conversational_response(self, query: str): |
||||
query = f'User message: "{query}". Make your answer short and conversational. Include a very brief description of the project if you think that would be helpful.' |
||||
result = ( |
||||
self.chatbot.generate(query, stream=True) |
||||
if self.chatbot |
||||
else self.llm.generate(query, stream=True) |
||||
) |
||||
return result |
||||
|
||||
|
||||
class EditorBot(Bot): |
||||
def __init__(self, chat: Chat, username: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.role = "Editor" |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
"fetch_notes", |
||||
"conversational_response", |
||||
"fetch_science_articles", |
||||
"fetch_other_documents", |
||||
"fetch_science_articles_and_other_documents", |
||||
] |
||||
) |
||||
self.chatbot = LLM( |
||||
system_message=get_editor_prompt(kwargs.get("project")), |
||||
messages=self.chatbot_memory, |
||||
chosen_backend=kwargs.get("chosen_backend"), |
||||
) |
||||
|
||||
|
||||
class ResearchAssistantBot(Bot): |
||||
def __init__(self, chat: Chat, username: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.role = "Research Assistant" |
||||
self.chatbot = LLM( |
||||
system_message=get_assistant_prompt(), |
||||
temperature=0.1, |
||||
messages=self.chatbot_memory, |
||||
) |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
"fetch_science_articles", |
||||
"fetch_other_documents", |
||||
"fetch_science_articles_and_other_documents", |
||||
] |
||||
) |
||||
|
||||
|
||||
class PodBot(Bot): |
||||
"""Two LLM agents construct a conversation using material from science articles.""" |
||||
|
||||
def __init__( |
||||
self, |
||||
chat: Chat, |
||||
subject: str, |
||||
username: str, |
||||
instructions: str = None, |
||||
**kwargs, |
||||
): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.subject = subject |
||||
self.instructions = instructions |
||||
self.guest_name = kwargs.get("name_guest", "Merit") |
||||
self.hostbot = HostBot( |
||||
Chat(username=self.username, role="Host"), subject, username, instructions=instructions, **kwargs |
||||
) |
||||
self.guestbot = GuestBot( |
||||
Chat(username=self.username, role="Guest"), |
||||
subject, |
||||
username, |
||||
name_guest=self.guest_name, |
||||
**kwargs, |
||||
) |
||||
|
||||
def run(self): |
||||
|
||||
notes = self.get_notes() |
||||
notes_string = "" |
||||
if self.instructions: |
||||
instructions_string = f''' |
||||
These are the instructions for the podcast from the producer: |
||||
""" |
||||
{self.instructions} |
||||
""" |
||||
''' |
||||
else: |
||||
instructions_string = "" |
||||
|
||||
for note in notes: |
||||
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n" |
||||
a = f'''You will make a podcast interview with {self.guest_name}, an expert on "{self.subject}". |
||||
{instructions_string} |
||||
Below are notes on the subject that you can use to ask relevant questions: |
||||
""" |
||||
{notes_string} |
||||
""" |
||||
Say hello to the expert and start the interview. Remember to keep the interview to the subject of {self.subject} throughout the conversation. |
||||
''' |
||||
|
||||
with st.sidebar: |
||||
stop = st.button("Stop the podcast") |
||||
if stop: |
||||
st.session_state["make_podcast"] = False |
||||
while st.session_state["make_podcast"]: |
||||
|
||||
# Stop the podcast if there are more than 14 messages in the chat |
||||
self.chat.show_chat_history() |
||||
if len(self.chat.chat_history) == 14: |
||||
result = self.hostbot.generate( |
||||
"The interview has ended. Say thank you to the expert and end the conversation." |
||||
) |
||||
self.chat.add_message("Host", result) |
||||
with st.chat_message( |
||||
"assistant", avatar=self.chat.get_avatar(role="assistant") |
||||
): |
||||
st.write(result.strip('"')) |
||||
st.stop() |
||||
|
||||
_q = self.hostbot.toolbot.generate( |
||||
query=f"{self.guest_name} has answered: {a}. You have to choose a tool to help the host continue the interview.", |
||||
tools=self.hostbot.tools, |
||||
temperature=0.6, |
||||
stream=False, |
||||
) |
||||
if "tool_calls" in _q: |
||||
print_yellow("Tool call response (host)", _q) |
||||
print_purple("HOST", self.hostbot.chat.role) |
||||
q = self.hostbot.answer_tool_call(_q, a) |
||||
else: |
||||
q = _q |
||||
|
||||
self.chat.add_message("Host", q) |
||||
|
||||
_a = self.guestbot.toolbot.generate( |
||||
f'The podcast host has asked: "{q}" Choose a tool to help the expert answer with relevant facts and information.', |
||||
tools=self.guestbot.tools, |
||||
) |
||||
if "tool_calls" in _a: |
||||
print_yellow("Tool call response (guest)", _a) |
||||
print_yellow(self.guestbot.chat.role) |
||||
a = self.guestbot.answer_tool_call(_a, q) |
||||
else: |
||||
a = _a |
||||
self.chat.add_message("Guest", a) |
||||
|
||||
|
||||
self.update_session_state() |
||||
|
||||
|
||||
class HostBot(Bot): |
||||
def __init__(self, chat: Chat, subject: str, username: str, instructions: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.chat.role = kwargs.get("role", "Host") |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
"fetch_notes", |
||||
"conversational_response", |
||||
"fetch_other_documents", |
||||
] |
||||
) |
||||
self.instructions = instructions |
||||
self.llm = LLM( |
||||
system_message=f''' |
||||
You are the host of a podcast and an expert on {subject}. You will ask one question at a time about the subject, and then wait for the answer. |
||||
Don't ask the guest to talk about herself/himself, only about the subject. |
||||
These are the instructions for the podcast from the producer: |
||||
""" |
||||
{self.instructions} |
||||
""" |
||||
If the experts' answer is complicated, try to make a very brief summary of it for the audience to understand. You can also ask follow-up questions to clarify the answer, or ask for examples. |
||||
''' |
||||
) |
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message=''' |
||||
You are assisting a podcast host in asking questions to an expert. |
||||
Choose one or many tools to use in order to assist the host in asking relevant questions. |
||||
Often "conversational_response" is enough, but sometimes notes are needed or even other documents. |
||||
Make sure to read the description of the tools carefully!''', |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
def generate(self, query): |
||||
return self.llm.generate(query) |
||||
|
||||
|
||||
class GuestBot(Bot): |
||||
def __init__(self, chat: Chat, subject: str, username: str, **kwargs): |
||||
super().__init__(chat=chat, username=username, **kwargs) |
||||
self.chat.role = kwargs.get("role", "Guest") |
||||
self.tools = ToolRegistry.get_tools( |
||||
tools=[ |
||||
"fetch_notes", |
||||
"fetch_science_articles", |
||||
] |
||||
) |
||||
self.llm = LLM( |
||||
system_message=f""" |
||||
You are {kwargs.get('name', 'Merit')}, an expert on {subject}. |
||||
Today you are a guest in a podcast about {subject}. A host will ask you questions about the subject and you will answer by using scientific facts and information. |
||||
Try to be concise when answering, and remember that the audience of the podcast is not expert on the subject, so don't complicate things too much. |
||||
It's very important that you answer in a "spoken" way, as if you were talking to someone in a conversation. That means you should avoid using scientific jargon and complex terms, too many figures or abstract concepts. |
||||
Lists are also not recommended, instead use "for the first reason", "secondly", etc. |
||||
Instead, use "..." to indicate a pause, "-" to indicate a break in the sentence, as if you were speaking. |
||||
""" |
||||
) |
||||
self.toolbot = LLM( |
||||
temperature=0, |
||||
system_message=f"You are an assistant to an expert on {subject}. Choose one or many tools to use in order to assist the expert in answering questions. Make sure to read the description of the tools carefully.", |
||||
chat=False, |
||||
model="small", |
||||
) |
||||
|
||||
def generate(self, query): |
||||
return self.llm.generate(query) |
||||
@ -0,0 +1,44 @@ |
||||
# streamlit_pages.py |
||||
|
||||
import streamlit as st |
||||
from colorprinter.print_color import * |
||||
from time import sleep |
||||
def Projects(): |
||||
""" |
||||
Function to handle the Projects page. |
||||
""" |
||||
from _classes import ProjectsPage |
||||
if 'Projects' not in st.session_state: |
||||
st.session_state['Projects'] = {} |
||||
projectpage = ProjectsPage(username=st.session_state["username"]) |
||||
projectpage.run() |
||||
|
||||
def Bot_Chat(): |
||||
""" |
||||
Function to handle the Chat Bot page. |
||||
""" |
||||
from _classes import BotChatPage |
||||
if 'bot_chat_page' not in st.session_state: |
||||
st.session_state['Bot Chat'] = {} |
||||
chatpage = BotChatPage(username=st.session_state["username"]) |
||||
chatpage.run() |
||||
|
||||
def Article_Collections(): |
||||
""" |
||||
Function to handle the Article Collections page. |
||||
""" |
||||
from _classes import ArticleCollectionsPage |
||||
if 'article_collections' not in st.session_state: |
||||
st.session_state['Article Collections'] = {} |
||||
|
||||
article_collection = ArticleCollectionsPage(username=st.session_state["username"]) |
||||
article_collection.run() |
||||
|
||||
|
||||
def Settings(): |
||||
""" |
||||
Function to handle the Settings page. |
||||
""" |
||||
from _classes import SettingsPage |
||||
settings = SettingsPage(username=st.session_state["username"]) |
||||
settings.run() |
||||
@ -0,0 +1,31 @@ |
||||
from TTS.api import TTS |
||||
import torch |
||||
from datetime import datetime |
||||
tts = TTS("tts_models/en/multi-dataset/tortoise-v2") |
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
tts.to(device) |
||||
text="There is, therefore, an increasing need to understand BEVs from a systems perspective. This involves an in-depth consideration of the environmental impact of the product using life cycle assessment (LCA) as well as taking a broader 'circular economy' approach. On the one hand, LCA is a means of assessing the environmental impact associated with all stages of a product's life from cradle to grave: from raw material extraction and processing to the product's manufacture to its use in everyday life and finally to its end of life." |
||||
|
||||
|
||||
# cloning `lj` voice from `TTS/tts/utils/assets/tortoise/voices/lj` |
||||
# with custom inference settings overriding defaults. |
||||
time_now = datetime.now().strftime("%Y%m%d%H%M%S") |
||||
output_path = f"output/tortoise_{time_now}.wav" |
||||
tts.tts_to_file(text, |
||||
file_path=output_path, |
||||
voice_dir="voices", |
||||
speaker="test", |
||||
split_sentences=False, # Change to True if context is not enough |
||||
num_autoregressive_samples=20, |
||||
diffusion_iterations=50) |
||||
|
||||
# # Using presets with the same voice |
||||
# tts.tts_to_file(text, |
||||
# file_path="output.wav", |
||||
# voice_dir="path/to/tortoise/voices/dir/", |
||||
# speaker="lj", |
||||
# preset="ultra_fast") |
||||
|
||||
# # Random voice generation |
||||
# tts.tts_to_file(text, |
||||
# file_path="output.wav") |
||||
@ -0,0 +1,51 @@ |
||||
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub |
||||
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface |
||||
from fairseq import utils |
||||
import nltk |
||||
import torch |
||||
|
||||
# Download the required NLTK resource |
||||
nltk.download('averaged_perceptron_tagger') |
||||
|
||||
# Model loading |
||||
models, cfg, task = load_model_ensemble_and_task_from_hf_hub( |
||||
"facebook/fastspeech2-en-ljspeech", |
||||
arg_overrides={"vocoder": "hifigan", "fp16": False} |
||||
) |
||||
|
||||
# Set device |
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
|
||||
# Move all models to the correct device |
||||
for model in models: |
||||
model.to(device) |
||||
|
||||
# Update configuration and build generator after moving models |
||||
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg) |
||||
generator = task.build_generator(models, cfg) |
||||
|
||||
# Ensure the vocoder is on the correct device |
||||
generator.vocoder.model.to(device) |
||||
|
||||
# Define your text |
||||
text = """Hi there, thanks for having me! My interest in electric cars really started back when I was a teenager...""" |
||||
|
||||
# Convert text to model input |
||||
sample = TTSHubInterface.get_model_input(task, text) |
||||
|
||||
# Recursively move all tensors in sample to the correct device |
||||
sample = utils.move_to_cuda(sample) if torch.cuda.is_available() else sample |
||||
|
||||
|
||||
|
||||
# Generate speech |
||||
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample) |
||||
|
||||
from scipy.io.wavfile import write |
||||
|
||||
# If wav is a tensor, convert it to a NumPy array |
||||
if isinstance(wav, torch.Tensor): |
||||
wav = wav.cpu().numpy() |
||||
|
||||
# Save the audio to a WAV file |
||||
write('output_fair.wav', rate, wav) |
||||
@ -0,0 +1,45 @@ |
||||
import torch |
||||
from TTS.api import TTS |
||||
from datetime import datetime |
||||
# Get device |
||||
from TTS.tts.utils.speakers import SpeakerManager |
||||
device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
||||
|
||||
# Init TTS |
||||
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) |
||||
|
||||
|
||||
exit() |
||||
|
||||
|
||||
|
||||
|
||||
text = """Hi there, thanks for having me! My interest in electric cars really started back when I was a teenager. I remember learning about the history of EVs and how they've been around since the late 1800s, even before gasoline cars took over. The fact that these vehicles could run on electricity instead of fossil fuels just fascinated me. |
||||
|
||||
Then, in the 90s, General Motors introduced the EV1 - it was a real game-changer. It showed that electric cars could be practical and enjoyable to drive. And when Tesla came along with their Roadster in 2007, proving that EVs could have a long range, I was hooked. |
||||
|
||||
But what really sealed my interest was learning about the environmental impact of EVs. They produce zero tailpipe emissions, which means they can help reduce air pollution and greenhouse gas emissions. That's something I'm really passionate about. |
||||
""" |
||||
text_se = """Antalet bilar ger dock bara en del av bilden. För att förstå bilberoendet bör vi framför allt titta på hur mycket bilarna faktiskt används. |
||||
Stockholmarnas genomsnittliga körsträcka med bil har minskat sedan millennieskiftet. Den är dock lägre i Göteborg och i Malmö. |
||||
I procent har bilanvändningen sedan år 2000 minskat lika mycket i Stockholm och Malmö, 9 procent. I Göteborg är minskningen 13 procent, i riket är minskningen 7 procent.""" |
||||
# Run TTS |
||||
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language |
||||
# Text to speech list of amplitude values as output |
||||
#wav = tts.tts(text=text, speaker_wav="my/cloning/audio.wav", language="en") |
||||
# Text to speech to a file |
||||
time_now = datetime.now().strftime("%Y%m%d%H%M%S") |
||||
output_path = f"output/tts_{time_now}.wav" |
||||
tts.tts_to_file(text=text, speaker_wav='voices/test/test_en.wav', language="en", file_path=output_path) |
||||
|
||||
|
||||
|
||||
|
||||
# api = TTS("tts_models/se/fairseq/vits") |
||||
|
||||
# api.tts_with_vc_to_file( |
||||
# text_se, |
||||
# speaker_wav="test_audio_se.wav", |
||||
# file_path="output_se.wav" |
||||
# ) |
||||
@ -0,0 +1,22 @@ |
||||
import requests |
||||
|
||||
# Define the server URL |
||||
server_url = "http://localhost:5002/api/tts" |
||||
|
||||
# Define the payload |
||||
payload = { |
||||
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.", |
||||
"speaker": "Ana Florence", |
||||
"language": "en", |
||||
"split_sentences": True |
||||
} |
||||
|
||||
# Send the request to the TTS server |
||||
response = requests.post(server_url, json=payload) |
||||
|
||||
# Save the response audio to a file |
||||
if response.status_code == 200: |
||||
with open("output.wav", "wb") as f: |
||||
f.write(response.content) |
||||
else: |
||||
print(f"Error: {response.status_code}") |
||||
@ -0,0 +1,33 @@ |
||||
from TTS.tts.configs.tortoise_config import TortoiseConfig |
||||
from TTS.tts.models.tortoise import Tortoise |
||||
import torch |
||||
import os |
||||
import torchaudio |
||||
|
||||
# Initialize Tortoise model |
||||
config = TortoiseConfig() |
||||
model = Tortoise.init_from_config(config) |
||||
model.load_checkpoint(config, checkpoint_dir="tts_models/en/multi-dataset/tortoise-v2", eval=True) |
||||
|
||||
# Move model to GPU if available |
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
||||
print(device) |
||||
model.to(device) |
||||
|
||||
# Define the text and voice directory |
||||
text = "There is, therefore, an increasing need to understand BEVs from a systems perspective." |
||||
voice_dir = "voices" |
||||
speaker = "test" |
||||
|
||||
# Load voice samples |
||||
voice_samples = [] |
||||
for file_name in os.listdir(os.path.join(voice_dir, speaker)): |
||||
file_path = os.path.join(voice_dir, speaker, file_name) |
||||
waveform, sample_rate = torchaudio.load(file_path) |
||||
voice_samples.append(waveform) |
||||
|
||||
# Get conditioning latents |
||||
conditioning_latents = model.get_conditioning_latents(voice_samples) |
||||
|
||||
# Save conditioning latents to a file |
||||
torch.save(conditioning_latents, "conditioning_latents.pth") |
||||
@ -0,0 +1,16 @@ |
||||
# utils.py |
||||
import re |
||||
|
||||
def fix_key(_key: str) -> str: |
||||
""" |
||||
Sanitize a given key by replacing all characters that are not alphanumeric, |
||||
underscore, hyphen, dot, at symbol, parentheses, plus, equals, semicolon, |
||||
dollar sign, asterisk, single quote, percent, or colon with an underscore. |
||||
|
||||
Args: |
||||
_key (str): The key to be sanitized. |
||||
|
||||
Returns: |
||||
str: The sanitized key with disallowed characters replaced by underscores. |
||||
""" |
||||
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;$!*\'%:]", "_", _key) |
||||
Loading…
Reference in new issue