first commit

main
lasseedfast 1 year ago
commit 01df43bba2
  1. 66
      _arango.py
  2. 140
      _base_class.py
  3. 135
      _chromadb.py
  4. 929
      _classes.py
  5. 253
      _llm.py
  6. 79
      _mailersend.py
  7. 0
      _project.py
  8. 781
      article2db.py
  9. 4
      dl_sci.py
  10. 192
      info.py
  11. 53
      llm_server.py
  12. 27
      llm_tools.py
  13. 97
      new_user.py
  14. 0
      pod_bot.py
  15. 187
      prompts.py
  16. 10
      reset_test.py
  17. 87
      streamlit_app.py
  18. 726
      streamlit_chatbot.py
  19. 44
      streamlit_pages.py
  20. 14
      test.py
  21. 31
      test_ tortoise.py
  22. 51
      test_fairseq.py
  23. 45
      test_tts.py
  24. 22
      test_tts_call_server.py
  25. 33
      tts_save_speaker.py
  26. 16
      utils.py

@ -0,0 +1,66 @@
import re
from arango import ArangoClient
from dotenv import load_dotenv
import os
import env_manager
load_dotenv() # Install with pip install python-dotenv
class ArangoDB:
def __init__(self, user=None, password=None, db_name=None):
"""
Initializes an instance of the ArangoClass.
Args:
db_name (str): The name of the database.
username (str): The username for authentication.
password (str): The password for authentication.
"""
host = os.getenv("ARANGO_HOST")
if not user:
user = os.getenv("ARANGO_USER")
if not password:
password = os.getenv("ARANGO_PASSWORD")
if not db_name:
db_name = os.getenv("ARANGO_DB")
self.client = ArangoClient(hosts=host)
self.db = self.client.db(db_name, username=user, password=password)
def fix_key(self, _key):
"""
Sanitize a given key by replacing all characters that are not alphanumeric,
underscore, hyphen, dot, at symbol, parentheses, plus, equals, semicolon,
dollar sign, asterisk, single quote, percent, or colon with an underscore.
Args:
_key (str): The key to be sanitized.
Returns:
str: The sanitized key with disallowed characters replaced by underscores.
"""
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;$!*\'%:]", "_", _key)
if __name__ == "__main__":
arango = ArangoDB(db_name='base')
articles = arango.db.collection('sci_articles').all()
for article in articles:
if 'metadata' in article and article['metadata']:
if 'abstract' in article['metadata']:
abstract = article['metadata']['abstract']
if isinstance(abstract, str):
# Remove text within <> brackets and the brackets themselves
article['metadata']['abstract'] = re.sub(r'<[^>]*>', '', abstract)
arango.db.collection('sci_articles').update_match(
filters={'_key': article['_key']},
body={'metadata': article['metadata']},
merge=True
)
print(f"Updated abstract for {article['_key']}")

@ -0,0 +1,140 @@
# _base_class.py
import os
import re
import streamlit as st
from _arango import ArangoDB
from _llm import LLM
from _chromadb import ChromaDB
class BaseClass:
def __init__(self, username: str, **kwargs) -> None:
self.username: str = username
self.project_name: str = kwargs.get('project_name', None)
self.collection: str = kwargs.get('collection_name', None)
self.user_arango: ArangoDB = self.get_arango()
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB:
if db_name:
return ArangoDB(db_name=db_name)
elif admin:
return ArangoDB()
else:
return ArangoDB(db_name=st.session_state["username"])
def get_settings(self):
settings = self.user_arango.db.document("settings/settings")
if not settings:
self.user_arango.db.collection("settings").insert(
{"_key": "settings", "current_collection": None, "current_page": None}
)
for i in ["current_collection", "current_page"]:
if i not in settings:
settings[i] = None
st.session_state["settings"] = settings
return settings
def update_settings(self, key, value) -> None:
self.user_arango.db.collection("settings").update_match(
filters={"_key": "settings"},
body={key: value},
merge=True,
)
st.session_state["settings"][key] = value
def get_article_collections(self) -> list:
article_collections = self.user_arango.db.aql.execute(
'FOR doc IN article_collections RETURN doc["name"]'
)
return list(article_collections)
def get_projects(self) -> list:
projects = self.user_arango.db.aql.execute(
'FOR doc IN projects RETURN doc["name"]'
)
return list(projects)
def choose_collection(self, text="Select a collection of favorite articles") -> str:
collections = self.get_article_collections()
collection = st.selectbox(text, collections, index=None)
if collection:
self.project = None
self.collection = collection
self.update_settings("current_collection", collection)
self.update_session_state()
return collection
def choose_project(self, text="Select a project") -> str:
projects = self.get_projects()
project = st.selectbox(text, projects, index=None)
if project:
from _classes import Project
self.project = Project(self.username, project, self.user_arango)
self.collection = None
self.update_settings("current_project", self.project.name)
self.update_session_state()
return self.project
def get_chromadb(self):
return ChromaDB()
def get_project(self, project_name: str):
doc = self.user_arango.db.aql.execute(
f'FOR doc IN projects FILTER doc["name"] == "{project_name}" RETURN doc',
count=True,
)
if doc:
return doc.next()
def update_current_page(self, page_name):
if st.session_state.get("current_page") != page_name:
st.session_state["current_page"] = page_name
self.update_settings("current_page", page_name)
def update_session_state(self, page_name=None):
"""
Updates the Streamlit session state with the attributes of the current instance.
Parameters:
page_name (str, optional): The name of the page to update in the session state.
If not provided, it defaults to the current page stored in the session state.
The method iterates over the instance's attributes and updates the session state
for the given page name with those attributes that are of type str, int, float, list, dict, or bool.
"""
if not page_name:
page_name = st.session_state.get("current_page")
for attr, value in self.__dict__.items():
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]):
st.session_state[page_name][attr] = value
def set_filename(self, filename=None, folder="other_documents"):
"""
Checks if a file with the given filename already exists in the download folder.
If it does, appends or increments a numerical suffix to the filename until a unique name is found.
Args:
filename (str): The base name of the file to check.
Sets:
self.file_path (str): The unique file path generated.
"""
download_folder = f"user_data/{self.username}/{self.document_type}"
if self.is_sci and not self.document_type == "notes":
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_")
return os.path.exists(self.file_path)
else:
file_path = f"{self.download_folder}/{filename}"
while os.path.exists(file_path + ".pdf"):
if not re.search(r"(_\d+)$", file_path):
file_path += "_1"
else:
file_path = re.sub(
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path
)
self.file_path = file_path + ".pdf"
return file_path

@ -0,0 +1,135 @@
import chromadb
import os
from chromadb.config import Settings
from dotenv import load_dotenv
from colorprinter.print_color import *
load_dotenv(".env")
class ChromaDB:
def __init__(self, local_deployment: bool = False, db="sci_articles", host=None):
if local_deployment:
self.db = chromadb.PersistentClient(f"chroma_{db}")
else:
if not host:
host = os.getenv("CHROMA_HOST")
credentials = os.getenv("CHROMA_CLIENT_AUTH_CREDENTIALS")
auth_token_transport_header = os.getenv(
"CHROMA_AUTH_TOKEN_TRANSPORT_HEADER"
)
self.db = chromadb.HttpClient(
host=host,
settings=Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=credentials,
chroma_auth_token_transport_header=auth_token_transport_header,
),
)
def query(
self,
query,
collection,
n_results=6,
n_sources=3,
max_retries=None,
where: dict = None,
**kwargs,
):
if not isinstance(n_sources, int):
n_sources = int(n_sources)
if not isinstance(n_results, int):
n_results = int(n_results)
if not max_retries:
max_retries = n_sources
if n_sources > n_results:
n_sources = n_results
col = self.db.get_collection(collection)
sources = []
n = 0
result = {"ids": [[]], "metadatas": [[]], "documents": [[]], "distances": [[]]}
while True:
n += 1
if n > max_retries:
break
r = col.query(
query_texts=query,
n_results=n_results - len(sources),
where=where,
**kwargs,
)
if r["ids"][0] == []:
if result['ids'][0] == []:
print_red("No results found in vector database.")
else:
print_red("No more results found in vector database.")
break
# Manually extend each list within the lists of lists
for key in result:
if key in r:
result[key][0].extend(r[key][0])
# Order result by distance
combined = sorted(
zip(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
),
key=lambda x: x[0],
)
(
result["distances"][0],
result["ids"][0],
result["metadatas"][0],
result["documents"][0],
) = map(list, zip(*combined))
sources += list(set([i["_id"] for i in result["metadatas"][0]]))
if len(sources) >= n_sources:
break
elif n != max_retries:
for k, v in result.items():
if k not in r["included"]:
continue
result[k][0] = v[0][: n_results - (n_sources - len(sources))]
if "_id" in where:
where["_id"]["$in"] = [
i for i in where["_id"]["$in"] if i not in sources
]
if where["_id"]["$in"] == []:
break
return result
if __name__ == "__main__":
from colorprinter.print_color import *
chroma = ChromaDB()
result = chroma.query(
query="What is Open Science)",
collection="sci_articles",
n_results=2,
n_sources=3,
max_retries=4,
)
print(result)
exit()
all = chroma_collection.get()
ids = all.get("ids", [])
metadatas = all.get("metadatas", [])
combined_list = list(zip(ids, metadatas))
ids = []
metadatas = []
for id, metadata in combined_list:
ids.append(id)
metadata["_id"] = f"sci_articles/{metadata['_key']}"
metadatas.append(metadata)
chroma_collection.update(ids=ids, metadatas=metadatas)

@ -0,0 +1,929 @@
# streamlit_pages.py
import re
import streamlit as st
from time import sleep
import pandas as pd
from datetime import datetime, timedelta
from PIL import Image
from io import BytesIO
import base64
from colorprinter.print_color import *
from article2db import PDFProcessor
from streamlit_chatbot import Chat, EditorBot, ResearchAssistantBot, PodBot
from info import country_emojis
from utils import fix_key
from _arango import ArangoDB
from _llm import LLM
from _base_class import BaseClass
from prompts import get_note_summary_prompt, get_image_system_prompt
class ArticleCollectionsPage(BaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.collection = None
self.page_name = "Article Collections"
# Initialize attributes from session state if available
for k, v in st.session_state[self.page_name].items():
setattr(self, k, v)
def run(self):
if self.user_arango.db.collection("article_collections").count() == 0:
self.create_new_collection()
self.update_current_page(self.page_name)
self.choose_collection_method()
self.choose_project_method()
if self.collection:
self.display_collection()
self.sidebar_actions()
if st.session_state.get("new_collection"):
self.create_new_collection()
# Persist state to session_state
self.update_session_state(page_name=self.page_name)
def choose_collection_method(self):
self.collection = self.choose_collection()
# Persist state after choosing collection
self.update_session_state(page_name=self.page_name)
def choose_project_method(self):
# If you have a project selection similar to collection, implement here
pass # Placeholder for project-related logic
def choose_collection(self):
collections = self.get_article_collections()
current_collection = self.collection
preselected = (
collections.index(current_collection)
if current_collection in collections
else None
)
with st.sidebar:
collection = st.selectbox(
"Select a collection of favorite articles",
collections,
index=preselected,
)
if collection:
self.collection = collection
self.update_settings("current_collection", collection)
return self.collection
def create_new_collection(self):
with st.form("create_collection_form", clear_on_submit=True):
new_collection_name = st.text_input("Enter the name of the new collection")
submitted = st.form_submit_button("Create Collection")
if submitted:
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
st.success(f'New collection "{new_collection_name}" created')
self.collection = new_collection_name
self.update_settings("current_collection", new_collection_name)
# Persist state after creating a new collection
self.update_session_state(page_name=self.page_name)
sleep(1)
st.rerun()
def display_collection(self):
with st.sidebar:
col1, col2 = st.columns(2)
with col1:
if st.button("Create new collection"):
st.session_state["new_collection"] = True
with col2:
if st.button(f'Remove collection "{self.collection}"'):
self.user_arango.db.collection("article_collections").delete_match(
{"name": self.collection}
)
st.success(f'Collection "{self.collection}" removed')
self.collection = None
self.update_settings("current_collection", None)
# Persist state after removing a collection
self.update_session_state(page_name=self.page_name)
st.rerun()
self.show_articles_in_collection()
def show_articles_in_collection(self):
collection_articles_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{self.collection}" RETURN doc["articles"]'
)
collection_articles = list(collection_articles_cursor)
if collection_articles and collection_articles[0]:
articles = collection_articles[0]
st.markdown(f"#### Articles in *{self.collection}*:")
for article in articles:
if article is None:
continue
metadata = article.get("metadata")
if metadata is None:
continue
title = metadata.get("title", "No Title").strip()
journal = metadata.get("journal", "No Journal").strip()
published_year = metadata.get("published_year", "No Year")
published_date = metadata.get("published_date", None)
language = metadata.get("language", "No Language")
icon = country_emojis.get(language.upper(), "") if language else ""
expander_title = f"**{title}** *{journal}* ({published_year}) {icon}"
with st.expander(expander_title):
if not title == "No Title":
st.markdown(f"**Title:** \n{title}")
if not journal == "No Journal":
st.markdown(f"**Journal:** \n{journal}")
if published_date:
st.markdown(f"**Published Date:** \n{published_date}")
for key, value in article.items():
if key in [
"_key",
"text",
"file",
"_rev",
"chunks",
"user_access",
"_id",
"metadata",
"doi",
]:
continue
if isinstance(value, list):
value = ", ".join(value)
st.markdown(f"**{key.capitalize()}**: \n{value} ")
if "doi" in article:
st.markdown(
f"**DOI:** \n[{article['doi']}](https://doi.org/{article['doi']}) "
)
st.button(
key=f'delete_{article["_id"]}',
label="Delete article from collection",
on_click=self.delete_article,
args=(self.collection, article["_id"]),
)
else:
st.write("No articles in this collection.")
def sidebar_actions(self):
with st.sidebar:
st.markdown(f"### Add new articles to {self.collection}")
with st.form("add_articles_form", clear_on_submit=True):
pdf_files = st.file_uploader(
"Upload PDF file(s)", type=["pdf"], accept_multiple_files=True
)
is_sci = st.checkbox("All articles are from scientific journals")
submitted = st.form_submit_button("Upload")
if submitted and pdf_files:
self.add_articles(pdf_files, is_sci)
# Persist state after adding articles
self.update_session_state(page_name=self.page_name)
st.rerun()
help_text = 'Paste a text containing DOIs, e.g., the reference section of a paper, and click "Add Articles" to add them to the collection.'
new_articles = st.text_area(
"Add articles to this collection", help=help_text
)
if st.button("Add Articles"):
with st.spinner("Processing..."):
self.process_dois(
article_collection_name=self.collection, text=new_articles
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
st.rerun()
self.write_not_downloaded()
def add_articles(self, pdf_files: list, is_sci: bool) -> None:
for pdf_file in pdf_files:
status_container = st.empty()
with status_container:
is_sci = is_sci if is_sci else None
with st.status(f"Processing {pdf_file.name}..."):
processor = PDFProcessor(
pdf_file=pdf_file,
filename=pdf_file.name,
process=False,
username=st.session_state["username"],
document_type="other_documents",
is_sci=is_sci,
)
_id, db, doi = processor.process_document()
print_rainbow(_id, db, doi)
if doi in st.session_state.get("not_downloaded", {}):
st.session_state["not_downloaded"].pop(doi)
self.articles2collection(collection=self.collection, db=db, _id=_id)
st.success("Done!")
sleep(1.5)
def articles2collection(self, collection: str, db: str, _id: str = None) -> None:
info = self.get_article_info(db, _id=_id)
info = {
k: v for k, v in info.items() if k in ["_id", "doi", "title", "metadata"]
}
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = doc.get("articles", [])
keys = [i["_id"] for i in articles]
if info["_id"] not in keys:
articles.append(info)
self.user_arango.db.collection("article_collections").update_match(
filters={"name": collection},
body={"articles": articles},
merge=True,
)
# Persist state after updating articles
self.update_session_state(page_name=self.page_name)
def get_article_info(self, db: str, _id: str = None, doi: str = None) -> dict:
assert _id or doi, "Either _id or doi must be provided."
arango = self.get_arango(db_name=db)
if _id:
query = """
RETURN {
"_id": DOCUMENT(@doc_id)._id,
"doi": DOCUMENT(@doc_id).doi,
"title": DOCUMENT(@doc_id).title,
"metadata": DOCUMENT(@doc_id).metadata,
"summary": DOCUMENT(@doc_id).summary
}
"""
info_cursor = arango.db.aql.execute(query, bind_vars={"doc_id": _id})
elif doi:
info_cursor = arango.db.aql.execute(
f'FOR doc IN sci_articles FILTER doc["doi"] == "{doi}" LIMIT 1 RETURN {{"_id": doc["_id"], "doi": doc["doi"], "title": doc["title"], "metadata": doc["metadata"], "summary": doc["summary"]}}'
)
return next(info_cursor, None)
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
def write_not_downloaded(self):
not_downloaded = st.session_state.get("not_downloaded", {})
if not_downloaded:
st.markdown(
"*The articles below were not downloaded. Download them yourself and add them to the collection by dropping them in the area above. Some of them can be downloaded using the link.*"
)
for doi, url in not_downloaded.items():
if url:
st.markdown(f"- [{doi}]({url})")
else:
st.markdown(f"- {doi}")
def delete_article(self, collection, _id):
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = [
article for article in doc.get("articles", []) if article["_id"] != _id
]
self.user_arango.db.collection("article_collections").update_match(
filters={"_id": doc["_id"]},
body={"articles": articles},
)
# Persist state after deleting an article
self.update_session_state(page_name=self.page_name)
class BotChatPage(BaseClass):
def __init__(self, username):
super().__init__(username=username)
self.collection_name = None
self.project_name = None
self.project: Project = None
self.chat = None
self.role = "Research Assistant" # Default persona
self.page_name = "Bot Chat"
# Initialize attributes from session state if available
if self.page_name in st.session_state:
for k, v in st.session_state[self.page_name].items():
setattr(self, k, v)
def run(self):
bot = None
self.update_current_page("Bot Chat")
self.remove_old_unsaved_chats()
self.sidebar_actions()
if self.collection_name or self.project:
# If no chat exists, create a new Chat instance
if not self.chat:
self.chat = Chat(username=self.username, role=self.role)
self.chat.show_chat_history()
# Create a Bot instance with the Chat object
if self.role == "Research Assistant":
bot = ResearchAssistantBot(
username=self.username,
chat=self.chat,
collection=self.collection_name,
project=self.project,
)
elif self.role == "Editor":
bot = EditorBot(
username=self.username,
chat=self.chat,
collection=self.collection,
project=self.project,
)
elif self.role == "Podcast":
st.session_state["make_podcast"] = True
# with st.sidebar:
with st.sidebar:
with st.form("make_podcast_form"):
instructions = st.text_area(
"What should the podcast be about? Give a brief description, as if you were the producer."
)
start = st.form_submit_button("Make Podcast!")
if start:
bot = PodBot(
subject=self.project.name,
username=self.username,
chat=self.chat,
collection=self.collection,
project=self.project,
instructions=instructions,
)
# Run the bot (this will display chat history and process user input)
if bot:
bot.run()
# Save updated chat state to session state
st.session_state[self.page_name] = {
"collection": self.collection,
"project": self.project,
"chat": self.chat,
"role": self.role,
}
def sidebar_actions(self):
with st.sidebar:
self.collection = self.choose_collection(
"Article collection to use for chat:"
)
self.project_name = self.choose_project("Project to use for chat:")
if self.collection or self.project:
st.write("---")
if self.project_name:
self.role = st.selectbox(
"Choose Bot Role",
options=["Research Assistant", "Editor", "Podcast"],
index=0,
)
elif self.collection:
self.role = "Research Assistant"
# Load existing chats from the database
if self.project:
chat_history = list(
self.user_arango.db.aql.execute(
f'FOR doc IN chats FILTER doc["project"] == "{self.project}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}'
)
)
# self.project = Project(username=self.username, project_name=self.project_name, user_arango=self.user_arango)
elif self.collection:
chat_history = list(
self.user_arango.db.aql.execute(
f'FOR doc IN chats FILTER doc["collection"] == "{self.collection}" RETURN {{"_key": doc["_key"], "name": doc["name"]}}'
)
)
chats = {i["name"]: i["_key"] for i in chat_history}
selected_chat = st.selectbox(
"Continue another chat", options=[""] + list(chats.keys()), index=0
)
if selected_chat:
chat_data = self.user_arango.db.collection("chats").get(
chats[selected_chat]
)
chat_dict = chat_data.get("chat_data", {})
self.chat = Chat.from_dict(chat_dict)
else:
self.chat = None
def remove_old_unsaved_chats(self):
two_weeks_ago = datetime.now() - timedelta(weeks=2)
old_chats = self.user_arango.db.aql.execute(
f'FOR doc IN chats FILTER doc.saved == false AND doc.last_updated < "{two_weeks_ago.isoformat()}" RETURN doc'
)
for chat in old_chats:
self.user_arango.db.collection("chats").delete(chat["_key"])
class ProjectsPage(BaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.projects = []
self.selected_project_name = None
self.project = None
self.page_name = "Projects"
# Initialize attributes from session state if available
page_state = st.session_state.get(self.page_name, {})
for k, v in page_state.items():
setattr(self, k, v)
def run(self):
self.update_current_page(self.page_name)
self.load_projects()
self.display_projects()
# Update session state
self.update_session_state(self.page_name)
def load_projects(self):
projects_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects RETURN doc", count=True
)
self.projects = list(projects_cursor)
def display_projects(self):
with st.sidebar:
self.new_project_button()
self.selected_project_name = st.selectbox(
"Select a project to manage",
options=[proj["name"] for proj in self.projects],
)
if self.selected_project_name:
self.project = Project(
username=self.username,
project_name=self.selected_project_name,
user_arango=self.user_arango,
)
self.manage_project()
# Update session state
self.update_session_state(self.page_name)
def new_project_button(self):
st.session_state.setdefault("new_project", False)
with st.sidebar:
if st.button("New project", type="primary"):
st.session_state["new_project"] = True
if st.session_state["new_project"]:
self.create_new_project()
# Update session state
self.update_session_state(self.page_name)
def create_new_project(self):
new_project_name = st.text_input("Enter the name of the new project")
new_project_description = st.text_area(
"Enter the description of the new project"
)
if st.button("Create Project"):
if new_project_name:
self.user_arango.db.collection("projects").insert(
{
"name": new_project_name,
"description": new_project_description,
"collections": [],
"notes": [],
"note_keys_hash": hash(""),
"settings": {},
}
)
st.success(f'New project "{new_project_name}" created')
st.session_state["new_project"] = False
self.update_settings("current_project", new_project_name)
sleep(1)
st.rerun()
def show_project_notes(self):
with st.expander("Show summarised notes"):
st.markdown(self.project.notes_summary)
with st.expander("Show project notes"):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc",
bind_vars={"note_ids": self.project.notes},
)
notes = list(notes_cursor)
if notes:
for note in notes:
st.markdown(f'_{note.get("timestamp", "")}_')
st.markdown(note["text"].replace("\n", " \n"))
st.button(
key=f'delete_note_{note["_id"]}',
label="Delete note",
on_click=self.project.delete_note,
args=(note["_id"],),
)
st.write("---")
else:
st.write("No notes in this project.")
def manage_project(self):
self.update_settings("current_project", self.selected_project_name)
# Initialize the Project instance
self.project = Project(
self.username, self.selected_project_name, self.user_arango
)
st.write(f"## {self.project.name}")
self.show_project_notes()
self.relate_collections()
self.sidebar_actions()
self.project.update_notes_hash()
if st.button(f"Remove project *{self.project.name}*"):
self.user_arango.db.collection("projects").delete_match(
{"name": self.project.name}
)
self.update_settings("current_project", None)
st.success(f'Project "{self.project.name}" removed')
st.rerun()
# Update session state
self.update_session_state(self.page_name)
def relate_collections(self):
collections = [
col["name"]
for col in self.user_arango.db.collection("article_collections").all()
]
selected_collections = st.multiselect(
"Relate existing collections", options=collections
)
if st.button("Relate Collections"):
self.project.add_collections(selected_collections)
st.success("Collections related to the project")
# Update session state
self.update_session_state(self.page_name)
new_collection_name = st.text_input(
"Enter the name of the new collection to create and relate"
)
if st.button("Create and Relate Collection"):
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
self.project.add_collection(new_collection_name)
st.success(
f'New collection "{new_collection_name}" created and related to the project'
)
# Update session state
self.update_session_state(self.page_name)
def sidebar_actions(self):
self.sidebar_notes()
# Update session state
self.update_session_state(self.page_name)
def sidebar_notes(self):
with st.sidebar:
st.markdown(f"### Add new notes to {self.project.name}")
self.upload_notes_form()
self.add_text_note()
self.add_wikipedia_data()
# Update session state
self.update_session_state(self.page_name)
def upload_notes_form(self):
with st.form("add_notes", clear_on_submit=True):
files = st.file_uploader(
"Upload PDF or image",
type=["png", "jpg", "pdf"],
accept_multiple_files=True,
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.process_uploaded_files(files)
# Update session state
self.update_session_state(self.page_name)
def add_text_note(self):
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies."
note_text = st.text_area("Write or paste anything.", help=help_text)
if st.button("Add Note"):
self.project.add_note(
{
"text": note_text,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"),
}
)
st.success("Note added to the project")
# Update session state
self.update_session_state(self.page_name)
def add_wikipedia_data(self):
wiki_url = st.text_input(
"Paste the address to a Wikipedia page to add its summary as a note",
placeholder="Paste Wikipedia URL",
)
if st.button("Add Wikipedia data"):
with st.spinner("Fetching Wikipedia data..."):
wiki_data = self.project.get_wikipedia_data(wiki_url)
if wiki_data:
self.project.process_wikipedia_data(wiki_data, wiki_url)
st.success("Wikipedia data added to notes")
# Update session state
self.update_session_state(self.page_name)
st.rerun()
class Project(BaseClass):
def __init__(self, username: str, project_name: str, user_arango: ArangoDB):
super().__init__(username=username)
self.name = project_name
self.user_arango = user_arango
self.description = ""
self.collections = []
self.notes = []
self.note_keys_hash = 0
self.settings = {}
self.notes_summary = ""
# Initialize attributes from arango doc if available
self.load_project()
def load_project(self):
print_blue("Project name:", self.name)
print(self.user_arango, type(self.user_arango))
project_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects FILTER doc.name == @name RETURN doc",
bind_vars={"name": self.name},
)
project = next(project_cursor, None)
if not project:
raise ValueError(f"Project '{self.name}' not found.")
self._key = project["_key"]
self.name = project.get("name", "")
self.description = project.get("description", "")
self.collections = project.get("collections", [])
self.notes = project.get("notes", [])
self.note_keys_hash = project.get("note_keys_hash", 0)
self.settings = project.get("settings", {})
self.notes_summary = project.get("notes_summary", "")
def update_project(self):
updated_doc = {
"_key": self._key,
"name": self.name,
"description": self.description,
"collections": self.collections,
"notes": self.notes,
"note_keys_hash": self.note_keys_hash,
"settings": self.settings,
"notes_summary": self.notes_summary,
}
self.user_arango.db.collection("projects").update(updated_doc, check_rev=False)
self.update_session_state()
def add_collections(self, collections):
self.collections.extend(collections)
self.update_project()
def add_collection(self, collection_name):
self.collections.append(collection_name)
self.update_project()
def add_note(self, note: dict):
assert note["text"], "Note text cannot be empty"
note["text"] = note["text"].strip().strip("\n")
if "timestamp" not in note:
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M")
note_doc = self.user_arango.db.collection("notes").insert(note)
if note_doc["_id"] not in self.notes:
self.notes.append(note_doc["_id"])
self.update_project()
def delete_note(self, note_id):
if note_id in self.notes:
self.notes.remove(note_id)
self.update_project()
def update_notes_hash(self):
current_hash = self.make_project_notes_hash()
if current_hash != self.note_keys_hash:
self.note_keys_hash = current_hash
with st.spinner("Summarizing notes for chatbot..."):
self.create_notes_summary()
self.update_project()
def make_project_notes_hash(self):
if not self.notes:
return hash("")
note_keys_str = "".join(self.notes)
return hash(note_keys_str)
def create_notes_summary(self):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text",
bind_vars={"note_ids": self.notes},
)
notes = list(notes_cursor)
notes_string = "\n---\n".join(notes)
llm = LLM(model="small")
query = get_note_summary_prompt(self, notes_string)
summary = llm.generate(query)
print_purple("New summary of notes:", summary)
self.notes_summary = summary
self.update_session_state()
def analyze_image(self, image_base64, text=None):
project_data = {"name": self.name}
llm = LLM(system_message=get_image_system_prompt(project_data))
prompt = (
f'Analyze the image. The text found in it read: "{text}"'
if text
else "Analyze the image."
)
description = llm.generate(query=prompt, images=[image_base64], stream=False)
print_green("Image description:", description)
def process_uploaded_files(self, files):
with st.spinner("Processing files..."):
for file in files:
st.write("Processing...")
filename = fix_key(file.name)
image_file = self.file2img(file)
pdf_file = self.convert_image_to_pdf(image_file)
pdf = PDFProcessor(
pdf_file=pdf_file,
is_sci=False,
document_type="notes",
is_image=True,
process=False,
)
text = pdf.process_pdf()
base64_str = base64.b64encode(file.read())
image_caption = self.analyze_image(base64_str, text=text)
self.add_note(
{
"_id": f"notes/{filename}",
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}",
}
)
st.success("Done!")
sleep(1.5)
self.update_session_state()
st.rerun()
def file2img(self, file):
img_bytes = file.read()
if not img_bytes:
raise ValueError("Uploaded file is empty.")
return Image.open(BytesIO(img_bytes))
def convert_image_to_pdf(self, img):
import pytesseract
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img)
pdf_file = BytesIO(pdf_bytes)
pdf_file.name = (
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf"
)
return pdf_file
def get_wikipedia_data(self, page_url: str) -> dict:
import wikipedia
from urllib.parse import urlparse
parsed_url = urlparse(page_url)
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path)
if page_name_match:
page_name = page_name_match.group(0)
else:
st.warning("Invalid Wikipedia URL")
return None
try:
page = wikipedia.page(page_name)
data = {
"title": page.title,
"summary": page.summary,
"content": page.content,
"url": page.url,
"references": page.references,
}
return data
except Exception as e:
st.error(f"Error fetching Wikipedia data: {e}")
return None
def process_wikipedia_data(self, wiki_data, wiki_url):
llm = LLM(
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!",
model="small",
)
if wiki_data.get("summary"):
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.'''
summary = llm.generate(query)
wiki_data["text"] = (
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}"
)
wiki_data.pop("summary", None)
wiki_data.pop("content", None)
self.user_arango.db.collection("notes").insert(
wiki_data, overwrite=True, silent=True
)
self.add_note(wiki_data)
processor = PDFProcessor(process=False)
dois = [
processor.extract_doi(ref)
for ref in wiki_data.get("references", [])
if processor.extract_doi(ref)
]
if dois:
current_collection = st.session_state["settings"].get("current_collection")
st.markdown(
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?"
)
if st.button("Add DOIs"):
self.process_dois(current_collection, dois=dois)
self.update_session_state()
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
self.update_session_state()
class SettingsPage(BaseClass):
def __init__(self, username: str):
super().__init__(username=username)
def run(self):
self.update_current_page("Settings")
self.set_profile_picture()
def set_profile_picture(self):
st.markdown("Profile picture")
profile_picture = st.file_uploader(
"Upload profile picture", type=["png", "jpg", "jpeg"]
)
if profile_picture:
# Resize the image to 64x64 pixels
from PIL import Image
img = Image.open(profile_picture)
img.thumbnail((64, 64))
img_path = f"user_data/{st.session_state['username']}/profile_picture.png"
img.save(img_path)
self.update_settings("avatar", img_path)
st.success("Profile picture uploaded")
sleep(1)

@ -0,0 +1,253 @@
import os
from typing import Literal, Optional
import requests
from requests.auth import HTTPBasicAuth
import tiktoken
import json
from colorprinter.print_color import *
import env_manager
import re
env_manager.set_env()
tokenizer = tiktoken.get_encoding("cl100k_base")
print(os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE"))
class LLM:
def __init__(
self,
system_message="You are an assistant.",
num_ctx=8192,
temperature=0.01,
model: Optional[Literal["small", "standard", "vision"]] = "standard",
max_length_answer=4096,
messages=None,
chat=True,
chosen_backend=None,
) -> None:
"""
Initialize the assistant with specified parameters.
Args:
system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.".
num_ctx (int): The number of context tokens to use. Defaults to 4096.
temperature (float): The temperature setting for the model's response generation. Defaults to 0.01.
chat (bool): Flag to indicate if the assistant is in chat mode. Defaults to True.
model (str): The model type to use. Defaults to "standard". Alternatives: 'small', 'standard', 'vision'.
max_length_answer (int): The maximum length of the generated answer. Defaults to 4000.
chosen_backend (str): The chosen ollama server for the request. Defaults to None.
Returns:
None
"""
self.model = self.get_model(model)
self.system_message = system_message
self.options = {"temperature": temperature, "num_ctx": num_ctx}
self.messages = messages or [{"role": "system", "content": self.system_message}]
self.max_length_answer = max_length_answer
self.chat = chat
self.chosen_backend = chosen_backend
def get_model(self, model_alias):
models = {
"standard": "LLM_MODEL",
"small": "LLM_MODEL_SMALL",
"vision": "LLM_MODEL_VISION",
"standard_64k": "LLM_MODEL_64K",
}
return os.getenv(models.get(model_alias, "LLM_MODEL"))
def count_tokens(self):
num_tokens = 0
for i in self.messages:
for k, v in i.items():
if k == "content":
if not isinstance(v, str):
v = str(v)
tokens = tokenizer.encode(v)
num_tokens += len(tokens)
return int(num_tokens)
def read_stream(self, response):
"""
Reads a stream of data from the given response object and yields the content of each message.
Args:
response (requests.Response): The response object to read the stream from.
Yields:
str: The content of each message in the stream.
Notes:
- The response is expected to provide data in chunks, which are decoded as UTF-8.
- Lines are split by newline characters.
- Each line is expected to be a JSON object containing a "message" key with a "content" field.
- If a chunk cannot be decoded as UTF-8, it is skipped.
- If a line cannot be parsed as JSON, it is skipped.
"""
buffer = ""
message = ""
for chunk in response.iter_content(chunk_size=64):
if chunk:
try:
message_part = chunk.decode("utf-8")
buffer += message_part
message += message_part
except UnicodeDecodeError:
continue
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line:
try:
json_data = json.loads(line)
yield json_data["message"]["content"]
except json.JSONDecodeError:
continue
self.messages.append({"role": "assistant", "content": message.strip('"')})
def generate(
self,
query,
stream=False,
tools=None,
function_call=None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature=None,
):
"""
Generates a response from the language model based on the provided query and options.
Args:
query (str): The input query to be processed by the language model.
stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): A list of tools to be used by the language model. Defaults to None.
function_call (dict, optional): A dictionary specifying a function call to be made by the language model. Defaults to None.
images (list, optional): A list of image paths or base64-encoded images to be included in the request. Defaults to None.
model (str, optional): The model alias to be used for generating the response. Defaults to None. Alternatives: 'small', 'standard', 'vision'.
Returns:
str: The generated response from the language model. If streaming is enabled, returns the streamed response.
"""
# Add custom header if large model is chosen
model = self.get_model(model) if model else self.model
temperature = temperature if temperature else self.options["temperature"]
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
headers = {"Content-Type": "application/json"}
from colorprinter.print_color import print_rainbow
if images:
import base64
model = self.get_model("vision")
# Convert image paths to base64
base64_images = []
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
# Handle images
for image in images:
if isinstance(image, str):
if base64_pattern.match(image):
base64_images.append(image)
else:
with open(image, "rb") as image_file:
base64_images.append(
base64.b64encode(image_file.read()).decode("utf-8")
)
message["images"] = base64_images
# Set the Content-Type header based on the presence of images
headers = {"Content-Type": "application/json; images"}
# Set the model type to the vision model
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
self.messages.append(message)
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer
if self.chat or len(self.messages) > 15000:
num_tokens = self.count_tokens() + self.max_length_answer / 2
if num_tokens < 8000 and "num_ctx" in self.options:
del self.options["num_ctx"]
else:
model = self.get_model("large")
headers["X-Model-Type"] = "standard_64k"
if tools:
stream = False
data = {
"messages": self.messages,
"stream": stream,
"keep_alive": 3600 * 24 * 7,
"model": model if model else self.model,
"options": self.options,
}
# Include tools if provided
if tools:
data["tools"] = tools
# Include function_call if provided
if function_call:
data["function_call"] = function_call
response = requests.post(
os.getenv("LLM_API_URL"),
headers=headers,
json=data,
auth=HTTPBasicAuth(
os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE")
),
stream=stream,
timeout=3600,
)
self.chosen_backend = response.headers.get('X-Chosen-Backend')
if response.status_code != 200:
print_red("Error!")
print_rainbow(response.content.decode("utf-8"))
if response.status_code == 404:
return "Target endpoint not found"
if response.status_code == 504:
return f"Gateway Timeout: {response.content.decode('utf-8')}"
if stream:
return self.read_stream(response)
else:
try:
response_json = response.json()
if tools and not response_json["message"].get("tool_calls"):
print_yellow("No tool calls in response".upper())
if "tool_calls" in response_json.get("message", {}):
# The LLM wants to invoke a function (tool)
result = response_json["message"]
else:
result = response_json["message"]["content"].strip('"')
self.messages.append({"role": "assistant", "content": result.strip('"')})
except requests.exceptions.JSONDecodeError:
print_red("Error: ", response.status_code, response.text)
return "An error occurred."
if not self.chat:
self.messages = [self.messages[0]]
return result
if __name__ == "__main__":
llm = LLM()
images = ["th-2182728540.jpeg"]
print(
llm.generate(
query="Hi there",
)
)

@ -0,0 +1,79 @@
import base64
import os
from time import sleep
from mailersend import emails
import dotenv
class MailSender:
def __init__(self):
dotenv.load_dotenv()
self.mailersend_api_key = os.getenv("MAILERSEND_API_KEY")
self.mailer = emails.NewEmail(mailersend_api_key=self.mailersend_api_key)
self.mail_body = {}
self.set_mail_to(recipients=input("Enter recipient email: "))
self.set_mail_from()
self.set_reply_to("No Reply", "noreply@assistant.fish")
self.set_subject(input("Enter subject: "))
plaintext_content = input("Enter plaintext content: ")
self.set_plaintext_content(plaintext_content)
html_content = input("Enter HTML content: ")
if html_content == "":
html_content = plaintext_content
self.set_html_content(html_content)
attachment_path = input("Path to image: ")
if attachment_path != "":
self.set_attachments(attachment_path)
# TODO: Add support for multiple attachments and other types of attachments
self.send_mail()
def set_mail_from(self):
mail_from = {"name": "SCI Fish", "email": 'sci@assistant.fish'}
self.mailer.set_mail_from(mail_from, self.mail_body)
def set_mail_to(self, recipients):
if isinstance(recipients, str):
recipients = [{"name": recipients, "email": recipients}]
elif isinstance(recipients, list):
recipients = [{"name": i, "email": i} for i in recipients]
self.mailer.set_mail_to(recipients, self.mail_body)
def set_subject(self, subject):
self.mailer.set_subject(subject, self.mail_body)
def set_html_content(self, html_content):
self.mailer.set_html_content(html_content, self.mail_body)
def set_plaintext_content(self, plaintext_content):
self.mailer.set_plaintext_content(plaintext_content, self.mail_body)
def set_reply_to(self, name, email):
reply_to = {"name": name, "email": email}
self.mailer.set_reply_to(reply_to, self.mail_body)
def set_attachments(self, file_path):
with open(file_path, "rb") as attachment:
att_read = attachment.read()
att_base64 = base64.b64encode(bytes(att_read))
attachments = [
{
"id": os.path.basename(file_path),
"filename": os.path.basename(file_path),
"content": f"{att_base64.decode('ascii')}",
"disposition": "attachment",
}
]
self.mailer.set_attachments(attachments, self.mail_body)
def send_mail(self):
r = self.mailer.send(self.mail_body)
sleep(4) # wait for email to be sent
if r.split("\n")[0].strip() != "202":
print("Error sending email")
else:
print("Email sent successfully")
# Example usage
if __name__ == "__main__":
mail_sender = MailSender()

@ -0,0 +1,781 @@
import io
import os
import re
from time import sleep
from datetime import datetime
import crossref_commons.retrieval as crossref
import pymupdf
import pymupdf4llm
import requests
from bs4 import BeautifulSoup
from pymupdf import Document
from semantic_text_splitter import MarkdownSplitter
from pyppeteer import launch
from arango.collection import StandardCollection as ArangoCollection
from arango.database import StandardDatabase as ArangoDatabase
import xml.etree.ElementTree as ET
from streamlit.runtime.uploaded_file_manager import UploadedFile
import streamlit as st
from _arango import ArangoDB
from _chromadb import ChromaDB
from _llm import LLM
from colorprinter.print_color import *
from utils import fix_key
class Document:
def __init__(
self,
pdf_file=None,
filename=None,
doi=None,
username=None,
is_sci=None,
is_image=False,
):
self.filename = filename
self.pdf_file = pdf_file
self.doi = doi
self.username = username
self.is_sci = is_sci
self.is_image = is_image
self.pdf = None
self._key = None
self._id = None
self.chunks = []
self.metadata = None
self.title = None
self.open_access = False
self.file_path = None
self.download_folder = None
self.text = ""
self.arango_db_name = None
self.document_type = None
if self.pdf_file:
self.open_pdf(self.pdf_file)
def make_summary_in_background(self):
data = {
"text": self.text,
"arango_db_name": self.arango_db_name,
"_id": self._id,
"is_sci": self.is_sci,
}
# Send the data to the FastAPI server
url = "http://192.168.1.11:8100/summarise_document"
requests.post(url, json=data)
def open_pdf(self, pdf_file):
st.write(f"Reading the file...")
if isinstance(pdf_file, bytes):
from io import BytesIO
pdf_file = BytesIO(pdf_file)
if isinstance(pdf_file, str):
self.pdf: Document = pymupdf.open(pdf_file)
elif isinstance(pdf_file, io.BytesIO):
try:
self.pdf: Document = pymupdf.open(stream=pdf_file, filetype="pdf")
except:
pdf_bytes = pdf_file.read()
pdf_stream = io.BytesIO(pdf_bytes)
self.pdf: Document = pymupdf.open(stream=pdf_stream, filetype="pdf")
def extract_text(self):
md_pages = pymupdf4llm.to_markdown(
self.pdf, page_chunks=True, show_progress=False
)
md_text = ""
for page in md_pages:
md_text += f"{page['text'].strip()}\n@{page['metadata']['page']}@\n"
md_text = re.sub(r"[-]{3,}", "", md_text)
md_text = re.sub(r"\n{3,}", "\n\n", md_text)
md_text = re.sub(r"\s{2,}", " ", md_text)
md_text = re.sub(r"\s*\n\s*", "\n", md_text)
self.text = md_text
def make_chunks(self, len_chunks=2200):
better_chunks = []
ts = MarkdownSplitter(len_chunks)
chunks = ts.chunks(self.text)
for chunk in chunks:
if len(chunk) < 40 and len(chunks) > 1:
continue
elif all(
[
len(chunk) < int(len_chunks / 3),
len(chunks[-1]) < int(len_chunks * 1.5),
len(better_chunks) > 0,
]
):
better_chunks[-1] += chunk
else:
better_chunks.append(chunk.strip())
self.chunks = better_chunks
def get_title(self, only_meta=False):
"""
Extracts the title from the PDF metadata or generates a title based on the filename.
Args:
only_meta (bool): If True, only attempts to retrieve the title from metadata.
If False, generates a title from the filename if metadata is not available.
Returns:
str: The title of the PDF if found in metadata or generated from the filename.
Returns None if only_meta is True and no title is found in metadata.
Raises:
AssertionError: If only_meta is False and no PDF file is provided to generate a title.
"""
xml_metadata = self.pdf.get_xml_metadata()
if not xml_metadata.strip():
return None
try:
root = ET.fromstring(xml_metadata)
except ET.ParseError:
return None
namespaces = {}
for elem in root.iter():
if elem.tag.startswith("{"):
uri, tag = elem.tag[1:].split("}")
prefix = uri.split("/")[-1]
namespaces[prefix] = uri
namespaces["rdf"] = "http://www.w3.org/1999/02/22-rdf-syntax-ns#"
namespaces["dc"] = "http://purl.org/dc/elements/1.1/"
title_element = root.find(
".//rdf:Description/dc:title/rdf:Alt/rdf:li", namespaces
)
if title_element is not None:
self.title = title_element.text
return title_element.text
else:
if only_meta:
return None
else:
assert (
self.pdf_file
), "PDF file must be provided to generate a title if no title in metadata."
try:
filename = self.pdf_file.split("/")[-1].replace(".pdf", "")
except:
filename = self.pdf_file.name.split("/")[-1].replace(".pdf", "")
self.title = f"{filename}_{datetime.now().strftime('%Y%m%d%H%M%S')}"
return self.title
def save_pdf(self, document_type):
assert (
self.is_sci or self.username
), "To save a PDF username must be provided for non-sci articles."
if self.is_sci:
download_folder = "sci_articles"
else:
download_folder = f"user_data/{self.username}/{document_type}"
if not os.path.exists(download_folder):
os.makedirs(download_folder)
self.download_folder = download_folder
if self.doi and not document_type == "notes":
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_")
if not os.path.exists(self.file_path):
self.file_path = f"{self.download_folder}/{fix_key(self.doi)}.pdf"
self.pdf.save(self.file_path)
else:
self.file_path = self.set_filename(self.get_title())
if not self.file_path:
try:
self.file_path = self.pdf_file.name
except:
self.file_path = self.pdf_file.split("/")[-1]
self.pdf.save(self.file_path)
return self.file_path
def set_filename(self, filename=None):
if self.is_sci and not self.document_type == "notes":
self.file_path = f"sci_articles/{self.doi}.pdf".replace("/", "_")
return os.path.exists(self.file_path)
else:
file_path = f"{self.download_folder}/{filename}"
while os.path.exists(file_path + ".pdf"):
if not re.search(r"(_\d+)$", file_path):
file_path += "_1"
else:
file_path = re.sub(
r"(\d+)$", lambda x: str(int(x.group()) + 1), file_path
)
self.file_path = file_path + ".pdf"
return file_path
class Processor:
def __init__(
self,
document: Document,
filename: str = None,
chroma_db: str = "sci_articles",
len_chunks: int = 2200,
local_chroma_deployment: bool = False,
process: bool = True,
document_type: str = None,
):
self.document = document
self.chromadb = ChromaDB(local_deployment=local_chroma_deployment, db=chroma_db)
self.len_chunks = len_chunks
self.document_type = document_type
self._id = None
if process:
self.process_document()
def get_arango(self, db_name=None, document_type=None):
if db_name and document_type:
arango = ArangoDB(db_name=db_name)
arango_collection = arango.db.collection(document_type)
elif self.document.is_sci:
arango = ArangoDB(db_name="base")
arango_collection = arango.db.collection("sci_articles")
elif self.document.open_access:
arango = ArangoDB(db_name="base")
arango_collection = arango.db.collection("other_documents")
else:
arango = ArangoDB(db_name=self.document.username)
arango_collection: ArangoCollection = arango.db.collection(
self.document_type
)
self.document.arango_db_name = arango.db.name
self.arango_collection = arango_collection
return arango_collection
def extract_doi(self, text, multi=False):
doi_pattern = r"10\.\d{4,9}/[-._;()/:A-Za-z0-9]+"
if multi:
dois = re.findall(doi_pattern, text)
processed_dois = [doi.strip(".").replace(".pdf", "") for doi in dois]
return processed_dois if processed_dois else None
else:
doi = re.search(doi_pattern, text)
if doi:
doi = doi.group()
doi = doi.strip(".").replace(".pdf", "")
if self.get_crossref(doi):
self.document.metadata = self.get_crossref(doi)
self.document.doi = doi
else:
for page in self.document.pdf.pages(0, 6):
text = page.get_text()
if re.search(doi_pattern, text):
llm = LLM(
temperature=0.01,
system_message='You are an assistant helping a user to extract the DOI from a scientific article. \
A DOI always starts with "10." and is followed by a series of numbers and letters, and a "/" in the middle.\
Sometimes the DOI is split by a line break, so be sure to check for that.',
max_length_answer=50,
)
prompt = f'''
This is the text of an article:
"""
{text}
"""
I want you to find the DOI of the article. Ansewer ONLY with the DOI, nothing else.
If you can't find the DOI, answer "not_found".
'''
st.write('Trying to extract DOI from text using LLM...')
doi = llm.generate(prompt).replace('https://doi.org/', '')
if doi == "not_found":
return None
else:
doi = re.search(doi_pattern, doi).group()
break
return doi
else:
return None
def chunks2chroma(self, _id, key):
st.write("Adding to vector database...")
assert self.document.text, "Document must have 'text' attribute."
ids = []
documents = []
metadatas = []
last_page = 1
for i, chunk in enumerate(self.document.chunks):
page_numbers = re.findall(r"@(\d+)@", chunk)
if page_numbers == []:
page_numbers = [last_page]
else:
last_page = page_numbers[-1]
id = fix_key(f"{key}_{i}")
ids.append(id)
metadata = {
"_key": id,
"file": self.document.file_path,
"chunk_nr": i,
"pages": ",".join([str(i) for i in page_numbers]),
"_id": _id,
}
if self.document.doi:
metadata["doi"] = self.document.doi
metadatas.append(metadata)
chunk = re.sub(r"@(\d+)@", "", chunk)
documents.append(chunk)
if self.document.is_sci:
chroma_collection = self.chromadb.db.get_or_create_collection(
"sci_articles"
)
else:
chroma_collection = self.chromadb.db.get_or_create_collection(
"other_documents"
)
chroma_collection.add(ids=ids, documents=documents, metadatas=metadatas)
def chunks2arango(self):
st.write("Adding to document database...")
assert self.document.text, "Document must have 'text' attribute."
if self.document.is_sci:
for key in ["doi", "metadata"]:
assert getattr(
self.document, key
), f"Document must have '{key}' attribute."
else:
assert (
getattr(self.document, "_key", None) or self.document.doi
), "Document must have '_key' attribute or DOI."
arango_collection = self.get_arango()
if self.document.doi:
key = self.document.doi
else:
key = self.document._key
arango_chunks = []
last_page = 1
for i, chunk in enumerate(self.document.chunks):
page_numbers = re.findall(r"@(\d+)@", chunk)
if page_numbers == []:
page_numbers = [last_page]
else:
last_page = page_numbers[-1]
id = fix_key(key) + f"_{i}"
chunk = re.sub(r"@(\d+)@", "", chunk)
arango_chunks.append({"text": chunk, "pages": page_numbers, "id": id})
if not hasattr(self.document, "_key"):
self.document._key = fix_key(key)
user_access = [self.document.username]
if not self.document.open_access:
if arango_collection.has(self.document._key):
doc = arango_collection.get(self.document._key)
if "user_access" in doc:
if doc["user_access"]:
if self.document.username not in doc["user_access"]:
user_access = doc["user_access"] + [self.document.username]
else:
user_access = [self.document.username]
if self.document.open_access:
user_access = None
arango_document = {
"_key": fix_key(self.document._key),
"file": self.document.file_path,
"chunks": arango_chunks,
"text": self.document.text,
"open_access": self.document.open_access,
"user_access": user_access,
"doi": self.document.doi,
"metadata": self.document.metadata,
"filename": self.document.filename,
}
if self.document.metadata and self.document.is_sci:
if "abstract" in self.document.metadata:
if isinstance(self.document.metadata["abstract"], str):
self.document.metadata["abstract"] = re.sub(
r"<[^>]*>", "", self.document.metadata["abstract"]
)
arango_document["metadata"] = self.document.metadata
arango_document["summary"] = {
"text_sum": self.document.metadata["abstract"],
"meta": {"model": "from_metadata"},
}
arango_document["crossref"] = True
doc = arango_collection.insert(
arango_document, overwrite=True, overwrite_mode="update", keep_none=False
)
self.document._id = doc["_id"]
if "summary" not in arango_document:
# Make a summary in the background
self.document.make_summary_in_background()
return doc["_id"], key
def llm2metadata(self):
st.write("Extracting metadata using LLM...")
llm = LLM(
temperature=0.01,
system_message="You are an assistant helping a user to extract metadata from a scientific article.",
model="small",
max_length_answer=500,
)
text = pymupdf4llm.to_markdown(
self.document.pdf, page_chunks=False, show_progress=False, pages=[0, 1]
)
if len(self.document.pdf) == 1:
pages = [0]
prompt = f'''
Below is the beginning of an article. I want to know when it's published, the title, and the journal.
"""
{text}
"""
Answer ONLY with the information requested.
I want to know the published date on the form "YYYY-MM-DD".
I want the full title of the article and the journal.
Be sure to answer on the form "published_date;title;journal" as the answer will be used in a CSV.
If you can't find the information, answer "not_found".
'''
result = llm.generate(prompt)
print_blue(result)
if result == "not_found":
return None
else:
parts = result.split(";", 2)
if len(parts) != 3:
return None
published_date, title, journal = parts
if published_date == "not_found":
published_date = "[Unknown date]"
else:
try:
published_year = int(published_date.split("-")[0])
except:
published_year = None
if title == "not_found":
title = "[Unknown title]"
if journal == "not_found":
journal = "[Unknown publication]"
return {
"published_date": published_date,
"published_year": published_year,
"title": title,
"journal": journal,
}
def get_crossref(self, doi):
try:
print(f"Retrieving metadata for DOI {doi}...")
work = crossref.get_publication_as_json(doi)
print_green(f"Metadata retrieved for DOI {doi}.")
if "published-print" in work:
publication_date = work["published-print"]["date-parts"][0]
elif "published-online" in work:
publication_date = work["published-online"]["date-parts"][0]
elif "issued" in work:
publication_date = work["issued"]["date-parts"][0]
else:
publication_date = [None]
publication_year = publication_date[0]
metadata = {
"doi": work.get("DOI", None),
"title": work.get("title", [None])[0],
"authors": [
f"{author['given']} {author['family']}"
for author in work.get("author", [])
],
"abstract": work.get("abstract", None),
"journal": work.get("container-title", [None])[0],
"volume": work.get("volume", None),
"issue": work.get("issue", None),
"pages": work.get("page", None),
"published_date": "-".join(map(str, publication_date)),
"published_year": publication_year,
"url_doi": work.get("URL", None),
"link": (
work.get("link", [None])[0]["URL"]
if work.get("link", None)
else None
),
"language": work.get("language", None),
}
if "abstract" in metadata and isinstance(metadata["abstract"], str):
metadata["abstract"] = re.sub(r"<[^>]*>", "", metadata["abstract"])
self.document.metadata = metadata
self.document.is_sci = True
return metadata
except Exception as e:
if not self.document.is_sci:
self.document.is_sci = False
return None
def check_doaj(self, doi):
url = f"https://doaj.org/api/search/articles/{doi}"
response = requests.get(url)
if response.status_code == 200:
data = response.json()
if data.get("results", []) == []:
print(f"DOI {doi} not found in DOAJ.")
return False
else:
return data
else:
print(
f"Error fetching metadata for DOI from DOAJ: {doi}. HTTP Status Code: {response.status_code}"
)
return
def process_document(self):
assert self.document.pdf_file or self.document.pdf, "PDF file must be provided."
if not self.document.pdf:
self.document.open_pdf(self.document.pdf_file)
if self.document.is_image:
return pymupdf4llm.to_markdown(
self.document.pdf, page_chunks=False, show_progress=False
)
self.document.title = self.document.get_title()
if not self.document.doi and self.document.filename:
self.document.doi = self.extract_doi(self.document.filename)
if not self.document.doi:
text = ""
for page in self.document.pdf.pages(0, 6):
text += page.get_text()
self.document.doi = self.extract_doi(text)
if self.document.doi:
self.document._key = fix_key(self.document.doi)
if self.check_doaj(self.document.doi):
self.document.open_access = True
self.document.is_sci = True
self.document.metadata = self.get_crossref(self.document.doi)
if not self.document.is_sci:
self.document.is_sci = bool(self.document.metadata)
arango_collection = self.get_arango()
doc = arango_collection.get(self.document._key) if self.document.doi else None
if doc:
print_green(f"Document with key {self.document._key} already in database.")
self.document.doc = doc
crossref = self.get_crossref(self.document.doi)
if crossref:
self.document.doc["metadata"] = crossref
elif "metadata" not in doc or not doc["metadata"]:
self.document.doc["metadata"] = {
"title": self.document.get_title(only_meta=True)
}
elif 'title' not in doc['metadata']:
self.document.doc["metadata"]["title"] = self.document.get_title(only_meta=True)
if "user_access" not in doc or doc['user_access'] == None:
self.document.doc["user_access"] = [self.document.username]
else:
if self.document.username not in doc['user_access']:
self.document.doc["user_access"] = doc.get("user_access", []) + [
self.document.username
]
self.metadata = self.document.doc["metadata"]
arango_collection.update(self.document.doc)
return doc["_id"], arango_collection.db_name, self.document.doi
else:
self.document.doc = (
{"doi": self.document.doi, "_key": fix_key(self.document.doi)}
if self.document.doi
else {}
)
if self.document.doi:
if not self.document.metadata:
self.document.metadata = self.get_crossref(self.document.doi)
if self.document.metadata:
self.document.doc["metadata"] = self.document.metadata or {
"title": self.document.get_title(only_meta=True)
}
else:
self.document.doc["metadata"] = self.llm2metadata()
if self.document.get_title(only_meta=True):
self.document.doc["metadata"]["title"] = (
self.document.get_title(only_meta=True)
)
else:
self.document.doc["metadata"] = self.llm2metadata()
if self.document.get_title(only_meta=True):
self.document.doc["metadata"]["title"] = self.document.get_title(
only_meta=True
)
if "_key" not in self.document.doc:
_key = (
self.document.doi
or self.document.title
or self.document.get_title()
)
print_yellow(f"Document key: {_key}")
print(self.document.doi, self.document.title, self.document.get_title())
self.document.doc["_key"] = fix_key(_key)
self.document._key = fix_key(_key)
self.document.metadata = self.document.doc["metadata"]
if not self.document.text:
self.document.extract_text()
if self.document.doi:
self.document.doc["doi"] = self.document.doi
self.document.doc["doi"] = self.document.doi
self.document._key = fix_key(self.document.doi)
self.document.save_pdf(self.document_type)
self.document.make_chunks()
_id, key = self.chunks2arango()
self.chunks2chroma(_id=_id, key=key)
self._id = _id
return _id, arango_collection.db_name, self.document.doi
async def dl_pyppeteer(self, doi, url):
browser = await launch(
headless=True, args=["--no-sandbox", "--disable-setuid-sandbox"]
)
page = await browser.newPage()
await page.setUserAgent(
"Mozilla/5.0 (Macintosh; Intel Mac OS X x.y; rv:10.0) Gecko/20100101 Firefox/10.0"
)
await page.goto(url)
await page.waitFor(5000)
content = await page.content()
await page.pdf({"path": f"{doi}.pdf".replace("/", "_"), "format": "A4"})
await browser.close()
def doi2pdf(self, doi):
url = None
downloaded = False
path = None
in_db = False
sci_articles = self.get_arango(db_name="base", document_type="sci_articles")
if sci_articles.has(fix_key(doi)):
in_db = True
downloaded = True
doc = sci_articles.get(fix_key(doi))
url = doc["metadata"]["link"]
path = doc["file"]
print_green(f"Article {doi} already in database.")
return downloaded, url, doc["file"], in_db
doaj_data = self.check_doaj(doi)
sleep(0.5)
if doaj_data:
for link in doaj_data.get("bibjson", {}).get("link", []):
if "mdpi.com" in link["url"]:
r = requests.get(link["url"])
soup = BeautifulSoup(r.content, "html.parser")
pdf_link_html = soup.find("a", {"class": "UD_ArticlePDF"})
pdf_url = "https://www.mdpi.com" + pdf_link_html["href"]
pdf = requests.get(pdf_url)
path = f"sci_articles/{doi}.pdf".replace("/", "_")
with open(path, "wb") as f:
f.write(pdf.content)
self.process_document()
print(f"Downloaded PDF for {doi}")
downloaded = True
url = link["url"]
else:
downloaded = False
else:
metadata = self.get_crossref(doi)
if metadata:
url = metadata["link"]
else:
print(f"Error fetching metadata for DOI: {doi}")
return downloaded, url, path, in_db
class PDFProcessor(Processor):
def __init__(
self,
pdf_file=None,
filename=None,
chroma_db: str = "sci_articles",
document_type: str = None,
len_chunks: int = 2200,
local_chroma_deployment: bool = False,
process: bool = True,
doi=False,
username=None,
is_sci=None,
is_image=False,
):
self.document = Document(
pdf_file=pdf_file,
filename=filename,
doi=doi,
username=username,
is_sci=is_sci,
is_image=is_image,
)
super().__init__(
document=self.document,
filename=filename,
chroma_db=chroma_db,
len_chunks=len_chunks,
local_chroma_deployment=local_chroma_deployment,
process=process,
document_type=document_type,
)
if __name__ == "__main__":
doi = "10.1007/s10584-019-02646-9"
print(f"Processing article with DOI: {doi}")
ap = PDFProcessor(doi=doi, process=False)
print(f"Downloading article with DOI: {doi}")
ap.doi2pdf(doi)

@ -0,0 +1,4 @@
from PyPaperBot import __main__ as ppb
ppb.start(DOIs=["10.1016/j.jas.2021.103086", "10.1016/j.jas.2021.103087", "10.1016/j.jas.2021.103088"])

@ -0,0 +1,192 @@
country_emojis = {
"ad": "🇦🇩",
"ae": "🇦🇪",
"af": "🇦🇫",
"ag": "🇦🇬",
"ai": "🇦🇮",
"al": "🇦🇱",
"am": "🇦🇲",
"ao": "🇦🇴",
"aq": "🇦🇶",
"ar": "🇦🇷",
"as": "🇦🇸",
"at": "🇦🇹",
"au": "🇦🇺",
"aw": "🇦🇼",
"ax": "🇦🇽",
"az": "🇦🇿",
"ba": "🇧🇦",
"bb": "🇧🇧",
"bd": "🇧🇩",
"be": "🇧🇪",
"bf": "🇧🇫",
"bg": "🇧🇬",
"bh": "🇧🇭",
"bi": "🇧🇮",
"bj": "🇧🇯",
"bl": "🇧🇱",
"bm": "🇧🇲",
"bn": "🇧🇳",
"bo": "🇧🇴",
"bq": "🇧🇶",
"br": "🇧🇷",
"bs": "🇧🇸",
"bt": "🇧🇹",
"bv": "🇧🇻",
"bw": "🇧🇼",
"by": "🇧🇾",
"bz": "🇧🇿",
"ca": "🇨🇦",
"cc": "🇨🇨",
"cd": "🇨🇩",
"cf": "🇨🇫",
"cg": "🇨🇬",
"ch": "🇨🇭",
"ci": "🇨🇮",
"ck": "🇨🇰",
"cl": "🇨🇱",
"cm": "🇨🇲",
"cn": "🇨🇳",
"co": "🇨🇴",
"cr": "🇨🇷",
"cu": "🇨🇺",
"cv": "🇨🇻",
"cw": "🇨🇼",
"cx": "🇨🇽",
"cy": "🇨🇾",
"cz": "🇨🇿",
"de": "🇩🇪",
"dj": "🇩🇯",
"dk": "🇩🇰",
"dm": "🇩🇲",
"do": "🇩🇴",
"dz": "🇩🇿",
"ec": "🇪🇨",
"ee": "🇪🇪",
"eg": "🇪🇬",
"eh": "🇪🇭",
"er": "🇪🇷",
"es": "🇪🇸",
"et": "🇪🇹",
"fi": "🇫🇮",
"fj": "🇫🇯",
"fk": "🇫🇰",
"fm": "🇫🇲",
"fo": "🇫🇴",
"fr": "🇫🇷",
"ga": "🇬🇦",
"gb": "🇬🇧",
"gd": "🇬🇩",
"ge": "🇬🇪",
"gf": "🇬🇫",
"gg": "🇬🇬",
"gh": "🇬🇭",
"gi": "🇬🇮",
"gl": "🇬🇱",
"gm": "🇬🇲",
"gn": "🇬🇳",
"gp": "🇬🇵",
"gq": "🇬🇶",
"gr": "🇬🇷",
"gs": "🇬🇸",
"gt": "🇬🇹",
"gu": "🇬🇺",
"gw": "🇬🇼",
"gy": "🇬🇾",
"hk": "🇭🇰",
"hm": "🇭🇲",
"hn": "🇭🇳",
"hr": "🇭🇷",
"ht": "🇭🇹",
"hu": "🇭🇺",
"id": "🇮🇩",
"ie": "🇮🇪",
"il": "🇮🇱",
"im": "🇮🇲",
"in": "🇮🇳",
"io": "🇮🇴",
"iq": "🇮🇶",
"ir": "🇮🇷",
"is": "🇮🇸",
"it": "🇮🇹",
"je": "🇯🇪",
"jm": "🇯🇲",
"jo": "🇯🇴",
"jp": "🇯🇵",
"ke": "🇰🇪",
"kg": "🇰🇬",
"kh": "🇰🇭",
"ki": "🇰🇮",
"km": "🇰🇲",
"kn": "🇰🇳",
"kp": "🇰🇵",
"kr": "🇰🇷",
"kw": "🇰🇼",
"ky": "🇰🇾",
"kz": "🇰🇿",
"la": "🇱🇦",
"lb": "🇱🇧",
"lc": "🇱🇨",
"li": "🇱🇮",
"lk": "🇱🇰",
"lr": "🇱🇷",
"ls": "🇱🇸",
"lt": "🇱🇹",
"lu": "🇱🇺",
"lv": "🇱🇻",
"ly": "🇱🇾",
"ma": "🇲🇦",
"mc": "🇲🇨",
"md": "🇲🇩",
"me": "🇲🇪",
"mf": "🇲🇫",
"mg": "🇲🇬",
"mh": "🇲🇭",
"mk": "🇲🇰",
"ml": "🇲🇱",
"mm": "🇲🇲",
"mn": "🇲🇳",
"mo": "🇲🇴",
"mp": "🇲🇵",
"mq": "🇲🇶",
"mr": "🇲🇷",
"ms": "🇲🇸",
"mt": "🇲🇹",
"mu": "🇲🇺",
"mv": "🇲🇻",
"mw": "🇲🇼",
"mx": "🇲🇽",
"my": "🇲🇾",
"mz": "🇲🇿",
"na": "🇳🇦",
"nc": "🇳🇨",
"ne": "🇳🇪",
"nf": "🇳🇫",
"ng": "🇳🇬",
"ni": "🇳🇮",
"nl": "🇳🇱",
"no": "🇳🇴",
"np": "🇳🇵",
"nr": "🇳🇷",
"nu": "🇳🇺",
"nz": "🇳🇿",
"om": "🇴🇲",
"pa": "🇵🇦",
"pe": "🇵🇪",
"pf": "🇵🇫",
"pg": "🇵🇬",
"ph": "🇵🇭",
"pk": "🇵🇰",
"pl": "🇵🇱",
"pm": "🇵🇲",
"pn": "🇵🇳",
"pr": "🇵🇷",
"ps": "🇵🇸",
"pt": "🇵🇹",
"pw": "🇵🇼",
"py": "🇵🇾",
"qa": "🇶🇦",
"re": "🇷🇪",
"ro": "🇷🇴",
"rs": "🇷🇸",
}

@ -0,0 +1,53 @@
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from typing import Optional
from prompts import get_summary_prompt
from _llm import LLM
from _arango import ArangoDB
app = FastAPI()
class DocumentData(BaseModel):
text: str
arango_db_name: str
arango_id: str
is_sci: Optional[bool] = False
@app.post("/summarise_document")
async def summarize_document(doc_data: DocumentData, background_tasks: BackgroundTasks):
background_tasks.add_task(summarise_document_task, doc_data.dict())
return {"message": "Document summarization has started."}
def summarise_document_task(doc_data: dict):
text = doc_data.get("text")
is_sci = doc_data.get("is_sci", False)
system_message = "You are summarising scientific articles. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English."
llm = LLM(system_message=system_message)
summary = llm.generate(query=get_summary_prompt(text, is_sci))
summary_doc = {
"text_sum": summary,
"meta": {
"model": llm.model,
"system_message": system_message,
"temperature": llm.options["temperature"],
},
}
arango = ArangoDB(db_name=doc_data.get("arango_db_name"))
arango.db.update_document(
{"summary": summary_doc, "_id": doc_data.get("arango_id")},
silent=True,
check_rev=False,
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8100)

@ -0,0 +1,27 @@
from typing import Callable, Dict, Any, List
class ToolRegistry:
_tools = []
@classmethod
def register(cls, name: str, description: str, parameters: Dict[str, Any] = None):
def decorator(func: Callable):
cls._tools.append({
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters or {}
}
})
# No need for the wrapper since we're not adding any extra logic
return func
return decorator
@classmethod
def get_tools(cls, tools: list = None) -> List[Dict[str, Any]]:
if tools:
return [tool for tool in cls._tools if tool['function']['name'] in tools]
else:
return cls._tools

@ -0,0 +1,97 @@
import yaml
import sys
import bcrypt
from _arango import ArangoDB
import os
import dotenv
import getpass
dotenv.load_dotenv()
def read_yaml(file_path):
with open(file_path, "r") as file:
return yaml.safe_load(file)
def write_yaml(file_path, data):
with open(file_path, "w") as file:
yaml.safe_dump(data, file)
def add_user(data, username, email, name, password):
# Check for existing username
if username in data["credentials"]["usernames"]:
print(f"Error: Username '{username}' already exists.")
sys.exit(1)
# Check for existing email
for user in data["credentials"]["usernames"].values():
if user["email"] == email:
print(f"Error: Email '{email}' already exists.")
sys.exit(1)
# Hash the password using bcrypt
hashed_password = bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode(
"utf-8"
)
# Add the new user
data["credentials"]["usernames"][username] = {
"email": email,
"name": name,
"password": hashed_password,
}
def make_arango(username):
root_user = os.getenv("ARANGO_ROOT_USER")
root_password = os.getenv("ARANGO_ROOT_PASSWORD")
arango = ArangoDB(user=root_user, password=root_password, db_name="_system")
if not arango.db.has_database(username):
arango.db.create_database(
username,
users=[
{
"username": os.getenv("ARANGO_USER"),
"password": os.getenv("ARANGO_PASSWORD"),
"active": True,
"extra": {},
}
]
)
arango = ArangoDB(user=root_user, password=root_password, db_name=username)
for collection in ["projects", "favorite_articles", "article_collections", "settings", 'chats', 'notes', 'other_documents']:
if not arango.db.has_collection(collection):
arango.db.create_collection(collection)
user_arango = ArangoDB(db_name=username)
user_arango.db.collection("settings").insert(
{"current_page": 'Bot Chat', "current_project": None}
)
def main():
yaml_file = "streamlit_users.yaml"
if len(sys.argv) == 5:
username = sys.argv[1]
email = sys.argv[2]
name = sys.argv[3]
password = sys.argv[4]
else:
username = input("Enter username: ")
email = input("Enter email: ")
name = input("Enter name: ")
password = getpass.getpass("Enter password: ")
data = read_yaml(yaml_file)
add_user(data, username, email, name, password)
make_arango(username)
write_yaml(yaml_file, data)
print(f"User {username} added successfully.")
if __name__ == "__main__":
main()

@ -0,0 +1,187 @@
import re
#from _classes import Project
def description_string(project: "Project"):
if project.description != "":
description = f'The project is about "{project.description}".'
else:
description = ''
return description
def use_tools(use: bool = False):
if use:
return 'If you need to use a tool to fetch information, you can do that as well.'
else:
return ''
def get_assistant_prompt():
"""
Returns a multi-line string that serves as a prompt for a research assistant AI.
Returns:
str: The prompt for the research assistant AI.
"""
return """You are a research assistant chatting with a researcher. Only use the information from scientific articles you are provided with to answer questions.
The articles are ordered by publication date, so the first article is the oldest one. Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account.
Be sure to reference the source of the information with the number of the article inside square brackets (e.g. "<answer based on an article> [article number]").
If you have to reference the articles in running text, e.g. in a headline or the beginning of a bullet point, use the title of the article. Do NOT write something like "Article 1" but use the actual title of the article.
You should not write a reference section as this will be added later.
Format your answers in Markdown format. """
def get_editor_prompt(project: "Project", tools: bool = False):
"""Generates a coaching prompt for an editor to assist a reporter with a specific project.
Args:
project (dict): A dictionary containing information about the project the reporter is working on. The dictionary should have the following keys:
- name (str): The name of the project.
- description (str): The description of the project.
- notes_summary (str): A summary of notes about the project and
Returns:
str: A formatted string containing the coaching prompt for the editor."""
if project.notes_summary:
notes_string = f'''Here are other important things you should know about the project and the topic:
"""
{project.notes_summary}
"""
'''
else:
notes_string = ''
return f'''You are an editor coaching a journalist who is working on the project "{project.name}". {description_string(project)}
{notes_string}
When writing with the reporter you will also get other information, like excerpts from articles and other documents. Use the notes to put the information in context and help the reporter to move forward.
The project is a journalistic piece, so it is important that you help the reporter to be critical of the sources and to provide a balanced view of the topic.
Be sure to understand what the reporter is asking and provide the information in a way that is helpful for the reporter to move forward. Try to understand if the reporter is asking for a specific piece of information or if they are looking for guidance on how to move forward, or just want to discuss the topic.
If you need more information to answer the question, try to get it.
'''
def get_chat_prompt(user_input, content_string, role):
if role == "Research Assistant":
prompt = f'''{user_input}
Below are snippets from different articles, often with title and date of publication.
ONLY use the information below to answer the question. Do not use any other information.
"""
{content_string}
"""
Remember:
- Reference the source of the information with the number of the article inside square brackets (e.g. "<answer based on an article> [article number]").
- If you have to reference the articles in running text, e.g. in a headline or the beginning of a bullet point, use the title of the article. Do NOT write something like "Article 1" but use the actual title of the article.
- The articles are ordered by publication date, so the first article is the oldest one. Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account.
{user_input}
'''
elif role == "Editor":
prompt = f'''The reporter has asked: "{user_input}". Try to answer the question or provide guidance based on the information below, and your knowledge of the project.
"""
{content_string}
"""
Remember:
- Sometimes the articles might contain conflicting information, in that case be clear about the conflict and provide both sides of the argument with publication dates taken into account.
- If you think you need more information to answer the question, ask the reporter for more context.
{user_input}
'''
elif role == "Guest":
prompt = f'''The podcast host has asked: "{user_input}". Try to answer the question based on the information below.
"""
{content_string}
"""
Remember:
- Answer in a way that is easy to understand for a general audience.
- Only answer based on the information above.
- Answer in a "spoken" way, formatting the answer as if you were speaking it out loud.
{user_input}
'''
elif role == "Host":
prompt = f'''The expert has stated: "{user_input}". Try to come up with a new question based on the information below.
"""
{content_string}
"""
Remember:
- The information above is the context for the expert's statement, so the new question should be relevant to that context, as well as the conversation as a whole.
- You are a critical journalist, so try to come up with a question that challenges the expert's statement or asks for more information.
- Make sure not to repeat yourself! Check what questions you have already asked to avoid repetition.
'''
return prompt
def get_query_builder_system_message():
system_message = """
Take the user input and write it as a sentence that could be used as a query for a vector database.
The vector database will return text snippets that semantically match the query, so you CAN'T USE NEGATIONS or other complex language constructs. If there is a negation in the user input, exclude that part from the query.
If the user input seems to be a follow-up question or comment, use the context from the chat history to make a relevant query.
Answer ONLY with the query, no explanation or reasoning!
"""
return re.sub(r"\s*\n\s*", "\n", system_message)
def get_note_summary_prompt(project: "Project", notes_string: str):
query = f'''
Below are notes from a project called "{project.name}". {description_string(project)}.
"""
{notes_string}
"""
I want you to summarize the notes in a concise manner. The summary will be used to create a system message for chatting with LLMs about the project.
Make sure to include the most important points, and include any key terms or concepts.
Try to gather at least something from each note, but don't reference the notes directly.
Answer ONLY with the summary, nothing else.
'''
return re.sub(r"\s*\n\s*", "\n", query)
def get_image_system_prompt(project: "Project"):
system_message = f"""
You are an assistant to a journalist who is working on a project called {project.name}. Your task is to analyze and describe the images that are part of the project.
The images you get might show a graph, chart or other visual representation of data. If so, answer in a way that describes the data and the trends or patterns that can be seen in the image.
The images might also show a photo or illustration, in that case describe the content of the image in a way that could be used in a caption.
- Don't include any information that is not visible in the image.
- Don't focus to much on the different parts of the image/figure, but rather on the meaning of it.
- Answer ONLY with the description of the image, nothing else argumenting or explaining.
"""
return re.sub(r"\s*\n\s*", "\n", system_message)
def get_tools_prompt(user_input):
return f'''The reporter has asked: "{user_input}"
What information is needed to answer the question? Choose one or many tools in order to answer the question. Make sure to read the description of the tools carefully before choosing.
If you are shure that you can answer the question in a correct way without fetching data, you can do that as well.
'''
def get_summary_prompt(text, is_sci):
text = re.sub(r"\s*\n\s*", "\n", text)
if is_sci:
s = 'The text will be used as an abstract for a scientific article. Make sure to include the most important results and findings, as well with relevant information about methods, data and conclusions. Keep the summary concise.'
else:
s = 'Make sure to include the most important points and facts, and to keep the summary concise.'
prompt = f'''Summarise the text below:
"""
{text}
"""
{s}
Use ONLY the information from the text to write the summary. Do not include any additional information or your own interpretation.
Answer ONLY with the summary, nothing else like reasoning or explanation.
'''
return re.sub(r"\s*\n\s*", "\n", prompt)
def get_generate_vector_query_prompt(user_input: str, role: str):
print(role.upper())
if role in ["Research Assistant", "Editor"]:
query = f"""A user asked this question: "{user_input}". Generate a query for the vector database. Make sure to follow the instructions you got earlier!"""
elif role == "Guest":
query = f"""A podcast host has asked this question in an interview: "{user_input}". Generate a query for the vector database to answer the actial question. Make sure to follow the instructions you got earlier!"""
elif role == "Host":
query = f"""An expert has stated: "{user_input}". Generate a query for the vector database to get context for that answer in order to come up with a new question. Make sure to follow the instructions you got earlier!"""
return query

@ -0,0 +1,10 @@
from _arango import ArangoDB
from _chromadb import ChromaDB
arango = ArangoDB(db_name='lasse')
admin_arango = ArangoDB(db_name='base')
arango.db.collection('other_documents').truncate()
chroma = ChromaDB()
chroma.db.delete_collection("other_documents")
chroma.db.create_collection("other_documents")

@ -0,0 +1,87 @@
import streamlit as st
import streamlit_authenticator as stauth
import yaml
from yaml.loader import SafeLoader
from streamlit_authenticator import LoginError
from time import sleep
from colorprinter.print_color import *
from _arango import ArangoDB
def get_settings():
"""
Function to get the settings from the ArangoDB.
"""
arango = ArangoDB(db_name=st.session_state["username"])
st.session_state["settings"] = arango.db.collection("settings").get("settings")
return st.session_state["settings"]
st.set_page_config(page_title="Science Assistant Fish", page_icon="🐟")
with open("streamlit_users.yaml") as file:
config = yaml.load(file, Loader=SafeLoader)
authenticator = stauth.Authenticate(
config["credentials"],
config["cookie"]["name"],
config["cookie"]["key"],
config["cookie"]["expiry_days"],
)
try:
authenticator.login()
except LoginError as e:
st.error(e)
if st.session_state["authentication_status"]:
sleep(0.1)
# Retry mechanism for importing get_settings
for _ in range(3):
try:
get_settings()
except ImportError as e:
sleep(0.3)
print_red(e)
print("Retrying to import get_settings...")
# Retry mechanism for importing pages
for _ in range(3):
try:
from streamlit_pages import Article_Collections, Bot_Chat, Projects, Settings
break
except ImportError as e:
# Write the full error traceback
sleep(0.3)
print_red(e)
print("Retrying to import pages...")
get_settings()
if 'current_page' in st.session_state["settings"]:
st.session_state["current_page"] = st.session_state["settings"]["current_page"]
else:
if 'current_page' not in st.session_state:
st.session_state["current_page"] = None
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
# Pages
bot_chat = st.Page(Bot_Chat)
projects = st.Page(Projects)
article_collections = st.Page(Article_Collections)
settings = st.Page(Settings)
pg = st.navigation([bot_chat, projects, article_collections, settings])
pg.run()
with st.sidebar:
st.write("---")
authenticator.logout()
elif st.session_state["authentication_status"] is False:
st.error("Username/password is incorrect")
elif st.session_state["authentication_status"] is None:
st.warning("Please enter your username and password")

@ -0,0 +1,726 @@
from datetime import datetime
import streamlit as st
from _base_class import BaseClass
from _llm import LLM
from prompts import *
from colorprinter.print_color import *
from llm_tools import ToolRegistry
class Chat(BaseClass):
def __init__(self, username: str, role: str, **kwargs):
super().__init__(username=username, **kwargs)
self.name = kwargs.get("name", None)
self.chat_history = kwargs.get("chat_history", [])
self.role = role
def add_message(self, role, content):
self.chat_history.append(
{"role": role, "content": content.strip().strip('"'), "role_type": self.role}
)
def to_dict(self):
return {
"name": self.name,
"chat_history": self.chat_history,
"role": self.role,
}
def show_chat_history(self):
for message in self.chat_history:
if message["role"] not in ["user", "assistant"]:
continue
avatar = self.get_avatar(message)
with st.chat_message(message["role"], avatar=avatar):
if message["content"]:
st.markdown(message["content"].strip('"'))
def get_avatar(self, message: dict = None, role=None) -> str:
assert message or role, "Either message or role must be provided"
if message and message.get("role", None) == "user" or role == "user":
avatar = st.session_state["settings"].get("avatar", "user")
elif (
message and message.get("role", None) == "assistant" or role == "assistant"
):
role_type = message.get("role_type", self.role) if message else self.role
if role_type == "Research Assistant":
avatar = "img/avatar_researcher.png"
elif role_type == "Editor":
avatar = "img/avatar_editor.png"
elif role_type == "Host":
avatar = "img/avatar_host.png"
elif role_type == "Guest":
avatar = "img/avatar_guest.png"
else:
avatar = None
else:
avatar = None
return avatar
def set_name(self, user_input):
llm = LLM(
model="small",
max_length_answer=50,
temperature=0.4,
system_message="You are a chatbot who will be chatting with a user",
)
prompt = (
f'Give a short name to the chat based on this user input: "{user_input}" '
"No more than 30 characters. Answer ONLY with the name of the chat."
)
name = llm.generate(prompt)
print_blue("Chat Name", name)
name = f'{name} - {datetime.now().strftime("%B %d")}'
# Check if the chat name already exists
existing_chat = self.user_arango.db.aql.execute(
f'FOR doc IN chats FILTER doc.name == "{name}" RETURN doc', count=True
)
if existing_chat.count() > 0:
name = f'{name} ({datetime.now().strftime("%H:%M")})'
name += f" - [{self.role}]"
self.name = name
return name
@classmethod
def from_dict(cls, data):
return cls(
name=data.get("name"),
chat_history=data.get("chat_history"),
role=data.get("role", "Research Assistant"),
)
class Bot(BaseClass):
def __init__(self, chat: Chat, username: str, **kwargs):
super().__init__(username=username, **kwargs)
self.chat = chat
self.project = kwargs.get("project", None)
self.collection: list = kwargs.get(
"collection", st.session_state["settings"]["current_collection"]
)
if not self.collection and self.project:
self.collection = self.project.collections
if not isinstance(self.collection, list):
self.collection = [self.collection]
# Load articles in the collections
self.arango_ids = []
for collection in self.collection:
for _id in self.user_arango.db.aql.execute(
'''
FOR doc IN article_collections
FILTER doc.name == @collection
FOR article IN doc.articles
RETURN article._id
''',
bind_vars={"collection": collection},
):
self.arango_ids.append(_id)
self.chosen_backend = kwargs.get("chosen_backend", None)
self.chatbot: LLM = None
self.tools: list[dict] = None
self.chatbot_memory = None
self.helperbot_memory = None
# Initialize LLM instances
self.helperbot = LLM(
temperature=0,
model="small",
max_length_answer=500,
system_message=get_query_builder_system_message(),
messages=self.helperbot_memory,
)
self.toolbot = LLM(
temperature=0,
system_message="Choose one or many tools to use in order to assist the user. Make sure to read the description of the tools carefully.",
chat=False,
model="small",
)
# self.sidebar_content()
def sidebar_content(self):
with st.sidebar:
st.write("---")
st.markdown(f'#### {self.chat.name if self.chat.name else ""}')
st.button("Delete this chat", on_click=self.delete_chat)
def delete_chat(self):
self.user_arango.db.collection("chats").delete_match(
filters={"name": self.chat.name}
)
self.chat = Chat()
def get_chunks(
self,
user_input,
collections=["sci_articles", "other_documents"],
n_results=7,
n_sources=4,
filter=True,
):
if not isinstance(n_sources, int):
n_sources = int(n_sources)
if not isinstance(n_results, int):
n_results = int(n_results)
# Use self.chat.chat_history if needed
if self.chat.role == "Editor":
n_results = 4
n_sources = 3
query = self.helperbot.generate(
get_generate_vector_query_prompt(user_input, self.chat.role)
)
print_rainbow("Vector query:", query)
combined_chunks = []
for collection in collections:
where_filter = {"_id": {"$in": self.arango_ids}} if filter else {}
chunks = self.get_chromadb().query(
query=query,
collection=collection,
n_results=n_results,
n_sources=n_sources,
where=where_filter,
max_retries=3,
)
for doc, meta, dist in zip(
chunks["documents"][0],
chunks["metadatas"][0],
chunks["distances"][0],
):
combined_chunks.append(
{"document": doc, "metadata": meta, "distance": dist}
)
combined_chunks.sort(key=lambda x: x["distance"])
sources = set()
closest_chunks = []
for chunk in combined_chunks:
source_id = chunk["metadata"]["_id"]
if source_id not in sources:
sources.add(source_id)
closest_chunks.append(chunk)
if len(sources) >= n_sources:
break
if len(closest_chunks) < n_results:
remaining_chunks = [
chunk for chunk in combined_chunks if chunk not in closest_chunks
]
closest_chunks.extend(remaining_chunks[: n_results - len(closest_chunks)])
for chunk in closest_chunks:
_id = chunk["metadata"]["_id"]
if _id.split("/")[0] == "sci_articles":
arango_doc = self.get_arango(admin=True).db.document(_id)
arango_metadata = arango_doc.get("metadata", {}) if arango_doc else {}
else:
arango_doc = self.user_arango.db.document(_id)
arango_metadata = arango_doc.get("metadata", {}) if arango_doc else {}
chunk["metadata"] = arango_metadata or {
"title": "No title",
"published_date": "No published date",
"journal": "No journal",
}
for k in ["published_date", "journal", "title"]:
chunk["metadata"].setdefault(k, f"No {k}")
sorted_chunks = sorted(
closest_chunks,
key=lambda x: (
x["metadata"]["published_date"],
x["metadata"]["title"],
),
)
grouped_chunks = {}
article_number = 1
for chunk in sorted_chunks:
title = chunk["metadata"]["title"]
chunk["article_number"] = article_number
if title not in grouped_chunks:
grouped_chunks[title] = {
"article_number": article_number,
"chunks": [],
}
article_number += 1
grouped_chunks[title]["chunks"].append(chunk)
return grouped_chunks
def process_user_input(self, user_input):
# Add user's message to chat history
self.chat.add_message("user", user_input)
# Generate response with tool support
prompt = get_tools_prompt(user_input)
response = self.toolbot.generate(prompt, tools=self.tools, stream=False)
print_yellow("Tool to use")
# Check if the LLM wants to use a tool
if isinstance(response, dict) and "tool_calls" in response:
bot_response = self.answer_tool_call(response, user_input)
else:
# Use the LLM's direct response
bot_response = response.strip('"')
with st.chat_message(
"assistant", avatar=self.chat.get_avatar(role="assitant")
):
st.write(bot_response)
# Add assistant's message to chat history
if self.chat.chat_history[-1]["role"] != "assistant":
self.chat.add_message("assistant", bot_response)
self.chatbot_memory = self.chatbot.messages
self.helperbot_memory = self.helperbot.messages
# Save the chat data without heavy objects
chat_data = self.chat.to_dict()
self.user_arango.db.collection("chats").update_match(
filters={"name": self.chat.name},
body={
"chat_data": chat_data,
"chatbot_memory": self.chatbot_memory,
"helperbot_memory": self.helperbot_memory,
},
)
def answer_tool_call(self, response, user_input):
bot_responses = []
tool_calls = response["tool_calls"]
for tool_call in tool_calls:
function = tool_call["function"]
function_name = function["name"]
arguments = function.get("arguments", {})
arguments["query"] = user_input
# Find and execute the tool function
with st.chat_message(
"assistant", avatar=self.chat.get_avatar(role="assistant")
):
if function_name in [
"fetch_other_documents",
"fetch_science_articles",
"fetch_science_articles_and_other_documents",
]:
chunks = getattr(self, function_name)(**arguments)
# Provide the tool's output back to the LLM
response = self.generate_from_chunks(user_input, chunks)
bot_response = st.write_stream(response)
bot_response = bot_response.strip('"')
if len(chunks) > 0:
sources = "###### Sources: \n"
for title, group in chunks.items():
sources += (
f"[{group['article_number']}] **{title}** "
f":gray[{group['chunks'][0]['metadata']['journal']} "
f"({group['chunks'][0]['metadata']['published_date']})] \n"
)
st.markdown(sources)
bot_response = f"{bot_response}\n\n{sources}"
bot_responses.append(bot_response)
elif function_name == "fetch_notes":
notes = getattr(self, function_name)()
response = self.generate_from_notes(user_input, notes)
bot_response = st.write_stream(response)
bot_responses.append(bot_response.strip('"'))
elif function_name == "conversational_response":
response = getattr(self, function_name)(user_input)
bot_response = st.write_stream(response)
bot_responses.append(bot_response.strip('"'))
return "\n\n".join(bot_responses)
def generate_from_notes(self, user_input, notes):
notes_string = ""
for note in notes:
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n"
prompt = get_chat_prompt(user_input, notes_string, role=self.chat.role)
with st.spinner("Reading project notes..."):
return self.chatbot.generate(prompt, stream=True)
def generate_from_chunks(self, user_input, chunks):
chunks_string = ""
for title, group in chunks.items():
chunks_content_string = "\n(...)\n".join(
[chunk["document"] for chunk in group["chunks"]]
)
chunks_string += (
f"\n# {title}\n"
f"## Article number: {group['article_number']}\n"
f"## {group['chunks'][0]['metadata']['published_date']} in "
f"{group['chunks'][0]['metadata']['journal']}\n"
f"{chunks_content_string}\n---\n"
)
prompt = get_chat_prompt(user_input, chunks_string, role=self.chat.role)
magazines = list(
set(
[
f"*{group['chunks'][0]['metadata']['journal']}*"
for group in chunks.values()
if "metadata" in group["chunks"][0]
]
)
)
if len(magazines) > 0:
s = f"Reading articles from {', '.join(magazines[:-1])} and {magazines[-1]}..."
else:
s = "Reading articles..."
with st.spinner(s):
return (
self.chatbot.generate(prompt, stream=True)
if self.chatbot
else self.llm.generate(prompt, stream=True)
)
def run(self):
if not hasattr(st.session_state, "bot"):
st.session_state.bot = self
# Display chat history
self.chat.show_chat_history()
if user_input := st.chat_input("Write your message here..."):
with st.chat_message("user", avatar=self.chat.get_avatar(role="user")):
st.write(user_input)
if not self.chat.name:
self.chat.set_name(user_input)
existing_chat = self.user_arango.db.aql.execute(
f'FOR doc IN chats FILTER doc.name == "{self.chat.name}" RETURN doc',
count=True,
)
if existing_chat.count() == 0:
chat_data = self.chat.to_dict()
chat_doc = self.user_arango.db.collection("chats").insert(
{
"name": self.chat.name,
"collection": self.collection,
"project": self.project.name if self.project else None,
"last_updated": datetime.now().isoformat(),
"saved": False,
"chat_data": chat_data,
}
)
self.chat_key = chat_doc["_key"]
self.process_user_input(user_input)
self.update_session_state()
def get_notes(self):
notes = self.user_arango.db.aql.execute(
f'FOR doc IN notes FILTER doc.project == "{self.project.name}" RETURN doc'
)
return list(notes)
# Register tools
@ToolRegistry.register(
name="fetch_science_articles",
description="Fetches information from scientific articles. Use this tool when the user is looking for information from scientific articles.",
parameters={
"type": "object",
"properties": {
"n_documents": {
"type": "integer",
"description": "How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.",
}
},
"required": ["n_documents"],
},
)
def fetch_science_articles(self, query: str, n_documents: int):
return self.get_chunks(
query, collections=["sci_articles"], n_results=n_documents
)
@ToolRegistry.register(
name="fetch_other_documents",
description="Fetches information from other documents based on the user's query. Other documents can include reports, news articles, and other kinds of texts. Use this tool only when it's obvious that the user is not looking for scientific articles.",
parameters={
"type": "object",
"properties": {
"n_documents": {
"type": "integer",
"description": "How many documents to fetch. A complex query may require more documents. Min: 2, Max: 10.",
}
},
"required": ["n_documents"],
},
)
def fetch_other_documents(self, query: str, n_documents: int):
return self.get_chunks(
query, collections=["other_documents"], n_results=n_documents
)
@ToolRegistry.register(
name="fetch_science_articles_and_other_documents",
description="Fetches information from both scientific articles and other documents. This is often used when the user hasn't specified what kind of sources they are interested in.",
parameters={
"type": "object",
"properties": {
"n_documents": {
"type": "integer",
"description": "How many documents to fetch. A complex query may require more documents. Min: 3, Max: 10.",
}
},
"required": ["n_documents"],
},
)
def fetch_science_articles_and_other_documents(self, query: str, n_documents: int):
return self.get_chunks(
query,
collections=["sci_articles", "other_documents"],
n_results=n_documents,
)
@ToolRegistry.register(
name="fetch_notes",
description="Fetches information from the project notes when you as an editor need context from the project notes to understand other information. ONLY use this together with other tools!",
)
def fetch_notes(self):
return self.get_notes()
@ToolRegistry.register(
name="conversational_response",
description="Generates a conversational response without fetching data. Use this ONLY if it is obvious that the user is not looking for information but only wants to chat.",
)
def conversational_response(self, query: str):
query = f'User message: "{query}". Make your answer short and conversational. Include a very brief description of the project if you think that would be helpful.'
result = (
self.chatbot.generate(query, stream=True)
if self.chatbot
else self.llm.generate(query, stream=True)
)
return result
class EditorBot(Bot):
def __init__(self, chat: Chat, username: str, **kwargs):
super().__init__(chat=chat, username=username, **kwargs)
self.role = "Editor"
self.tools = ToolRegistry.get_tools(
tools=[
"fetch_notes",
"conversational_response",
"fetch_science_articles",
"fetch_other_documents",
"fetch_science_articles_and_other_documents",
]
)
self.chatbot = LLM(
system_message=get_editor_prompt(kwargs.get("project")),
messages=self.chatbot_memory,
chosen_backend=kwargs.get("chosen_backend"),
)
class ResearchAssistantBot(Bot):
def __init__(self, chat: Chat, username: str, **kwargs):
super().__init__(chat=chat, username=username, **kwargs)
self.role = "Research Assistant"
self.chatbot = LLM(
system_message=get_assistant_prompt(),
temperature=0.1,
messages=self.chatbot_memory,
)
self.tools = ToolRegistry.get_tools(
tools=[
"fetch_science_articles",
"fetch_other_documents",
"fetch_science_articles_and_other_documents",
]
)
class PodBot(Bot):
"""Two LLM agents construct a conversation using material from science articles."""
def __init__(
self,
chat: Chat,
subject: str,
username: str,
instructions: str = None,
**kwargs,
):
super().__init__(chat=chat, username=username, **kwargs)
self.subject = subject
self.instructions = instructions
self.guest_name = kwargs.get("name_guest", "Merit")
self.hostbot = HostBot(
Chat(username=self.username, role="Host"), subject, username, instructions=instructions, **kwargs
)
self.guestbot = GuestBot(
Chat(username=self.username, role="Guest"),
subject,
username,
name_guest=self.guest_name,
**kwargs,
)
def run(self):
notes = self.get_notes()
notes_string = ""
if self.instructions:
instructions_string = f'''
These are the instructions for the podcast from the producer:
"""
{self.instructions}
"""
'''
else:
instructions_string = ""
for note in notes:
notes_string += f"\n# {note['title']}\n{note['content']}\n---\n"
a = f'''You will make a podcast interview with {self.guest_name}, an expert on "{self.subject}".
{instructions_string}
Below are notes on the subject that you can use to ask relevant questions:
"""
{notes_string}
"""
Say hello to the expert and start the interview. Remember to keep the interview to the subject of {self.subject} throughout the conversation.
'''
with st.sidebar:
stop = st.button("Stop the podcast")
if stop:
st.session_state["make_podcast"] = False
while st.session_state["make_podcast"]:
# Stop the podcast if there are more than 14 messages in the chat
self.chat.show_chat_history()
if len(self.chat.chat_history) == 14:
result = self.hostbot.generate(
"The interview has ended. Say thank you to the expert and end the conversation."
)
self.chat.add_message("Host", result)
with st.chat_message(
"assistant", avatar=self.chat.get_avatar(role="assistant")
):
st.write(result.strip('"'))
st.stop()
_q = self.hostbot.toolbot.generate(
query=f"{self.guest_name} has answered: {a}. You have to choose a tool to help the host continue the interview.",
tools=self.hostbot.tools,
temperature=0.6,
stream=False,
)
if "tool_calls" in _q:
print_yellow("Tool call response (host)", _q)
print_purple("HOST", self.hostbot.chat.role)
q = self.hostbot.answer_tool_call(_q, a)
else:
q = _q
self.chat.add_message("Host", q)
_a = self.guestbot.toolbot.generate(
f'The podcast host has asked: "{q}" Choose a tool to help the expert answer with relevant facts and information.',
tools=self.guestbot.tools,
)
if "tool_calls" in _a:
print_yellow("Tool call response (guest)", _a)
print_yellow(self.guestbot.chat.role)
a = self.guestbot.answer_tool_call(_a, q)
else:
a = _a
self.chat.add_message("Guest", a)
self.update_session_state()
class HostBot(Bot):
def __init__(self, chat: Chat, subject: str, username: str, instructions: str, **kwargs):
super().__init__(chat=chat, username=username, **kwargs)
self.chat.role = kwargs.get("role", "Host")
self.tools = ToolRegistry.get_tools(
tools=[
"fetch_notes",
"conversational_response",
"fetch_other_documents",
]
)
self.instructions = instructions
self.llm = LLM(
system_message=f'''
You are the host of a podcast and an expert on {subject}. You will ask one question at a time about the subject, and then wait for the answer.
Don't ask the guest to talk about herself/himself, only about the subject.
These are the instructions for the podcast from the producer:
"""
{self.instructions}
"""
If the experts' answer is complicated, try to make a very brief summary of it for the audience to understand. You can also ask follow-up questions to clarify the answer, or ask for examples.
'''
)
self.toolbot = LLM(
temperature=0,
system_message='''
You are assisting a podcast host in asking questions to an expert.
Choose one or many tools to use in order to assist the host in asking relevant questions.
Often "conversational_response" is enough, but sometimes notes are needed or even other documents.
Make sure to read the description of the tools carefully!''',
chat=False,
model="small",
)
def generate(self, query):
return self.llm.generate(query)
class GuestBot(Bot):
def __init__(self, chat: Chat, subject: str, username: str, **kwargs):
super().__init__(chat=chat, username=username, **kwargs)
self.chat.role = kwargs.get("role", "Guest")
self.tools = ToolRegistry.get_tools(
tools=[
"fetch_notes",
"fetch_science_articles",
]
)
self.llm = LLM(
system_message=f"""
You are {kwargs.get('name', 'Merit')}, an expert on {subject}.
Today you are a guest in a podcast about {subject}. A host will ask you questions about the subject and you will answer by using scientific facts and information.
Try to be concise when answering, and remember that the audience of the podcast is not expert on the subject, so don't complicate things too much.
It's very important that you answer in a "spoken" way, as if you were talking to someone in a conversation. That means you should avoid using scientific jargon and complex terms, too many figures or abstract concepts.
Lists are also not recommended, instead use "for the first reason", "secondly", etc.
Instead, use "..." to indicate a pause, "-" to indicate a break in the sentence, as if you were speaking.
"""
)
self.toolbot = LLM(
temperature=0,
system_message=f"You are an assistant to an expert on {subject}. Choose one or many tools to use in order to assist the expert in answering questions. Make sure to read the description of the tools carefully.",
chat=False,
model="small",
)
def generate(self, query):
return self.llm.generate(query)

@ -0,0 +1,44 @@
# streamlit_pages.py
import streamlit as st
from colorprinter.print_color import *
from time import sleep
def Projects():
"""
Function to handle the Projects page.
"""
from _classes import ProjectsPage
if 'Projects' not in st.session_state:
st.session_state['Projects'] = {}
projectpage = ProjectsPage(username=st.session_state["username"])
projectpage.run()
def Bot_Chat():
"""
Function to handle the Chat Bot page.
"""
from _classes import BotChatPage
if 'bot_chat_page' not in st.session_state:
st.session_state['Bot Chat'] = {}
chatpage = BotChatPage(username=st.session_state["username"])
chatpage.run()
def Article_Collections():
"""
Function to handle the Article Collections page.
"""
from _classes import ArticleCollectionsPage
if 'article_collections' not in st.session_state:
st.session_state['Article Collections'] = {}
article_collection = ArticleCollectionsPage(username=st.session_state["username"])
article_collection.run()
def Settings():
"""
Function to handle the Settings page.
"""
from _classes import SettingsPage
settings = SettingsPage(username=st.session_state["username"])
settings.run()

File diff suppressed because one or more lines are too long

@ -0,0 +1,31 @@
from TTS.api import TTS
import torch
from datetime import datetime
tts = TTS("tts_models/en/multi-dataset/tortoise-v2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tts.to(device)
text="There is, therefore, an increasing need to understand BEVs from a systems perspective. This involves an in-depth consideration of the environmental impact of the product using life cycle assessment (LCA) as well as taking a broader 'circular economy' approach. On the one hand, LCA is a means of assessing the environmental impact associated with all stages of a product's life from cradle to grave: from raw material extraction and processing to the product's manufacture to its use in everyday life and finally to its end of life."
# cloning `lj` voice from `TTS/tts/utils/assets/tortoise/voices/lj`
# with custom inference settings overriding defaults.
time_now = datetime.now().strftime("%Y%m%d%H%M%S")
output_path = f"output/tortoise_{time_now}.wav"
tts.tts_to_file(text,
file_path=output_path,
voice_dir="voices",
speaker="test",
split_sentences=False, # Change to True if context is not enough
num_autoregressive_samples=20,
diffusion_iterations=50)
# # Using presets with the same voice
# tts.tts_to_file(text,
# file_path="output.wav",
# voice_dir="path/to/tortoise/voices/dir/",
# speaker="lj",
# preset="ultra_fast")
# # Random voice generation
# tts.tts_to_file(text,
# file_path="output.wav")

@ -0,0 +1,51 @@
from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
from fairseq.models.text_to_speech.hub_interface import TTSHubInterface
from fairseq import utils
import nltk
import torch
# Download the required NLTK resource
nltk.download('averaged_perceptron_tagger')
# Model loading
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"facebook/fastspeech2-en-ljspeech",
arg_overrides={"vocoder": "hifigan", "fp16": False}
)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move all models to the correct device
for model in models:
model.to(device)
# Update configuration and build generator after moving models
TTSHubInterface.update_cfg_with_data_cfg(cfg, task.data_cfg)
generator = task.build_generator(models, cfg)
# Ensure the vocoder is on the correct device
generator.vocoder.model.to(device)
# Define your text
text = """Hi there, thanks for having me! My interest in electric cars really started back when I was a teenager..."""
# Convert text to model input
sample = TTSHubInterface.get_model_input(task, text)
# Recursively move all tensors in sample to the correct device
sample = utils.move_to_cuda(sample) if torch.cuda.is_available() else sample
# Generate speech
wav, rate = TTSHubInterface.get_prediction(task, models[0], generator, sample)
from scipy.io.wavfile import write
# If wav is a tensor, convert it to a NumPy array
if isinstance(wav, torch.Tensor):
wav = wav.cpu().numpy()
# Save the audio to a WAV file
write('output_fair.wav', rate, wav)

@ -0,0 +1,45 @@
import torch
from TTS.api import TTS
from datetime import datetime
# Get device
from TTS.tts.utils.speakers import SpeakerManager
device = "cuda" if torch.cuda.is_available() else "cpu"
# Init TTS
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
exit()
text = """Hi there, thanks for having me! My interest in electric cars really started back when I was a teenager. I remember learning about the history of EVs and how they've been around since the late 1800s, even before gasoline cars took over. The fact that these vehicles could run on electricity instead of fossil fuels just fascinated me.
Then, in the 90s, General Motors introduced the EV1 - it was a real game-changer. It showed that electric cars could be practical and enjoyable to drive. And when Tesla came along with their Roadster in 2007, proving that EVs could have a long range, I was hooked.
But what really sealed my interest was learning about the environmental impact of EVs. They produce zero tailpipe emissions, which means they can help reduce air pollution and greenhouse gas emissions. That's something I'm really passionate about.
"""
text_se = """Antalet bilar ger dock bara en del av bilden. För att förstå bilberoendet bör vi framför allt titta på hur mycket bilarna faktiskt används.
Stockholmarnas genomsnittliga körsträcka med bil har minskat sedan millennieskiftet. Den är dock lägre i Göteborg och i Malmö.
I procent har bilanvändningen sedan år 2000 minskat lika mycket i Stockholm och Malmö, 9 procent. I Göteborg är minskningen 13 procent, i riket är minskningen 7 procent."""
# Run TTS
# ❗ Since this model is multi-lingual voice cloning model, we must set the target speaker_wav and language
# Text to speech list of amplitude values as output
#wav = tts.tts(text=text, speaker_wav="my/cloning/audio.wav", language="en")
# Text to speech to a file
time_now = datetime.now().strftime("%Y%m%d%H%M%S")
output_path = f"output/tts_{time_now}.wav"
tts.tts_to_file(text=text, speaker_wav='voices/test/test_en.wav', language="en", file_path=output_path)
# api = TTS("tts_models/se/fairseq/vits")
# api.tts_with_vc_to_file(
# text_se,
# speaker_wav="test_audio_se.wav",
# file_path="output_se.wav"
# )

@ -0,0 +1,22 @@
import requests
# Define the server URL
server_url = "http://localhost:5002/api/tts"
# Define the payload
payload = {
"text": "It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent.",
"speaker": "Ana Florence",
"language": "en",
"split_sentences": True
}
# Send the request to the TTS server
response = requests.post(server_url, json=payload)
# Save the response audio to a file
if response.status_code == 200:
with open("output.wav", "wb") as f:
f.write(response.content)
else:
print(f"Error: {response.status_code}")

@ -0,0 +1,33 @@
from TTS.tts.configs.tortoise_config import TortoiseConfig
from TTS.tts.models.tortoise import Tortoise
import torch
import os
import torchaudio
# Initialize Tortoise model
config = TortoiseConfig()
model = Tortoise.init_from_config(config)
model.load_checkpoint(config, checkpoint_dir="tts_models/en/multi-dataset/tortoise-v2", eval=True)
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)
# Define the text and voice directory
text = "There is, therefore, an increasing need to understand BEVs from a systems perspective."
voice_dir = "voices"
speaker = "test"
# Load voice samples
voice_samples = []
for file_name in os.listdir(os.path.join(voice_dir, speaker)):
file_path = os.path.join(voice_dir, speaker, file_name)
waveform, sample_rate = torchaudio.load(file_path)
voice_samples.append(waveform)
# Get conditioning latents
conditioning_latents = model.get_conditioning_latents(voice_samples)
# Save conditioning latents to a file
torch.save(conditioning_latents, "conditioning_latents.pth")

@ -0,0 +1,16 @@
# utils.py
import re
def fix_key(_key: str) -> str:
"""
Sanitize a given key by replacing all characters that are not alphanumeric,
underscore, hyphen, dot, at symbol, parentheses, plus, equals, semicolon,
dollar sign, asterisk, single quote, percent, or colon with an underscore.
Args:
_key (str): The key to be sanitized.
Returns:
str: The sanitized key with disallowed characters replaced by underscores.
"""
return re.sub(r"[^A-Za-z0-9_\-\.@()+=;$!*\'%:]", "_", _key)
Loading…
Cancel
Save