befire travel

main
lasseedfast 12 months ago
parent 00fd42b32d
commit 732793b79f
  1. 7
      .env
  2. 4
      _arango.py
  3. 149
      _base_class.py
  4. 46
      _chromadb.py
  5. 830
      _classes.py
  6. 440
      _llm copy.py
  7. 607
      _llm.py
  8. 35
      article2db.py
  9. 232
      bluesky_bot.py
  10. 412
      collections_page.py
  11. 88
      llm_server.py
  12. 68
      llm_tools.py
  13. 1
      manage_users.py
  14. 0
      pod_bot.py
  15. 725
      projects_page.py
  16. 4
      prompts.py
  17. 64
      streamlit_app.py
  18. 920
      streamlit_chatbot.py
  19. 12
      streamlit_pages.py
  20. 4
      streamlit_rss_old.py
  21. 50
      test_ollama_client.py
  22. 9
      test_ollama_image.py
  23. 59
      transcribe_audio.py

@ -18,4 +18,9 @@ ARANGO_PWD_ENV_MANAGER="jagskoterenv(Y)"
ARANGO_ROOT_USER='root'
ARANGO_ROOT_PASSWORD='gyhqed-kiwNac-9buhme'
MAILERSEND_API_KEY="mlsn.71de3eb2dbcb733bd4ee509d1c95ccfc8939fd647cba9e3a0f631f60f900bd85"
TRANSCRIBE_URL="http://98.128.172.165:4001/upload"
MAILERSEND_API_KEY="mlsn.71de3eb2dbcb733bd4ee509d1c95ccfc8939fd647cba9e3a0f631f60f900bd85"
BLUESKY_USERNAME="assistant-fish.bsky.social"
BLUESKY_PASSWORD="robMep-4hajgu-ceprox"

@ -2,7 +2,9 @@ import re
from arango import ArangoClient
from dotenv import load_dotenv
import os
import env_manager
if "INFO" not in os.environ:
import env_manager
env_manager.set_env()
load_dotenv() # Install with pip install python-dotenv

@ -3,10 +3,8 @@ import os
import re
import streamlit as st
from _arango import ArangoDB
from _llm import LLM
from _chromadb import ChromaDB
class BaseClass:
def __init__(self, username: str, **kwargs) -> None:
self.username: str = username
@ -14,6 +12,8 @@ class BaseClass:
self.collection: str = kwargs.get('collection_name', None)
self.user_arango: ArangoDB = self.get_arango()
self.base_arango: ArangoDB = self.get_arango(admin=True)
for key, value in kwargs.items():
setattr(self, key, value)
def get_arango(self, admin: bool = False, db_name: str = None) -> ArangoDB:
@ -22,27 +22,7 @@ class BaseClass:
elif admin:
return ArangoDB()
else:
return ArangoDB(db_name=st.session_state["username"])
def get_settings(self):
settings = self.user_arango.db.document("settings/settings")
if not settings:
self.user_arango.db.collection("settings").insert(
{"_key": "settings", "current_collection": None, "current_page": None}
)
for i in ["current_collection", "current_page"]:
if i not in settings:
settings[i] = None
st.session_state["settings"] = settings
return settings
def update_settings(self, key, value) -> None:
self.user_arango.db.collection("settings").update_match(
filters={"_key": "settings"},
body={key: value},
merge=True,
)
st.session_state["settings"][key] = value
return ArangoDB(user=self.username)
def get_article_collections(self) -> list:
article_collections = self.user_arango.db.aql.execute(
@ -56,27 +36,6 @@ class BaseClass:
)
return list(projects)
def choose_collection(self, text="Select a collection of favorite articles") -> str:
collections = self.get_article_collections()
collection = st.selectbox(text, collections, index=None)
if collection:
self.project = None
self.collection = collection
self.update_settings("current_collection", collection)
self.update_session_state()
return collection
def choose_project(self, text="Select a project") -> str:
projects = self.get_projects()
project = st.selectbox(text, projects, index=None)
if project:
from _classes import Project
self.project = Project(self.username, project, self.user_arango)
self.collection = None
self.update_settings("current_project", self.project.name)
self.update_session_state()
return self.project
def get_chromadb(self):
return ChromaDB()
@ -89,28 +48,6 @@ class BaseClass:
if doc:
return doc.next()
def update_current_page(self, page_name):
if st.session_state.get("current_page") != page_name:
st.session_state["current_page"] = page_name
self.update_settings("current_page", page_name)
def update_session_state(self, page_name=None):
"""
Updates the Streamlit session state with the attributes of the current instance.
Parameters:
page_name (str, optional): The name of the page to update in the session state.
If not provided, it defaults to the current page stored in the session state.
The method iterates over the instance's attributes and updates the session state
for the given page name with those attributes that are of type str, int, float, list, dict, or bool.
"""
if not page_name:
page_name = st.session_state.get("current_page")
for attr, value in self.__dict__.items():
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]):
st.session_state[page_name][attr] = value
def set_filename(self, filename=None, folder="other_documents"):
"""
@ -139,3 +76,83 @@ class BaseClass:
)
self.file_path = file_path + ".pdf"
return file_path
class StreamlitBaseClass(BaseClass):
def __init__(self, username: str, **kwargs) -> None:
super().__init__(username, **kwargs)
def get_settings(self):
settings = self.user_arango.db.document("settings/settings")
if not settings:
self.user_arango.db.collection("settings").insert(
{"_key": "settings", "current_collection": None, "current_page": None}
)
for i in ["current_collection", "current_page"]:
if i not in settings:
settings[i] = None
st.session_state["settings"] = settings
return settings
def update_settings(self, key, value) -> None:
self.user_arango.db.collection("settings").update_match(
filters={"_key": "settings"},
body={key: value},
merge=True,
)
st.session_state["settings"][key] = value
def update_session_state(self, page_name=None):
"""
Updates the Streamlit session state with the attributes of the current instance.
Parameters:
page_name (str, optional): The name of the page to update in the session state.
If not provided, it defaults to the current page stored in the session state.
The method iterates over the instance's attributes and updates the session state
for the given page name with those attributes that are of type str, int, float, list, dict, or bool.
"""
if not page_name:
page_name = st.session_state.get("current_page")
for attr, value in self.__dict__.items():
if any([isinstance(value, t) for t in [str, int, float, list, dict, bool]]):
st.session_state[page_name][attr] = value
# for k, v in st.session_state[page_name].items():
# if isinstance(v, list):
# print(k.upper())
# for j in v:
# print(j)
# else:
# print(k.upper(), v)
def update_current_page(self, page_name):
if st.session_state.get("current_page") != page_name:
st.session_state["current_page"] = page_name
self.update_settings("current_page", page_name)
def choose_collection(self, text="Select a collection of favorite articles") -> str:
collections = self.get_article_collections()
collection = st.selectbox(text, collections, index=None)
if collection:
self.project = None
self.collection = collection
self.update_settings("current_collection", collection)
self.update_session_state()
return collection
def choose_project(self, text="Select a project") -> str:
projects = self.get_projects()
project = st.selectbox(text, projects, index=None)
if project:
from _classes import Project
self.project = Project(self.username, project, self.user_arango)
self.collection = None
self.update_settings("current_project", self.project.name)
self.update_session_state()
return self.project

@ -3,6 +3,7 @@ import os
from chromadb.config import Settings
from dotenv import load_dotenv
from colorprinter.print_color import *
load_dotenv(".env")
@ -36,6 +37,21 @@ class ChromaDB:
where: dict = None,
**kwargs,
):
"""
Query the vector database for relevant documents.
Args:
query (str): The query text to search for.
collection (str): The name of the collection to search within.
n_results (int, optional): The number of results to return. Defaults to 6.
n_sources (int, optional): The number of unique sources to return. Defaults to 3.
max_retries (int, optional): The maximum number of retries for querying. Defaults to None.
where (dict, optional): Additional filtering criteria for the query. Defaults to None.
**kwargs: Additional keyword arguments to pass to the query.
Returns:
dict: A dictionary containing the query results with keys 'ids', 'metadatas', 'documents', and 'distances'.
"""
if not isinstance(n_sources, int):
n_sources = int(n_sources)
if not isinstance(n_results, int):
@ -60,7 +76,7 @@ class ChromaDB:
**kwargs,
)
if r["ids"][0] == []:
if result['ids'][0] == []:
if result["ids"][0] == []:
print_red("No results found in vector database.")
else:
print_red("No more results found in vector database.")
@ -105,6 +121,31 @@ class ChromaDB:
break
return result
def add_chunks(self, collection: str, chunks: list, _key, metadata: dict = None):
"""
Adds chunks to a specified collection in the database.
Args:
collection (str): The name of the collection to add chunks to.
chunks (list): A list of chunks to be added to the collection.
_key: A key used to generate unique IDs for the chunks.
metadata (dict, optional): Metadata to be associated with each chunk. Defaults to None.
Returns:
None
"""
col = self.db.get_or_create_collection(collection)
ids = []
metadatas = []
for number in chunks:
if metadata:
metadata["number"] = number
metadatas.append(metadata)
else:
metadatas.append({})
ids.append(f"{_key}_{number}")
col.add(ids=ids, metadatas=metadatas, documents=chunks)
if __name__ == "__main__":
from colorprinter.print_color import *
@ -119,5 +160,4 @@ if __name__ == "__main__":
n_sources=3,
max_retries=4,
)
print_rainbow(result['metadatas'][0])
print_rainbow(result["metadatas"][0])

@ -1,340 +1,16 @@
# streamlit_pages.py
import os
import feedparser
import urllib
from urllib.parse import urljoin
import requests
import re
from bs4 import BeautifulSoup
import streamlit as st
from time import sleep
from datetime import datetime, timedelta
from PIL import Image
from io import BytesIO
import base64
from colorprinter.print_color import *
from article2db import PDFProcessor
import feedparser
from streamlit_chatbot import Chat, EditorBot, ResearchAssistantBot, PodBot, Bot
from info import country_emojis
from utils import fix_key
from _arango import ArangoDB
from _llm import LLM
from _base_class import BaseClass
from streamlit_chatbot import StreamlitChat, EditorBot, ResearchAssistantBot, PodBot, StreamlitBot
from _base_class import StreamlitBaseClass
from _rss import RSSReader
from projects_page import Project
from prompts import get_note_summary_prompt, get_image_system_prompt
class ArticleCollectionsPage(BaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.collection = None
self.page_name = "Article Collections"
# Initialize attributes from session state if available
for k, v in st.session_state[self.page_name].items():
setattr(self, k, v)
def run(self):
if self.user_arango.db.collection("article_collections").count() == 0:
self.create_new_collection()
self.update_current_page(self.page_name)
self.choose_collection_method()
self.choose_project_method()
if self.collection:
self.display_collection()
self.sidebar_actions()
if st.session_state.get("new_collection"):
self.create_new_collection()
# Persist state to session_state
self.update_session_state(page_name=self.page_name)
def choose_collection_method(self):
self.collection = self.choose_collection()
# Persist state after choosing collection
self.update_session_state(page_name=self.page_name)
def choose_project_method(self):
# If you have a project selection similar to collection, implement here
pass # Placeholder for project-related logic
def choose_collection(self):
collections = self.get_article_collections()
current_collection = self.collection
preselected = (
collections.index(current_collection)
if current_collection in collections
else None
)
with st.sidebar:
collection = st.selectbox(
"Select a collection of favorite articles",
collections,
index=preselected,
)
if collection:
self.collection = collection
self.update_settings("current_collection", collection)
return self.collection
def create_new_collection(self):
with st.form("create_collection_form", clear_on_submit=True):
new_collection_name = st.text_input("Enter the name of the new collection")
submitted = st.form_submit_button("Create Collection")
if submitted:
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
st.success(f'New collection "{new_collection_name}" created')
self.collection = new_collection_name
self.update_settings("current_collection", new_collection_name)
# Persist state after creating a new collection
self.update_session_state(page_name=self.page_name)
sleep(1)
st.rerun()
def display_collection(self):
with st.sidebar:
col1, col2 = st.columns(2)
with col1:
if st.button("Create new collection"):
st.session_state["new_collection"] = True
with col2:
if st.button(f'Remove collection "{self.collection}"'):
self.user_arango.db.collection("article_collections").delete_match(
{"name": self.collection}
)
st.success(f'Collection "{self.collection}" removed')
self.collection = None
self.update_settings("current_collection", None)
# Persist state after removing a collection
self.update_session_state(page_name=self.page_name)
st.rerun()
self.show_articles_in_collection()
def show_articles_in_collection(self):
collection_articles_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{self.collection}" RETURN doc["articles"]'
)
collection_articles = list(collection_articles_cursor)
if collection_articles and collection_articles[0]:
articles = collection_articles[0]
st.markdown(f"#### Articles in *{self.collection}*:")
for article in articles:
if article is None:
continue
metadata = article.get("metadata")
if metadata is None:
continue
title = metadata.get("title", "No Title").strip()
journal = metadata.get("journal", "No Journal").strip()
published_year = metadata.get("published_year", "No Year")
published_date = metadata.get("published_date", None)
language = metadata.get("language", "No Language")
icon = country_emojis.get(language.upper(), "") if language else ""
expander_title = f"**{title}** *{journal}* ({published_year}) {icon}"
with st.expander(expander_title):
if not title == "No Title":
st.markdown(f"**Title:** \n{title}")
if not journal == "No Journal":
st.markdown(f"**Journal:** \n{journal}")
if published_date:
st.markdown(f"**Published Date:** \n{published_date}")
for key, value in article.items():
if key in [
"_key",
"text",
"file",
"_rev",
"chunks",
"user_access",
"_id",
"metadata",
"doi",
]:
continue
if isinstance(value, list):
value = ", ".join(value)
st.markdown(f"**{key.capitalize()}**: \n{value} ")
if "doi" in article:
st.markdown(
f"**DOI:** \n[{article['doi']}](https://doi.org/{article['doi']}) "
)
st.button(
key=f'delete_{article["_id"]}',
label="Delete article from collection",
on_click=self.delete_article,
args=(self.collection, article["_id"]),
)
else:
st.write("No articles in this collection.")
def sidebar_actions(self):
with st.sidebar:
st.markdown(f"### Add new articles to {self.collection}")
with st.form("add_articles_form", clear_on_submit=True):
pdf_files = st.file_uploader(
"Upload PDF file(s)", type=["pdf"], accept_multiple_files=True
)
is_sci = st.checkbox("All articles are from scientific journals")
submitted = st.form_submit_button("Upload")
if submitted and pdf_files:
self.add_articles(pdf_files, is_sci)
# Persist state after adding articles
self.update_session_state(page_name=self.page_name)
st.rerun()
help_text = 'Paste a text containing DOIs, e.g., the reference section of a paper, and click "Add Articles" to add them to the collection.'
new_articles = st.text_area(
"Add articles to this collection", help=help_text
)
if st.button("Add Articles"):
with st.spinner("Processing..."):
self.process_dois(
article_collection_name=self.collection, text=new_articles
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
st.rerun()
self.write_not_downloaded()
def add_articles(self, pdf_files: list, is_sci: bool) -> None:
for pdf_file in pdf_files:
status_container = st.empty()
with status_container:
is_sci = is_sci if is_sci else None
with st.status(f"Processing {pdf_file.name}..."):
processor = PDFProcessor(
pdf_file=pdf_file,
filename=pdf_file.name,
process=False,
username=st.session_state["username"],
document_type="other_documents",
is_sci=is_sci,
)
_id, db, doi = processor.process_document()
print_rainbow(_id, db, doi)
if doi in st.session_state.get("not_downloaded", {}):
st.session_state["not_downloaded"].pop(doi)
self.articles2collection(collection=self.collection, db=db, _id=_id)
st.success("Done!")
sleep(1.5)
def articles2collection(self, collection: str, db: str, _id: str = None) -> None:
info = self.get_article_info(db, _id=_id)
info = {
k: v for k, v in info.items() if k in ["_id", "doi", "title", "metadata"]
}
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = doc.get("articles", [])
keys = [i["_id"] for i in articles]
if info["_id"] not in keys:
articles.append(info)
self.user_arango.db.collection("article_collections").update_match(
filters={"name": collection},
body={"articles": articles},
merge=True,
)
# Persist state after updating articles
self.update_session_state(page_name=self.page_name)
def get_article_info(self, db: str, _id: str = None, doi: str = None) -> dict:
assert _id or doi, "Either _id or doi must be provided."
arango = self.get_arango(db_name=db)
if _id:
query = """
RETURN {
"_id": DOCUMENT(@doc_id)._id,
"doi": DOCUMENT(@doc_id).doi,
"title": DOCUMENT(@doc_id).title,
"metadata": DOCUMENT(@doc_id).metadata,
"summary": DOCUMENT(@doc_id).summary
}
"""
info_cursor = arango.db.aql.execute(query, bind_vars={"doc_id": _id})
elif doi:
info_cursor = arango.db.aql.execute(
f'FOR doc IN sci_articles FILTER doc["doi"] == "{doi}" LIMIT 1 RETURN {{"_id": doc["_id"], "doi": doc["doi"], "title": doc["title"], "metadata": doc["metadata"], "summary": doc["summary"]}}'
)
return next(info_cursor, None)
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
def write_not_downloaded(self):
not_downloaded = st.session_state.get("not_downloaded", {})
if not_downloaded:
st.markdown(
"*The articles below were not downloaded. Download them yourself and add them to the collection by dropping them in the area above. Some of them can be downloaded using the link.*"
)
for doi, url in not_downloaded.items():
if url:
st.markdown(f"- [{doi}]({url})")
else:
st.markdown(f"- {doi}")
def delete_article(self, collection, _id):
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = [
article for article in doc.get("articles", []) if article["_id"] != _id
]
self.user_arango.db.collection("article_collections").update_match(
filters={"_id": doc["_id"]},
body={"articles": articles},
)
# Persist state after deleting an article
self.update_session_state(page_name=self.page_name)
class BotChatPage(BaseClass):
class BotChatPage(StreamlitBaseClass):
def __init__(self, username):
super().__init__(username=username)
self.collection_name = None
@ -343,6 +19,7 @@ class BotChatPage(BaseClass):
self.chat = None
self.role = "Research Assistant" # Default persona
self.page_name = "Bot Chat"
self.chat_key = None
# Initialize attributes from session state if available
if self.page_name in st.session_state:
@ -357,9 +34,8 @@ class BotChatPage(BaseClass):
if self.collection_name or self.project:
# If no chat exists, create a new Chat instance
if not self.chat:
self.chat = Chat(username=self.username, role=self.role)
self.chat.show_chat_history()
self.chat = self.get_chat(role=self.role)
# Create a Bot instance with the Chat object
if self.role == "Research Assistant":
bot = ResearchAssistantBot(
@ -368,6 +44,7 @@ class BotChatPage(BaseClass):
collection=self.collection_name,
project=self.project,
)
elif self.role == "Editor":
bot = EditorBot(
username=self.username,
@ -405,13 +82,26 @@ class BotChatPage(BaseClass):
"chat": self.chat,
"role": self.role,
}
else:
bot = Bot(
else: # If no collection or project is selected, use the conversational response bot
bot = StreamlitBot(
username=self.username,
chat=Chat(username=self.username, role="Research Assistant"),
chat=self.get_chat(),
tools=["conversational_response_tool"],
)
bot.run()
def get_chat(self, role="Research Assistant"):
if 'chat_key' not in st.session_state:
chat=StreamlitChat(username=self.username, role=role)
st.session_state['chat_key'] = chat._key
print_blue("Creating new chat:", st.session_state['chat_key'])
else:
print_blue("Old chat:", st.session_state['chat_key'])
chat = self.user_arango.db.collection("chats").get(st.session_state['chat_key'])
chat = StreamlitChat.from_dict(chat)
return chat
def sidebar_actions(self):
with st.sidebar:
self.collection = self.choose_collection(
@ -450,13 +140,8 @@ class BotChatPage(BaseClass):
"Continue another chat", options=[""] + list(chats.keys()), index=0
)
if selected_chat:
chat_data = self.user_arango.db.collection("chats").get(
chats[selected_chat]
)
chat_dict = chat_data.get("chat_data", {})
self.chat = Chat.from_dict(chat_dict)
else:
self.chat = None
st.session_state["chat_key"] = chats[selected_chat]
self.chat = self.get_chat()
def remove_old_unsaved_chats(self):
two_weeks_ago = datetime.now() - timedelta(weeks=2)
@ -467,454 +152,7 @@ class BotChatPage(BaseClass):
self.user_arango.db.collection("chats").delete(chat["_key"])
class ProjectsPage(BaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.projects = []
self.selected_project_name = None
self.project = None
self.page_name = "Projects"
# Initialize attributes from session state if available
page_state = st.session_state.get(self.page_name, {})
for k, v in page_state.items():
setattr(self, k, v)
def run(self):
self.update_current_page(self.page_name)
self.load_projects()
self.display_projects()
# Update session state
self.update_session_state(self.page_name)
def load_projects(self):
projects_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects RETURN doc", count=True
)
self.projects = list(projects_cursor)
def display_projects(self):
with st.sidebar:
self.new_project_button()
self.selected_project_name = st.selectbox(
"Select a project to manage",
options=[proj["name"] for proj in self.projects],
)
if self.selected_project_name:
self.project = Project(
username=self.username,
project_name=self.selected_project_name,
user_arango=self.user_arango,
)
self.manage_project()
# Update session state
self.update_session_state(self.page_name)
def new_project_button(self):
st.session_state.setdefault("new_project", False)
with st.sidebar:
if st.button("New project", type="primary"):
st.session_state["new_project"] = True
if st.session_state["new_project"]:
self.create_new_project()
# Update session state
self.update_session_state(self.page_name)
def create_new_project(self):
new_project_name = st.text_input("Enter the name of the new project")
new_project_description = st.text_area(
"Enter the description of the new project"
)
if st.button("Create Project"):
if new_project_name:
self.user_arango.db.collection("projects").insert(
{
"name": new_project_name,
"description": new_project_description,
"collections": [],
"notes": [],
"note_keys_hash": hash(""),
"settings": {},
}
)
st.success(f'New project "{new_project_name}" created')
st.session_state["new_project"] = False
self.update_settings("current_project", new_project_name)
sleep(1)
st.rerun()
def show_project_notes(self):
with st.expander("Show summarised notes"):
st.markdown(self.project.notes_summary)
with st.expander("Show project notes"):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc",
bind_vars={"note_ids": self.project.notes},
)
notes = list(notes_cursor)
if notes:
for note in notes:
st.markdown(f'_{note.get("timestamp", "")}_')
st.markdown(note["text"].replace("\n", " \n"))
st.button(
key=f'delete_note_{note["_id"]}',
label="Delete note",
on_click=self.project.delete_note,
args=(note["_id"],),
)
st.write("---")
else:
st.write("No notes in this project.")
def manage_project(self):
self.update_settings("current_project", self.selected_project_name)
# Initialize the Project instance
self.project = Project(
self.username, self.selected_project_name, self.user_arango
)
st.write(f"## {self.project.name}")
self.show_project_notes()
self.relate_collections()
self.sidebar_actions()
self.project.update_notes_hash()
if st.button(f"Remove project *{self.project.name}*"):
self.user_arango.db.collection("projects").delete_match(
{"name": self.project.name}
)
self.update_settings("current_project", None)
st.success(f'Project "{self.project.name}" removed')
st.rerun()
# Update session state
self.update_session_state(self.page_name)
def relate_collections(self):
collections = [
col["name"]
for col in self.user_arango.db.collection("article_collections").all()
]
selected_collections = st.multiselect(
"Relate existing collections", options=collections
)
if st.button("Relate Collections"):
self.project.add_collections(selected_collections)
st.success("Collections related to the project")
# Update session state
self.update_session_state(self.page_name)
new_collection_name = st.text_input(
"Enter the name of the new collection to create and relate"
)
if st.button("Create and Relate Collection"):
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
self.project.add_collection(new_collection_name)
st.success(
f'New collection "{new_collection_name}" created and related to the project'
)
# Update session state
self.update_session_state(self.page_name)
def sidebar_actions(self):
self.sidebar_notes()
# Update session state
self.update_session_state(self.page_name)
def sidebar_notes(self):
with st.sidebar:
st.markdown(f"### Add new notes to {self.project.name}")
self.upload_notes_form()
self.add_text_note()
self.add_wikipedia_data()
# Update session state
self.update_session_state(self.page_name)
def upload_notes_form(self):
with st.form("add_notes", clear_on_submit=True):
files = st.file_uploader(
"Upload PDF or image",
type=["png", "jpg", "pdf"],
accept_multiple_files=True,
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.process_uploaded_files(files)
# Update session state
self.update_session_state(self.page_name)
def add_text_note(self):
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies."
note_text = st.text_area("Write or paste anything.", help=help_text)
if st.button("Add Note"):
self.project.add_note(
{
"text": note_text,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"),
}
)
st.success("Note added to the project")
# Update session state
self.update_session_state(self.page_name)
def add_wikipedia_data(self):
wiki_url = st.text_input(
"Paste the address to a Wikipedia page to add its summary as a note",
placeholder="Paste Wikipedia URL",
)
if st.button("Add Wikipedia data"):
with st.spinner("Fetching Wikipedia data..."):
wiki_data = self.project.get_wikipedia_data(wiki_url)
if wiki_data:
self.project.process_wikipedia_data(wiki_data, wiki_url)
st.success("Wikipedia data added to notes")
# Update session state
self.update_session_state(self.page_name)
st.rerun()
class Project(BaseClass):
def __init__(self, username: str, project_name: str, user_arango: ArangoDB):
super().__init__(username=username)
self.name = project_name
self.user_arango = user_arango
self.description = ""
self.collections = []
self.notes = []
self.note_keys_hash = 0
self.settings = {}
self.notes_summary = ""
# Initialize attributes from arango doc if available
self.load_project()
def load_project(self):
print_blue("Project name:", self.name)
project_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects FILTER doc.name == @name RETURN doc",
bind_vars={"name": self.name},
)
project = next(project_cursor, None)
if not project:
raise ValueError(f"Project '{self.name}' not found.")
self._key = project["_key"]
self.name = project.get("name", "")
self.description = project.get("description", "")
self.collections = project.get("collections", [])
self.notes = project.get("notes", [])
self.note_keys_hash = project.get("note_keys_hash", 0)
self.settings = project.get("settings", {})
self.notes_summary = project.get("notes_summary", "")
def update_project(self):
updated_doc = {
"_key": self._key,
"name": self.name,
"description": self.description,
"collections": self.collections,
"notes": self.notes,
"note_keys_hash": self.note_keys_hash,
"settings": self.settings,
"notes_summary": self.notes_summary,
}
self.user_arango.db.collection("projects").update(updated_doc, check_rev=False)
self.update_session_state()
def add_collections(self, collections):
self.collections.extend(collections)
self.update_project()
def add_collection(self, collection_name):
self.collections.append(collection_name)
self.update_project()
def add_note(self, note: dict):
assert note["text"], "Note text cannot be empty"
note["text"] = note["text"].strip().strip("\n")
if "timestamp" not in note:
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M")
note_doc = self.user_arango.db.collection("notes").insert(note)
if note_doc["_id"] not in self.notes:
self.notes.append(note_doc["_id"])
self.update_project()
def delete_note(self, note_id):
if note_id in self.notes:
self.notes.remove(note_id)
self.update_project()
def update_notes_hash(self):
current_hash = self.make_project_notes_hash()
if current_hash != self.note_keys_hash:
self.note_keys_hash = current_hash
with st.spinner("Summarizing notes for chatbot..."):
self.create_notes_summary()
self.update_project()
def make_project_notes_hash(self):
if not self.notes:
return hash("")
note_keys_str = "".join(self.notes)
return hash(note_keys_str)
def create_notes_summary(self):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text",
bind_vars={"note_ids": self.notes},
)
notes = list(notes_cursor)
notes_string = "\n---\n".join(notes)
llm = LLM(model="small")
query = get_note_summary_prompt(self, notes_string)
summary = llm.generate(query)
print_purple("New summary of notes:", summary)
self.notes_summary = summary
self.update_session_state()
def analyze_image(self, image_base64, text=None):
project_data = {"name": self.name}
llm = LLM(system_message=get_image_system_prompt(project_data))
prompt = (
f'Analyze the image. The text found in it read: "{text}"'
if text
else "Analyze the image."
)
description = llm.generate(query=prompt, images=[image_base64], stream=False)
print_green("Image description:", description)
def process_uploaded_files(self, files):
with st.spinner("Processing files..."):
for file in files:
st.write("Processing...")
filename = fix_key(file.name)
image_file = self.file2img(file)
pdf_file = self.convert_image_to_pdf(image_file)
pdf = PDFProcessor(
pdf_file=pdf_file,
is_sci=False,
document_type="notes",
is_image=True,
process=False,
)
text = pdf.process_pdf()
base64_str = base64.b64encode(file.read())
image_caption = self.analyze_image(base64_str, text=text)
self.add_note(
{
"_id": f"notes/{filename}",
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}",
}
)
st.success("Done!")
sleep(1.5)
self.update_session_state()
st.rerun()
def file2img(self, file):
img_bytes = file.read()
if not img_bytes:
raise ValueError("Uploaded file is empty.")
return Image.open(BytesIO(img_bytes))
def convert_image_to_pdf(self, img):
import pytesseract
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img)
pdf_file = BytesIO(pdf_bytes)
pdf_file.name = (
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf"
)
return pdf_file
def get_wikipedia_data(self, page_url: str) -> dict:
import wikipedia
from urllib.parse import urlparse
parsed_url = urlparse(page_url)
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path)
if page_name_match:
page_name = page_name_match.group(0)
else:
st.warning("Invalid Wikipedia URL")
return None
try:
page = wikipedia.page(page_name)
data = {
"title": page.title,
"summary": page.summary,
"content": page.content,
"url": page.url,
"references": page.references,
}
return data
except Exception as e:
st.error(f"Error fetching Wikipedia data: {e}")
return None
def process_wikipedia_data(self, wiki_data, wiki_url):
llm = LLM(
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!",
model="small",
)
if wiki_data.get("summary"):
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.'''
summary = llm.generate(query)
wiki_data["text"] = (
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}"
)
wiki_data.pop("summary", None)
wiki_data.pop("content", None)
self.user_arango.db.collection("notes").insert(
wiki_data, overwrite=True, silent=True
)
self.add_note(wiki_data)
processor = PDFProcessor(process=False)
dois = [
processor.extract_doi(ref)
for ref in wiki_data.get("references", [])
if processor.extract_doi(ref)
]
if dois:
current_collection = st.session_state["settings"].get("current_collection")
st.markdown(
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?"
)
if st.button("Add DOIs"):
self.process_dois(current_collection, dois=dois)
self.update_session_state()
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
self.update_session_state()
class SettingsPage(BaseClass):
class SettingsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
@ -940,7 +178,7 @@ class SettingsPage(BaseClass):
sleep(1)
class RSSFeedsPage(BaseClass):
class RSSFeedsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.page_name = "RSS Feeds"
@ -983,10 +221,10 @@ class RSSFeedsPage(BaseClass):
st.session_state["discovered_feeds"] = feeds
else:
st.error("No RSS feeds found at the provided URL.")
def sidebar_actions(self):
if 'discovered_feeds' not in st.session_state:
st.session_state['discovered_feeds'] = None
if "discovered_feeds" not in st.session_state:
st.session_state["discovered_feeds"] = None
with st.sidebar:
self.select_rss_feeds()
@ -997,7 +235,7 @@ class RSSFeedsPage(BaseClass):
if submitted:
print_green(rss_url)
feeds = self.reader.discover_feeds(rss_url)
st.session_state['discovered_feeds'] = feeds
st.session_state["discovered_feeds"] = feeds
if st.session_state["discovered_feeds"]:
st.subheader("Select a Feed to Add")
@ -1037,4 +275,4 @@ class RSSFeedsPage(BaseClass):
summary = entry.get("summary", "No summary available")
markdown_summary = self.reader.html_to_markdown(summary)
st.markdown(markdown_summary)
st.markdown(f"[Read more]({entry['link']})")
st.markdown(f"[Read more]({entry['link']})")

@ -0,0 +1,440 @@
import os
import base64
import re
import json
from typing import Any, Callable, Iterator, Literal, Mapping, Optional, Sequence, Union
import tiktoken
from ollama import Client, AsyncClient, ResponseError, ChatResponse, Message, Tool, Options
from ollama._types import JsonSchemaValue, ChatRequest
import env_manager
from colorprinter.print_color import *
env_manager.set_env()
tokenizer = tiktoken.get_encoding("cl100k_base")
# Define a base class for common functionality
class BaseClient:
def chat(
self,
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, Iterator[ChatResponse]]:
return self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=[message for message in messages or []],
tools=[tool for tool in tools or []],
stream=stream,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
)
# Define your custom MyAsyncClient class
class MyAsyncClient(AsyncClient, BaseClient):
async def _request(self, response_type, method, path, headers=None, **kwargs):
# Merge default headers with per-call headers
all_headers = {**self._client.headers, **(headers or {})}
# Handle streaming separately
if kwargs.get('stream'):
kwargs.pop('stream')
async with self._client.stream(method, path, headers=all_headers, **kwargs) as response:
self.last_response = response # Store the response object
if response.status_code >= 400:
await response.aread()
raise ResponseError(response.text, response.status_code)
return self._stream(response_type, response)
else:
# Make the HTTP request with the combined headers
kwargs.pop('stream')
response = await self._request_raw(method, path, headers=all_headers, **kwargs)
self.last_response = response # Store the response object
if response.status_code >= 400:
raise ResponseError(response.text, response.status_code)
return response_type.model_validate_json(response.content)
async def chat(
self,
model: str = '',
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None,
*,
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
) -> Union[ChatResponse, Iterator[ChatResponse]]:
return await self._request(
ChatResponse,
'POST',
'/api/chat',
json=ChatRequest(
model=model,
messages=[message for message in messages or []],
tools=[tool for tool in tools or []],
stream=stream,
format=format,
options=options,
keep_alive=keep_alive,
).model_dump(exclude_none=True),
stream=stream,
)
# Define your custom MyClient class
class MyClient(Client, BaseClient):
def _request(self, response_type, method, path, headers=None, **kwargs):
# Merge default headers with per-call headers
all_headers = {**self._client.headers, **(headers or {})}
# Handle streaming separately
if kwargs.get('stream'):
kwargs.pop('stream')
with self._client.stream(method, path, headers=all_headers, **kwargs) as response:
self.last_response = response # Store the response object
if response.status_code >= 400:
raise ResponseError(response.text, response.status_code)
return self._stream(response_type, response)
else:
# Make the HTTP request with the combined headers
kwargs.pop('stream')
response = self._request_raw(method, path, headers=all_headers, **kwargs)
self.last_response = response # Store the response object
if response.status_code >= 400:
raise ResponseError(response.text, response.status_code)
return response_type.model_validate_json(response.content)
class LLM:
"""
LLM class for interacting with a language model.
"""
def __init__(
self,
system_message="You are an assistant.",
temperature=0.01,
model: Optional[Literal["small", "standard", "vision"]] = "standard",
max_length_answer=4096,
messages=None,
chat=True,
chosen_backend=None,
) -> None:
self.model = self.get_model(model)
self.system_message = system_message
self.options = {"temperature": temperature}
self.messages = messages or [{"role": "system", "content": self.system_message}]
self.max_length_answer = max_length_answer
self.chat = chat
self.chosen_backend = chosen_backend
# Initialize the client with the host and default headers
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
default_headers = {
"Authorization": f"Basic {encoded_credentials}",
}
host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
self.client = MyClient(host=host_url, headers=default_headers)
self.async_client = MyAsyncClient(host=host_url, headers=default_headers)
def get_model(self, model_alias):
models = {
"standard": "LLM_MODEL",
"small": "LLM_MODEL_SMALL",
"vision": "LLM_MODEL_VISION",
"standard_64k": "LLM_MODEL_64K",
}
return os.getenv(models.get(model_alias, "LLM_MODEL"))
def count_tokens(self):
num_tokens = 0
for i in self.messages:
for k, v in i.items():
if k == "content":
if not isinstance(v, str):
v = str(v)
tokens = tokenizer.encode(v)
num_tokens += len(tokens)
return int(num_tokens)
def generate(
self,
query: str = None,
user_input: str = None,
context: str = None,
stream: bool = False,
tools: list = None,
function_call: dict = None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature: float = None,
):
"""
Generates a response from the language model based on the provided inputs.
"""
# Prepare the model and temperature
model = self.get_model(model) if model else self.model
temperature = temperature if temperature else self.options["temperature"]
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
# Handle images if any
if images:
import base64
base64_images = []
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
for image in images:
if isinstance(image, str):
if base64_pattern.match(image):
base64_images.append(image)
else:
with open(image, "rb") as image_file:
base64_images.append(
base64.b64encode(image_file.read()).decode("utf-8")
)
elif isinstance(image, bytes):
base64_images.append(base64.b64encode(image).decode("utf-8"))
else:
print_red("Invalid image type")
message["images"] = base64_images
# Use the vision model
model = self.get_model("vision")
self.messages.append(message)
# Prepare headers
headers = {}
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
if model == self.get_model("small"):
headers["X-Model-Type"] = "small"
# Prepare options
options = Options(**self.options)
options.temperature = temperature
# Prepare tools if any
if tools:
tools = [
Tool(**tool) if isinstance(tool, dict) else tool
for tool in tools
]
# Adjust the options for long messages
if self.chat or len(self.messages) > 15000:
num_tokens = self.count_tokens() + self.max_length_answer // 2
if num_tokens > 8000:
model = self.get_model("standard_64k")
headers["X-Model-Type"] = "large"
# Call the client.chat method
try:
response = self.client.chat(
model=model,
messages=self.messages,
headers=headers,
tools=tools,
stream=stream,
options=options,
keep_alive=3600 * 24 * 7,
)
except ResponseError as e:
print_red("Error!")
print(e)
return "An error occurred."
# If user_input is provided, update the last message
if user_input:
if context:
if len(context) > 2000:
context = self.make_summary(context)
user_input = (
f"{user_input}\n\nUse the information below to answer the question.\n"
f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
)
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
if system_message_info not in self.messages[0]["content"]:
self.messages[0]["content"] += system_message_info
self.messages[-1] = {"role": "user", "content": user_input}
self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend")
# Handle streaming response
if stream:
return self.read_stream(response)
else:
# Process the response
if isinstance(response, ChatResponse):
result = response.message.content.strip('"')
self.messages.append({"role": "assistant", "content": result.strip('"')})
if tools and not response.message.get("tool_calls"):
print_yellow("No tool calls in response".upper())
if not self.chat:
self.messages = [self.messages[0]]
return result
else:
print_red("Unexpected response type")
return "An error occurred."
def make_summary(self, text):
# Implement your summary logic using self.client.chat()
summary_message = {
"role": "user",
"content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.',
}
messages = [
{"role": "system", "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information."},
summary_message,
]
try:
response = self.client.chat(
model=self.get_model("small"),
messages=messages,
options=Options(temperature=0.01),
keep_alive=3600 * 24 * 7,
)
summary = response.message.content.strip()
print_blue("Summary:", summary)
return summary
except ResponseError as e:
print_red("Error generating summary:", e)
return "Summary generation failed."
def read_stream(self, response):
# Implement streaming response handling if needed
buffer = ""
message = ""
first_chunk = True
prev_content = None
for chunk in response:
if chunk:
content = chunk.message.content
if first_chunk and content.startswith('"'):
content = content[1:]
first_chunk = False
if chunk.done:
if prev_content and prev_content.endswith('"'):
prev_content = prev_content[:-1]
if prev_content:
yield prev_content
break
else:
if prev_content:
yield prev_content
prev_content = content
self.messages.append({"role": "assistant", "content": message.strip('"')})
async def async_generate(
self,
query: str = None,
user_input: str = None,
context: str = None,
stream: bool = False,
tools: list = None,
function_call: dict = None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature: float = None,
):
"""
Asynchronous method to generate a response from the language model.
"""
# Prepare the model and temperature
model = self.get_model(model) if model else self.model
temperature = temperature if temperature else self.options["temperature"]
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
# Handle images if any
if images:
# (Image handling code as in the generate method)
...
self.messages.append(message)
# Prepare headers
headers = {}
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
if model == self.get_model("small"):
headers["X-Model-Type"] = "small"
# Prepare options
options = Options(**self.options)
options.temperature = temperature
# Prepare tools if any
if tools:
tools = [
Tool(**tool) if isinstance(tool, dict) else tool
for tool in tools
]
# Adjust options for long messages
# (Adjustments as needed)
...
# Call the async client's chat method
try:
response = await self.async_client.chat(
model=model,
messages=self.messages,
tools=tools,
stream=stream,
options=options,
keep_alive=3600 * 24 * 7,
)
except ResponseError as e:
print_red("Error!")
print(e)
return "An error occurred."
# Process the response
if isinstance(response, ChatResponse):
result = response.message.content.strip('"')
self.messages.append({"role": "assistant", "content": result.strip('"')})
return result
else:
print_red("Unexpected response type")
return "An error occurred."
# Usage example
if __name__ == "__main__":
import asyncio
llm = LLM()
async def main():
result = await llm.async_generate(query="Hello, how are you?")
print(result)
asyncio.run(main())

@ -1,52 +1,117 @@
import re
import os
import base64
import re
from typing import Literal, Optional
import requests
from requests.auth import HTTPBasicAuth
import tiktoken
import json
from colorprinter.print_color import *
import asyncio
from ollama import (
Client,
AsyncClient,
ResponseError,
ChatResponse,
Tool,
Options,
)
import env_manager
from colorprinter.print_color import *
env_manager.set_env()
tokenizer = tiktoken.get_encoding("cl100k_base")
class LLM:
"""
LLM class for interacting with an instance of Ollama.
Attributes:
model (str): The model to be used for response generation.
system_message (str): The system message to be used in the chat.
options (dict): Options for the model, such as temperature.
messages (list): List of messages in the chat.
max_length_answer (int): Maximum length of the generated answer.
chat (bool): Whether the chat mode is enabled.
chosen_backend (str): The chosen backend server for the API.
client (Client): The client for synchronous API calls.
async_client (AsyncClient): The client for asynchronous API calls.
Methods:
__init__(self, system_message, temperature, model, max_length_answer, messages, chat, chosen_backend):
Initializes the LLM class with the provided parameters.
get_model(self, model_alias):
Retrieves the model name based on the provided alias.
count_tokens(self):
Counts the number of tokens in the messages.
get_least_conn_server(self):
Retrieves the least connected server from the backend.
generate(self, query, user_input, context, stream, tools, images, model, temperature):
Generates a response based on the provided query and options.
make_summary(self, text):
Generates a summary of the provided text.
read_stream(self, response):
Handles streaming responses.
async_generate(self, query, user_input, context, stream, tools, images, model, temperature):
Asynchronously generates a response based on the provided query and options.
prepare_images(self, images, message):
"""
def __init__(
self,
system_message="You are an assistant.",
temperature=0.01,
system_message: str = "You are an assistant.",
temperature: float = 0.01,
model: Optional[Literal["small", "standard", "vision"]] = "standard",
max_length_answer=4096,
messages=None,
chat=True,
chosen_backend=None,
max_length_answer: int = 4096,
messages: list[dict] = None,
chat: bool = True,
chosen_backend: str = None,
) -> None:
"""
Initialize the assistant with specified parameters.
Initialize the assistant with the given parameters.
Args:
system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.".
temperature (float): The temperature setting for the model's response generation. Defaults to 0.01.
chat (bool): Flag to indicate if the assistant is in chat mode. Defaults to True.
model (str): The model type to use. Defaults to "standard". Alternatives: 'small', 'standard', 'vision'.
max_length_answer (int): The maximum length of the generated answer. Defaults to 4000.
chosen_backend (str): The chosen ollama server for the request. Defaults to None.
temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01.
model (Optional[Literal["small", "standard", "vision"]]): The model type to use. Defaults to "standard".
max_length_answer (int): The maximum length of the generated answer. Defaults to 4096.
messages (list[dict], optional): A list of initial messages. Defaults to None.
chat (bool): Whether the assistant is in chat mode. Defaults to True.
chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen.
Returns:
None
"""
self.model = self.get_model(model)
self.system_message = system_message
self.options = {"temperature": temperature}
self.messages = messages or [{"role": "system", "content": self.system_message}]
self.max_length_answer = max_length_answer
self.chat = chat
if not chosen_backend:
chosen_backend = self.get_least_conn_server()
self.chosen_backend = chosen_backend
# Initialize the client with the host and default headers
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers = {
"Authorization": f"Basic {encoded_credentials}",
"X-Chosen-Backend": self.chosen_backend,
}
host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
self.client = Client(host=host_url, headers=headers)
self.async_client = AsyncClient()
def get_model(self, model_alias):
models = {
"standard": "LLM_MODEL",
@ -66,76 +131,17 @@ class LLM:
tokens = tokenizer.encode(v)
num_tokens += len(tokens)
return int(num_tokens)
def read_stream(self, response):
buffer = ""
message = ""
first_chunk = True
prev_content = None # Store the previous content chunk
for chunk in response.iter_content(chunk_size=64):
if chunk:
try:
message_part = chunk.decode("utf-8")
buffer += message_part
message += message_part
except UnicodeDecodeError:
continue
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line.strip():
try:
json_data = json.loads(line)
content = json_data["message"]["content"]
done = json_data.get("done", False)
# Remove leading '"' from the first content
if first_chunk and content.startswith('"'):
content = content[1:]
first_chunk = False
if done:
# If the last content ends with '"', remove it
if prev_content and prev_content.endswith('"'):
prev_content = prev_content[:-1]
# Yield the last content
if prev_content:
yield prev_content
break
else:
# Yield the previous content before storing the current
if prev_content:
yield prev_content
prev_content = content
except json.JSONDecodeError:
continue
# Append the full message without leading/trailing quotes
self.messages.append({"role": "assistant", "content": message.strip('"')})
def make_summary(self, text):
data = {
"messages": [
{
"role": "system",
"content": """You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information.""",
},
{
"role": "user",
"content": f'Summarise the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.',
},
],
"stream": False,
"keep_alive": 3600 * 24 * 7,
"model": self.get_model("small"),
"options": {"temperature": 0.01},
}
response = requests.post(
os.getenv("LLM_API_URL"),
json=data,
auth=HTTPBasicAuth(
os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE")
),
)
print_blue("Summary:", response.json()["message"]["content"])
return response.json()["message"]["content"]
def get_least_conn_server(self):
try:
response = requests.get("http://192.168.1.12:5000/least_conn")
response.raise_for_status()
# Extract the least connected server from the response
least_conn_server = response.headers.get("X-Upstream-Address")
return least_conn_server
except requests.RequestException as e:
print_red("Error getting least connected server:", e)
return None
def generate(
self,
@ -144,165 +150,176 @@ class LLM:
context: str = None,
stream: bool = False,
tools: list = None,
function_call: dict = None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature: float = None,
messages: list[dict] = None,
):
"""
Generates a response from the language model based on the provided inputs.
If user_input is provided, it is included in the message history instead of the query.
If context is provided, it is summaried if len() > 2000 and included in the message history.
Generate a response based on the provided query and options.
Args:
query (str, optional): The main query string to be processed by the model.
user_input (str, optional): User input to be included in the message history.
context (str, optional): Contextual information to be included in the message history.
query (str, optional): The query string to generate a response for.
user_input (str, optional): Additional user input to update the last message.
context (str, optional): Context information to be used in the response.
stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): List of tools to be included in the request.
function_call (dict, optional): Dictionary specifying a function call to be made.
images (list, optional): List of image paths or base64-encoded images to be included.
model (Optional[Literal["small", "standard", "vision"]], optional): The model type to be used. Defaults to None.
temperature (float, optional): The temperature setting for the model. Defaults to None.
tools (list, optional): List of tools to be used in the response generation.
images (list, optional): List of images to be included in the response.
model (Optional[Literal["small", "standard", "vision"]], optional): The model to be used for response generation.
temperature (float, optional): The temperature setting for the model.
messages (list[dict], optional): A list of messages formated as dictionaries (eg. {'role': 'user', 'content': 'message'}).
Returns:
str: The generated response from the language model. If streaming is enabled, returns the streamed response.
str: The generated response or an error message if an exception occurs.
Raises:
ResponseError: If an error occurs during the response generation.
"""
# Add custom header if large model is chosen
# Prepare the model and temperature
model = self.get_model(model) if model else self.model
temperature = temperature if temperature else self.options["temperature"]
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
headers = {"Content-Type": "application/json"}
from colorprinter.print_color import print_rainbow
if messages:
messages = [{'role': i['role'], 'content': re.sub(r"\s*\n\s*", "\n", i['content'])} for i in messages]
message = messages.pop(-1)
query = message['content']
self.messages = messages
else:
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
# Handle images if any
if images:
import base64
message = self.prepare_images(images, message)
model = self.get_model("vision")
# Convert image paths to base64
base64_images = []
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
# Handle images
for image in images:
if isinstance(image, str):
if base64_pattern.match(image):
base64_images.append(image)
else:
with open(image, "rb") as image_file:
base64_images.append(
base64.b64encode(image_file.read()).decode("utf-8")
)
elif isinstance(image, bytes):
base64_images.append(base64.b64encode(image).decode("utf-8"))
else:
print_red("Invalid image type")
message["images"] = base64_images
# Set the Content-Type header based on the presence of images
headers = {"Content-Type": "application/json; images"}
# Set the model type to the vision model
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
self.messages.append(message)
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer
if self.chat or len(self.messages) > 15000:
num_tokens = self.count_tokens() + self.max_length_answer / 2
if num_tokens > 8000:
model = self.get_model("large")
headers["X-Model-Type"] = "large"
# Prepare headers
headers = {}
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
if tools:
stream = False
data = {
"messages": self.messages,
"stream": stream,
"keep_alive": 3600 * 24 * 7,
"model": model if model else self.model,
"options": self.options,
}
if model == self.get_model("small"):
headers["X-Model-Type"] = "small"
# Include tools if provided
if tools:
data["tools"] = tools
# Prepare options
options = Options(**self.options)
options.temperature = temperature
# Include function_call if provided
if function_call:
data["function_call"] = function_call
if tools:
print_yellow("Tools:", tools)
if data['model'] == 'small':
headers["X-Model-Type"] = "small"
response = requests.post(
os.getenv("LLM_API_URL"),
headers=headers,
json=data,
auth=HTTPBasicAuth(
os.getenv("LLM_API_USER"), os.getenv("LLM_API_PWD_LASSE")
),
stream=stream,
timeout=3600,
)
# Adjust the options for long messages
if self.chat or len(self.messages) > 15000:
num_tokens = self.count_tokens() + self.max_length_answer // 2
if num_tokens > 8000:
model = self.get_model("standard_64k")
headers["X-Model-Type"] = "large"
# If user_input is provided, change the last message to user_input and a summary of the context (if provided)
# This needs to be done after the request to LLM for the LLM to have the original message
# Call the client.chat method
try:
response = self.client.chat(
model=model,
messages=self.messages,
tools=tools,
stream=stream,
options=options,
keep_alive=3600 * 24 * 7,
)
except ResponseError as e:
print_red("Error!")
print(e)
return "An error occurred."
# print_rainbow(response.__dict__)
# If user_input is provided, update the last message
if user_input:
if context:
if len(context) > 2000:
context = self.make_summary(context)
user_input = f'''{user_input}\n\nUse the information below to answer the question.\n"""{context}"""\n[This is a summary of the context provided in the original message.]'''
user_input = (
f"{user_input}\n\nUse the information below to answer the question.\n"
f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
)
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
if system_message_info not in self.messages[0]["content"]:
self.messages[0]["content"] = (
self.messages[0]["content"] + system_message_info
)
self.messages[0]["content"] += system_message_info
self.messages[-1] = {"role": "user", "content": user_input}
self.chosen_backend = response.headers.get("X-Chosen-Backend")
if response.status_code != 200:
print_red("Error!")
print_rainbow(response.content.decode("utf-8"))
if response.status_code == 404:
return "Target endpoint not found"
if response.status_code == 504:
return f"Gateway Timeout: {response.content.decode('utf-8')}"
# self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend")
# Handle streaming response
if stream:
return self.read_stream(response)
else:
try:
response_json = response.json()
if tools and not response_json["message"].get("tool_calls"):
# Process the response
if isinstance(response, ChatResponse):
result = response.message.content.strip('"')
self.messages.append(
{"role": "assistant", "content": result.strip('"')}
)
if tools and not response.message.get("tool_calls"):
print_yellow("No tool calls in response".upper())
if "tool_calls" in response_json.get("message", {}):
# The LLM wants to invoke a function (tool)
result = response_json["message"]
else:
result = response_json["message"]["content"].strip('"')
self.messages.append(
{"role": "assistant", "content": result.strip('"')}
)
except requests.exceptions.JSONDecodeError:
print_red("Error: ", response.status_code, response.text)
if not self.chat:
self.messages = [self.messages[0]]
return response.message
else:
print_red("Unexpected response type")
return "An error occurred."
if not self.chat:
self.messages = [self.messages[0]]
return result
def make_summary(self, text):
# Implement your summary logic using self.client.chat()
summary_message = {
"role": "user",
"content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.',
}
messages = [
{
"role": "system",
"content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information.",
},
summary_message,
]
try:
response = self.client.chat(
model=self.get_model("small"),
messages=messages,
options=Options(temperature=0.01),
keep_alive=3600 * 24 * 7,
)
summary = response.message.content.strip()
print_blue("Summary:", summary)
return summary
except ResponseError as e:
print_red("Error generating summary:", e)
return "Summary generation failed."
def read_stream(self, response):
# Implement streaming response handling if needed
message = ""
first_chunk = True
prev_content = None
for chunk in response:
if chunk:
content = chunk.message.content
if first_chunk and content.startswith('"'):
content = content[1:]
first_chunk = False
if chunk.done:
if prev_content and prev_content.endswith('"'):
prev_content = prev_content[:-1]
if prev_content:
yield prev_content
break
else:
if prev_content:
yield prev_content
prev_content = content
self.messages.append({"role": "assistant", "content": message.strip('"')})
async def async_generate(
self,
@ -311,32 +328,172 @@ class LLM:
context: str = None,
stream: bool = False,
tools: list = None,
function_call: dict = None,
images: list = None,
model: Optional[Literal["small", "standard", "vision"]] = None,
temperature: float = None,
):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
self.generate,
query,
user_input,
context,
stream,
tools,
function_call,
images,
model,
temperature,
"""
Asynchronously generates a response based on the provided query and other parameters.
Args:
query (str, optional): The query string to generate a response for.
user_input (str, optional): Additional user input to be included in the response.
context (str, optional): Context information to be used in generating the response.
stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): List of tools to be used in generating the response.
images (list, optional): List of images to be included in the response.
model (Optional[Literal["small", "standard", "vision"]], optional): The model to be used for generating the response.
temperature (float, optional): The temperature setting for the model.
Returns:
str: The generated response or an error message if an exception occurs.
Raises:
ResponseError: If an error occurs during the response generation.
Notes:
- The function prepares the model and temperature settings.
- It normalizes whitespace in the query and handles images if provided.
- It prepares headers and options for the request.
- It adjusts options for long messages and calls the async client's chat method.
- If user_input is provided, it updates the last message.
- It updates the chosen backend based on the response headers.
- It handles streaming responses and processes the response accordingly.
"""
# Prepare the model and temperature
model = self.get_model(model) if model else self.model
temperature = temperature if temperature else self.options["temperature"]
# Normalize whitespace and add the query to the messages
query = re.sub(r"\s*\n\s*", "\n", query)
message = {"role": "user", "content": query}
# Handle images if any
if images:
message = self.prepare_images(images, message)
model = self.get_model("vision")
self.messages.append(message)
# Prepare headers
headers = {}
if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend
if model == self.get_model("small"):
headers["X-Model-Type"] = "small"
# Prepare options
options = Options(**self.options)
options.temperature = temperature
# Prepare tools if any
if tools:
tools = [Tool(**tool) if isinstance(tool, dict) else tool for tool in tools]
# Adjust options for long messages
if self.chat or len(self.messages) > 15000:
num_tokens = self.count_tokens() + self.max_length_answer // 2
if num_tokens > 8000:
model = self.get_model("standard_64k")
headers["X-Model-Type"] = "large"
# Call the async client's chat method
try:
response = await self.async_client.chat(
model=model,
messages=self.messages,
headers=headers,
tools=tools,
stream=stream,
options=options,
keep_alive=3600 * 24 * 7,
)
except ResponseError as e:
print_red("Error!")
print(e)
return "An error occurred."
# If user_input is provided, update the last message
if user_input:
if context:
if len(context) > 2000:
context = self.make_summary(context)
user_input = (
f"{user_input}\n\nUse the information below to answer the question.\n"
f'"""{context}"""\n[This is a summary of the context provided in the original message.]'
)
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message."
if system_message_info not in self.messages[0]["content"]:
self.messages[0]["content"] += system_message_info
self.messages[-1] = {"role": "user", "content": user_input}
# Update chosen_backend
self.chosen_backend = self.async_client.last_response.headers.get(
"X-Chosen-Backend"
)
# Handle streaming response
if stream:
return self.read_stream(response)
else:
# Process the response
if isinstance(response, ChatResponse):
result = response.message.content.strip('"')
self.messages.append(
{"role": "assistant", "content": result.strip('"')}
)
if tools and not response.message.get("tool_calls"):
print_yellow("No tool calls in response".upper())
if not self.chat:
self.messages = [self.messages[0]]
return result
else:
print_red("Unexpected response type")
return "An error occurred."
def prepare_images(self, images, message):
"""
Prepares a list of images by converting them to base64 encoded strings and adds them to the provided message dictionary.
Args:
images (list): A list of images, where each image can be a file path (str), a base64 encoded string (str), or bytes.
message (dict): A dictionary to which the base64 encoded images will be added under the key "images".
Returns:
dict: The updated message dictionary with the base64 encoded images added under the key "images".
Raises:
ValueError: If an image is not a string or bytes.
"""
import base64
base64_images = []
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
for image in images:
if isinstance(image, str):
if base64_pattern.match(image):
base64_images.append(image)
else:
with open(image, "rb") as image_file:
base64_images.append(
base64.b64encode(image_file.read()).decode("utf-8")
)
elif isinstance(image, bytes):
base64_images.append(base64.b64encode(image).decode("utf-8"))
else:
print_red("Invalid image type")
message["images"] = base64_images
# Use the vision model
return message
if __name__ == "__main__":
llm = LLM()
images = ["th-2182728540.jpeg"]
print(
llm.generate(
query="Hi there",
)
result = llm.generate(
query="I want to add 2 and 2",
)
print(result.content)

@ -29,11 +29,15 @@ class Document:
def __init__(
self,
pdf_file=None,
filename=None,
doi=None,
username=None,
is_sci=None,
is_image=False,
filename: str=None,
doi: str=None,
username: str=None,
is_sci: bool=None,
is_image: bool=False,
text: str=None,
_key: str=None,
arango_db_name: str=None,
arango_collection: str=None,
):
self.filename = filename
self.pdf_file = pdf_file
@ -41,30 +45,35 @@ class Document:
self.username = username
self.is_sci = is_sci
self.is_image = is_image
self.pdf = None
self._key = None
self._id = None
self._key = _key
self.arango_db_name = arango_db_name
self.arango_collection = arango_collection
self.text = text
self.chunks = []
self.pdf = None
self._id = None
self.metadata = None
self.title = None
self.open_access = False
self.file_path = None
self.download_folder = None
self.text = ""
self.arango_db_name = None
self.document_type = None
if self.pdf_file:
self.open_pdf(self.pdf_file)
def make_summary_in_background(self):
if not self._id and all([self.arango_collection, self._key]):
self._id = f"{self.arango_collection}/{self._key}"
if not self._id:
return
data = {
"text": self.text,
"arango_db_name": self.arango_db_name,
"_id": self._id,
"arango_id": self._id,
"is_sci": self.is_sci,
}
@ -104,7 +113,7 @@ class Document:
self.text = md_text
def make_chunks(self, len_chunks=2200):
def make_chunks(self, len_chunks=1500):
better_chunks = []
ts = MarkdownSplitter(len_chunks)

@ -0,0 +1,232 @@
from _llm import LLM
import os, re
from atproto import (
CAR,
AtUri,
Client,
FirehoseSubscribeReposClient,
firehose_models,
models,
parse_subscribe_repos_message,
)
from colorprinter.print_color import *
from datetime import datetime
from env_manager import set_env
set_env()
class Chat:
def __init__(self, bot_username, poster_username):
self.bot_username = bot_username
self.poster_username = poster_username
self.messages = []
self.thread_posts = []
class Bot:
def __init__(self):
# Create a client instance to interact with Bluesky
self.username = os.getenv("BLUESKY_USERNAME")
system_message = '''
You are a research assistant bot chatting with a user on Bluesky, a social media platform similar to Twitter.
Your speciality is electric cars, and you will use facts in articles to answer the questions
Use ONLY the information in the articles to answer the questions. Do not add any additional information or speculation.
IF you don't know the answer, you can say "I don't know" or "I'm not sure". You can also ask the user to specify the question.
Your answers should be concise and not exceed 250 characters to fit the character limit on Bluesky.
Answer in English.
'''
self.llm: LLM = LLM(system_message=system_message, max_length_answer=200)
self.client = Client()
self.client.login(self.username, os.getenv("BLUESKY_PASSWORD"))
self.chat = None
print("🐟 Bot is listening")
# Create a firehose client to subscribe to repository events
self.firehose = FirehoseSubscribeReposClient()
# Start the firehose to listen for repository events
self.firehose.start(self.on_message_handler)
def answer_message(self, message):
response = self.llm.generate(message).content
self.client.send_post(response)
def get_file_extension(self, file_path):
# Utility function to get the file extension from a file path
return os.path.splitext(file_path)[1]
def bot_mentioned(self, text: str) -> bool:
# Check if the text contains 'cc: dreambot' (case-insensitive)
return self.username.lower() in text.lower()
def parse_thread(self, author_did, thread):
# Traverse the thread to collect prompts from posts by the author
entries = []
stack = [thread]
while stack:
current_thread = stack.pop()
if current_thread is None:
continue
if current_thread.post.author.did == author_did:
print(current_thread.post.record.text)
# Extract prompt from the current post
entries.append(current_thread.post.record.text)
# Add parent thread to the stack for further traversal
stack.append(current_thread.parent)
return entries
def process_operation(
self,
op: models.ComAtprotoSyncSubscribeRepos.RepoOp,
car: CAR,
commit: models.ComAtprotoSyncSubscribeRepos.Commit,
) -> None:
# Construct the URI for the operation
uri = AtUri.from_str(f"at://{commit.repo}/{op.path}")
if op.action == "create":
if not op.cid:
return
# Retrieve the record from the CAR file using the content ID (CID)
record = car.blocks.get(op.cid)
if not record:
return
# Build the record with additional metadata
record = {
"uri": str(uri),
"cid": str(op.cid),
"author": commit.repo,
**record,
}
# Check if the operation is a post in the feed
if uri.collection == models.ids.AppBskyFeedPost:
if self.bot_mentioned(record["text"]):
poster_username = self.client.get_profile(actor=record["author"]).handle
self.chat = Chat(self.username, poster_username)
posts_in_thread = self.client.get_post_thread(uri=record["uri"])
self.traverse_thread(posts_in_thread.thread)
self.chat.thread_posts.sort(key=lambda x: x["timestamp"])
self.make_llm_messages()
answer = self.llm.generate(messages=self.chat.messages)
self.client.send_post(f'@{poster_username} {answer.content} ')
if op.action == "delete":
# Handle delete operations (not implemented)
return
if op.action == "update":
# Handle update operations (not implemented)
return
return
def traverse_thread(self, thread_view_post):
# Process the current post
post = thread_view_post.post
author_handle = post.author.handle
post_text = post.record.text
timestamp = int(
datetime.fromisoformat(post.indexed_at.replace("Z", "+00:00")).timestamp()
)
self.chat.thread_posts.append(
{
"user": author_handle,
"text": post_text.replace("\n", " "),
"timestamp": timestamp,
}
)
# If there's a parent, process it
if thread_view_post.parent:
self.traverse_thread(thread_view_post.parent)
# If there are replies, process them
if getattr(thread_view_post, "replies", None):
for reply in thread_view_post.replies:
self.traverse_thread(reply)
def make_llm_messages(self):
"""
Processes the chat thread posts and compiles them into a list of messages
formatted for a language model (LLM).
The function performs the following steps:
1. Iterates through the chat thread posts.
2. Starts processing messages only after encountering a message mentioning the bot.
3. Adds messages from the bot and the poster to the `self.chat.messages` list in the
appropriate format for the LLM.
The messages are formatted as follows:
- Messages from the bot are added with the role "assistant".
- Messages from the poster are added with the role "user".
- Consecutive messages from the same user are concatenated.
Returns:
None
"""
print_rainbow(self.chat.thread_posts)
start = False
for i in self.chat.thread_posts:
# Make the messages start with a message mentioning the bot
if self.chat.bot_username in i["text"]:
start = True
elif self.chat.bot_username not in i["text"] and not start:
continue
# Compile the messages int a list for LLM
if (
i["user"] == self.chat.bot_username
and len(self.chat.messages) > 0
and self.chat.messages[-1] != self.chat.bot_username
):
i['text'] = i['text'].replace(f"@{self.chat.poster_username}", "").strip()
self.chat.messages.append({"role": "assistant", "content": i["text"]})
elif i["user"] == self.chat.poster_username:
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip()
if len(self.chat.messages) > 0 and self.chat.messages[-1]['role'] == 'user':
self.chat.messages[-1]['content'] += f"\n\n{i['text']}"
else:
self.chat.messages.append({"role": "user", "content": i["text"]})
def on_message_handler(self, message: firehose_models.MessageFrame) -> None:
# Callback function that handles incoming messages from the firehose subscription
# Parse the incoming message to extract the commit information
commit = parse_subscribe_repos_message(message)
# Check if the parsed message is a Commit and if the commit contains blocks of data
if not isinstance(
commit, models.ComAtprotoSyncSubscribeRepos.Commit
) or not isinstance(commit.blocks, bytes):
# If the message is not a valid commit or blocks are missing, exit early
return
# Parse the CAR (Content Addressable aRchive) file from the commit's blocks
# The CAR file contains the data blocks referenced in the commit operations
car = CAR.from_bytes(commit.blocks)
# Iterate over each operation (e.g., create, delete, update) in the commit
for op in commit.ops:
# Process each operation using the process_operation method
# This method handles the logic based on the type of operation
self.process_operation(op, car, commit)
def main() -> None:
bot = Bot()
bot.answer_message("Hello, world!")
if __name__ == "__main__":
main()

@ -0,0 +1,412 @@
import streamlit as st
from time import sleep
from article2db import PDFProcessor
from info import country_emojis
from utils import fix_key
from _base_class import StreamlitBaseClass
from colorprinter.print_color import *
class ArticleCollectionsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.collection = None
self.page_name = "Article Collections"
# Initialize attributes from session state if available
for k, v in st.session_state[self.page_name].items():
setattr(self, k, v)
def run(self):
if self.user_arango.db.collection("article_collections").count() == 0:
self.create_new_collection()
self.update_current_page(self.page_name)
self.choose_collection_method()
self.choose_project_method()
if self.collection:
self.display_collection()
self.sidebar_actions()
if st.session_state.get("new_collection"):
self.create_new_collection()
# Persist state to session_state
self.update_session_state(page_name=self.page_name)
def choose_collection_method(self):
self.collection = self.choose_collection()
# Persist state after choosing collection
self.update_session_state(self.page_name)
def choose_project_method(self):
# If you have a project selection similar to collection, implement here
pass # Placeholder for project-related logic
def choose_collection(self):
collections = self.get_article_collections()
current_collection = self.collection
preselected = (
collections.index(current_collection)
if current_collection in collections
else None
)
with st.sidebar:
collection = st.selectbox(
"Select a collection of favorite articles",
collections,
index=preselected,
)
if collection:
self.collection = collection
self.update_settings("current_collection", collection)
return self.collection
def create_new_collection(self):
with st.form("create_collection_form", clear_on_submit=True):
new_collection_name = st.text_input("Enter the name of the new collection")
submitted = st.form_submit_button("Create Collection")
if submitted:
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
st.success(f'New collection "{new_collection_name}" created')
self.collection = new_collection_name
self.update_settings("current_collection", new_collection_name)
# Persist state after creating a new collection
self.update_session_state(page_name=self.page_name)
sleep(1)
st.rerun()
def display_collection(self):
with st.sidebar:
col1, col2 = st.columns(2)
with col1:
if st.button("Create new collection"):
st.session_state["new_collection"] = True
with col2:
if st.button(f':red[Remove collection "{self.collection}"]'):
self.user_arango.db.collection("article_collections").delete_match(
{"name": self.collection}
)
st.success(f'Collection "{self.collection}" removed')
self.collection = None
self.update_settings("current_collection", None)
# Persist state after removing a collection
self.update_session_state(page_name=self.page_name)
st.rerun()
self.show_articles_in_collection()
def show_articles_in_collection(self):
collection_articles_cursor = self.user_arango.db.aql.execute(
f"""
FOR doc IN article_collections
FILTER doc["name"] == @collection
FOR article IN doc["articles"]
RETURN article["_id"]
""",
bind_vars={"collection": self.collection},
)
collection_article_ids = list(collection_articles_cursor)
sci_articles = [
_id for _id in collection_article_ids if _id.startswith("sci_articles")
]
other_articles = [
_id for _id in collection_article_ids if not _id.startswith("sci_articles")
]
collection_articles = []
if sci_articles:
cursor = self.base_arango.db.aql.execute(
"""
FOR doc IN sci_articles
FILTER doc["_id"] IN @article_ids
RETURN doc
""",
bind_vars={"article_ids": sci_articles},
)
collection_articles += list(cursor)
if other_articles:
cursor = self.user_arango.db.aql.execute(
"""
FOR doc IN other_documents
FILTER doc["_id"] IN @article_ids
RETURN doc
""",
bind_vars={"article_ids": other_articles},
)
collection_articles += list(cursor)
# Sort articles by title
collection_articles = sorted(
collection_articles,
key=lambda x: x.get("metadata", {}).get("title", "No Title"),
)
if collection_articles:
st.markdown(f"#### Articles in *{self.collection}*:")
for article in collection_articles:
if article is None:
continue
metadata = article.get("metadata")
if metadata is None:
continue
title = metadata.get("title", "No Title").strip()
journal = metadata.get("journal", "No Journal").strip()
published_year = metadata.get("published_year", "No Year")
published_date = metadata.get("published_date", None)
language = metadata.get("language", "No Language")
icon = country_emojis.get(language.upper(), "") if language else ""
expander_title = f"**{title}** *{journal}* ({published_year}) {icon}"
with st.expander(expander_title):
if not title == "No Title":
st.markdown(f"**Title:** \n{title}")
if not journal == "No Journal":
st.markdown(f"**Journal:** \n{journal}")
if published_date:
st.markdown(f"**Published Date:** \n{published_date}")
for key, value in article.items():
if key in [
"_key",
"text",
"file",
"_rev",
"chunks",
"user_access",
"_id",
"metadata",
"doi",
"title",
"user_notes",
]:
continue
if isinstance(value, list):
value = ", ".join(value)
st.markdown(f"**{key.capitalize()}**: \n{value} ")
if "doi" in article:
if article["doi"]:
st.markdown(
f"**DOI:** \n[{article['doi']}](https://doi.org/{article['doi']}) "
)
# Let the user add notes to the article, if it's not a scientific article
# if not article._id.startswith("sci_articles"):
if "user_notes" in article and article["user_notes"]:
st.markdown(
f":blue[**Your notes:**]"
)
note_number = 0
for note in article["user_notes"]:
note_number += 1
c1, c2 = st.columns([4, 1])
with c1:
st.markdown(f":blue[{note}]")
with c2:
st.button(key=f'{article["_key"]}_{note_number}',
label=f":red[Delete note]",
on_click=self.delete_article_note,
args=(article, note),
)
with st.form(f"add_info_form_{article['_id']}", clear_on_submit=True):
new_info = st.text_area(
":blue[Add a note about the article]",
key=f'new_info_{article["_id"]}',
help="Add information such as what kind of article it is, what it's about, who's the author, etc.",
)
submitted = st.form_submit_button(":blue[Add note]")
if submitted:
self.update_article(article, "user_notes", new_info)
st.button(
key=f'delete_{article["_id"]}',
label=":red[Delete article from collection]",
on_click=self.delete_article,
args=(self.collection, article["_id"]),
)
else:
st.write("No articles in this collection.")
def sidebar_actions(self):
with st.sidebar:
st.markdown(f"### Add new articles to {self.collection}")
with st.form("add_articles_form", clear_on_submit=True):
pdf_files = st.file_uploader(
"Upload PDF file(s)", type=["pdf"], accept_multiple_files=True
)
is_sci = st.checkbox("All articles are from scientific journals")
submitted = st.form_submit_button("Upload")
if submitted and pdf_files:
self.add_articles(pdf_files, is_sci)
# Persist state after adding articles
self.update_session_state(page_name=self.page_name)
st.rerun()
help_text = 'Paste a text containing DOIs, e.g., the reference section of a paper, and click "Add Articles" to add them to the collection.'
new_articles = st.text_area(
"Add articles to this collection", help=help_text
)
if st.button("Add Articles"):
with st.spinner("Processing..."):
self.process_dois(
article_collection_name=self.collection, text=new_articles
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
st.rerun()
self.write_not_downloaded()
def add_articles(self, pdf_files: list, is_sci: bool) -> None:
for pdf_file in pdf_files:
status_container = st.empty()
with status_container:
is_sci = is_sci if is_sci else None
with st.status(f"Processing {pdf_file.name}..."):
processor = PDFProcessor(
pdf_file=pdf_file,
filename=pdf_file.name,
process=False,
username=st.session_state["username"],
document_type="other_documents",
is_sci=is_sci,
)
_id, db, doi = processor.process_document()
print_rainbow(_id, db, doi)
if doi in st.session_state.get("not_downloaded", {}):
st.session_state["not_downloaded"].pop(doi)
self.articles2collection(collection=self.collection, db=db, _id=_id)
st.success("Done!")
sleep(1.5)
def articles2collection(self, collection: str, db: str, _id: str = None) -> None:
info = self.get_article_info(db, _id=_id)
info = {
k: v for k, v in info.items() if k in ["_id", "doi", "title", "metadata"]
}
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = doc.get("articles", [])
keys = [i["_id"] for i in articles]
if info["_id"] not in keys:
articles.append(info)
self.user_arango.db.collection("article_collections").update_match(
filters={"name": collection},
body={"articles": articles},
merge=True,
)
# Persist state after updating articles
self.update_session_state(page_name=self.page_name)
def get_article_info(self, db: str, _id: str = None, doi: str = None) -> dict:
assert _id or doi, "Either _id or doi must be provided."
arango = self.get_arango(db_name=db)
if _id:
query = """
RETURN {
"_id": DOCUMENT(@doc_id)._id,
"doi": DOCUMENT(@doc_id).doi,
"title": DOCUMENT(@doc_id).title,
"metadata": DOCUMENT(@doc_id).metadata,
"summary": DOCUMENT(@doc_id).summary
}
"""
info_cursor = arango.db.aql.execute(query, bind_vars={"doc_id": _id})
elif doi:
info_cursor = arango.db.aql.execute(
f'FOR doc IN sci_articles FILTER doc["doi"] == "{doi}" LIMIT 1 RETURN {{"_id": doc["_id"], "doi": doc["doi"], "title": doc["title"], "metadata": doc["metadata"], "summary": doc["summary"]}}'
)
return next(info_cursor, None)
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
# Persist state after processing DOIs
self.update_session_state(page_name=self.page_name)
def write_not_downloaded(self):
not_downloaded = st.session_state.get("not_downloaded", {})
if not_downloaded:
st.markdown(
"*The articles below were not downloaded. Download them yourself and add them to the collection by dropping them in the area above. Some of them can be downloaded using the link.*"
)
for doi, url in not_downloaded.items():
if url:
st.markdown(f"- [{doi}]({url})")
else:
st.markdown(f"- {doi}")
def delete_article(self, collection, _id):
doc_cursor = self.user_arango.db.aql.execute(
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc'
)
doc = next(doc_cursor, None)
if doc:
articles = [
article for article in doc.get("articles", []) if article["_id"] != _id
]
self.user_arango.db.collection("article_collections").update_match(
filters={"_id": doc["_id"]},
body={"articles": articles},
)
# Persist state after deleting an article
self.update_session_state(page_name=self.page_name)
def update_article(self, article, field, value):
"Update a field in an article document"
value = str(value.strip())
print(value)
print(type(value))
if field in article:
if isinstance(article[field], list):
article[field].append(value)
else:
article[field] = [article[field], value]
else:
article[field] = [value]
self.user_arango.db.update_document(article, check_rev=False, silent=True)
sleep(0.2)
st.rerun()
def delete_article_note(self, article: dict, note: str):
"Delete a note from a list of notes in an article document."
if "user_notes" in article and note in article["user_notes"]:
article["user_notes"].remove(note)
self.user_arango.db.update_document(article, check_rev=False, silent=True)
sleep(0.1)

@ -1,6 +1,6 @@
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from typing import Optional
from fastapi import FastAPI, BackgroundTasks, Request
from fastapi.responses import JSONResponse
import logging
from prompts import get_summary_prompt
from _llm import LLM
@ -8,46 +8,66 @@ from _arango import ArangoDB
app = FastAPI()
class DocumentData(BaseModel):
text: str
arango_db_name: str
arango_id: str
is_sci: Optional[bool] = False
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.post("/summarise_document")
async def summarize_document(doc_data: DocumentData, background_tasks: BackgroundTasks):
background_tasks.add_task(summarise_document_task, doc_data.dict())
return {"message": "Document summarization has started."}
async def summarize_document(request: Request, background_tasks: BackgroundTasks):
try:
data = await request.json()
logger.info(f"Received data: {data}")
# Clean the data
data['text'] = data.get('text', '').strip()
data['arango_db_name'] = data.get('arango_db_name', '').strip()
data['arango_id'] = data.get('arango_id', '').strip()
data['is_sci'] = data.get('is_sci', False)
def summarise_document_task(doc_data: dict):
text = doc_data.get("text")
is_sci = doc_data.get("is_sci", False)
background_tasks.add_task(summarise_document_task, data)
return {"message": "Document summarization has started."}
except Exception as e:
logger.error(f"Error in summarize_document: {e}")
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred."},
)
system_message = "You are summarising scientific articles. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English."
llm = LLM(system_message=system_message)
def summarise_document_task(doc_data: dict):
try:
_id = doc_data.get("arango_id")
text = doc_data.get("text")
is_sci = doc_data.get("is_sci", False)
summary = llm.generate(query=get_summary_prompt(text, is_sci))
if _id.split('/')[0] == 'interviews':
system_message = "You are summarising interview transcripts. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English."
elif is_sci or _id.split('/')[0] == 'sci_articles':
system_message = "You are summarising scientific articles. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English."
else:
system_message = "You are summarising a document. It is very important that you keep to what is written and do not add any of your own opinions or interpretations. Always answer in English."
llm = LLM(system_message=system_message)
summary_doc = {
"text_sum": summary,
"meta": {
"model": llm.model,
"system_message": system_message,
"temperature": llm.options["temperature"],
},
}
prompt = get_summary_prompt(text, is_sci)
summary = llm.generate(query=prompt)
arango = ArangoDB(db_name=doc_data.get("arango_db_name"))
arango.db.update_document(
{"summary": summary_doc, "_id": doc_data.get("arango_id")},
silent=True,
check_rev=False,
)
summary_doc = {
"text_sum": summary,
"meta": {
"model": llm.model,
"temperature": llm.options["temperature"],
},
}
arango = ArangoDB(db_name=doc_data.get("arango_db_name"))
arango.db.update_document(
{"summary": summary_doc, "_id": _id},
silent=True,
check_rev=False,
)
except Exception as e:
logger.error(f'_id: _{id}')
logger.error(f"Error in summarise_document_task: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8100)
uvicorn.run(app, host="0.0.0.0", port=8100)

@ -1,27 +1,59 @@
from typing import Callable, Dict, Any, List
class ToolRegistry:
_tools = []
"""
A registry for managing and accessing tools (functions).
This class provides methods to register functions as tools and retrieve them by name.
Attributes:
_tools (Dict[str, Callable]): A dictionary mapping tool names to their corresponding functions.
Methods:
register(func: Callable) -> Callable:
Registers a function as a tool. The function's name is used as the key in the registry.
get_tools(tools: List[str] = None) -> List[Callable]:
Retrieves a list of registered tools. If a list of tool names is provided, only the tools
with those names are returned. If no list is provided, all registered tools are returned.
"""
_tools: Dict[str, Callable] = {}
@classmethod
def register(cls, name: str, description: str, parameters: Dict[str, Any] = None):
def decorator(func: Callable):
cls._tools.append({
"type": "function",
"function": {
"name": name,
"description": description,
"parameters": parameters or {}
}
})
# No need for the wrapper since we're not adding any extra logic
return func
return decorator
def register(cls, func: Callable):
"""
Registers a function as a tool in the class.
This method adds the given function to the class's `_tools` dictionary,
using the function's name as the key.
Args:
func (Callable): The function to be registered.
Returns:
Callable: The same function that was passed in, allowing for decorator usage.
"""
cls._tools[func.__name__] = func
return func
@classmethod
def get_tools(cls, tools: list = None) -> List[Dict[str, Any]]:
def get_tools(cls, tools: List[str] = None) -> List[Callable]:
"""
Retrieve a list of tool callables.
This method returns a list of tool callables based on the provided tool names.
If no tool names are provided, it returns all available tool callables.
Args:
tools (List[str], optional): A list of tool names to retrieve. Defaults to None.
Returns:
List[Callable]: A list of tool callables.
"""
print(tools)
if tools:
return [tool for tool in cls._tools if tool['function']['name'] in tools]
print(cls._tools)
return [cls._tools[name] for name in tools if name in cls._tools]
else:
return cls._tools
return list(cls._tools.values())

@ -76,6 +76,7 @@ def make_arango(username):
"notes",
"other_documents",
"rss_feeds",
"interviews",
]:
if not arango.db.has_collection(collection):
arango.db.create_collection(collection)

@ -0,0 +1,725 @@
import re
import os
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from time import sleep
from datetime import datetime
from PIL import Image
from io import BytesIO
import base64
from article2db import PDFProcessor
from utils import fix_key
from _arango import ArangoDB
from _llm import LLM
from _base_class import StreamlitBaseClass
from colorprinter.print_color import *
from prompts import get_note_summary_prompt, get_image_system_prompt
import env_manager
env_manager.set_env()
print_green("Environment variables set.")
class ProjectsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.projects = []
self.selected_project_name = None
self.project = None
self.page_name = "Projects"
# Initialize attributes from session state if available
page_state = st.session_state.get(self.page_name, {})
for k, v in page_state.items():
setattr(self, k, v)
def run(self):
self.update_current_page(self.page_name)
self.load_projects()
self.display_projects()
# Update session state
self.update_session_state(self.page_name)
def load_projects(self):
projects_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects RETURN doc", count=True
)
self.projects = list(projects_cursor)
def display_projects(self):
with st.sidebar:
self.new_project_button()
self.selected_project_name = st.selectbox(
"Select a project to manage",
options=[proj["name"] for proj in self.projects],
)
if self.selected_project_name:
self.project = Project(
username=self.username,
project_name=self.selected_project_name,
user_arango=self.user_arango,
)
self.manage_project()
# Update session state
self.update_session_state(self.page_name)
def new_project_button(self):
st.session_state.setdefault("new_project", False)
with st.sidebar:
if st.button("New project", type="primary"):
st.session_state["new_project"] = True
if st.session_state["new_project"]:
self.create_new_project()
# Update session state
self.update_session_state(self.page_name)
def create_new_project(self):
new_project_name = st.text_input("Enter the name of the new project")
new_project_description = st.text_area(
"Enter the description of the new project"
)
if st.button("Create Project"):
if new_project_name:
self.user_arango.db.collection("projects").insert(
{
"name": new_project_name,
"description": new_project_description,
"collections": [],
"notes": [],
"note_keys_hash": hash(""),
"settings": {},
}
)
st.success(f'New project "{new_project_name}" created')
st.session_state["new_project"] = False
self.update_settings("current_project", new_project_name)
sleep(1)
st.rerun()
def show_project_notes(self):
with st.expander("Show summarised notes"):
st.markdown(self.project.notes_summary)
with st.expander("Show project notes"):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc",
bind_vars={"note_ids": self.project.notes},
)
notes = list(notes_cursor)
if notes:
for note in notes:
st.markdown(f'_{note.get("timestamp", "")}_')
st.markdown(note["text"].replace("\n", " \n"))
st.button(
key=f'delete_note_{note["_id"]}',
label=":red[Delete note]",
on_click=self.project.delete_note,
args=(note["_id"],),
)
st.write("---")
else:
st.write("No notes in this project.")
def show_project_interviews(self):
with st.expander("Show project interviews"):
if not self.user_arango.db.has_collection("interviews"):
self.user_arango.db.create_collection("interviews")
interviews_cursor = self.user_arango.db.aql.execute(
"FOR doc IN interviews FILTER doc.project == @project_name RETURN doc",
bind_vars={"project_name": self.project.name},
)
interviews = list(interviews_cursor)
if interviews:
for interview in interviews:
st.markdown(f'_{interview.get("timestamp", "")}_')
st.markdown(
f"**Interviewees:** {', '.join(interview['intervievees'])}"
)
st.markdown(f"**Interviewer:** {interview['interviewer']}")
if len(interview["transcript"].split("\n")) > 6:
preview = (
" \n".join(interview["transcript"].split("\n")[:6])
+ " \n(...)"
)
else:
preview = interview["transcript"]
timestamps = re.findall(r"\[(.*?)\]", preview)
for ts in timestamps:
preview = preview.replace(f"[{ts}]", f":grey[{ts}]")
st.markdown(preview)
c1, c2 = st.columns(2)
with c1:
st.download_button(
label="Download Transcript",
key=f"download_transcript_{interview['_key']}",
data=interview["transcript"],
file_name=interview["filename"],
mime="text/vtt",
)
with c2:
st.button(
key=f'delete_interview_{interview["_key"]}',
label=":red[Delete interview]",
on_click=self.project.delete_interview,
args=(interview["_key"],),
)
st.write("---")
else:
st.write("No interviews in this project.")
def manage_project(self):
self.update_settings("current_project", self.selected_project_name)
# Initialize the Project instance
self.project = Project(
self.username, self.selected_project_name, self.user_arango
)
st.write(f"## {self.project.name}")
self.show_project_interviews()
self.show_project_notes()
self.relate_collections()
self.sidebar_actions()
self.project.update_notes_hash()
if st.button(f":red[Remove project *{self.project.name}*]"):
self.user_arango.db.collection("projects").delete_match(
{"name": self.project.name}
)
self.update_settings("current_project", None)
st.success(f'Project "{self.project.name}" removed')
st.rerun()
# Update session state
self.update_session_state(self.page_name)
def relate_collections(self):
collections = [
col["name"]
for col in self.user_arango.db.collection("article_collections").all()
]
selected_collections = st.multiselect(
"Relate existing collections", options=collections
)
if st.button("Relate Collections"):
self.project.add_collections(selected_collections)
st.success("Collections related to the project")
# Update session state
self.update_session_state(self.page_name)
new_collection_name = st.text_input(
"Enter the name of the new collection to create and relate"
)
if st.button("Create and Relate Collection"):
if new_collection_name:
self.user_arango.db.collection("article_collections").insert(
{"name": new_collection_name, "articles": []}
)
self.project.add_collection(new_collection_name)
st.success(
f'New collection "{new_collection_name}" created and related to the project'
)
# Update session state
self.update_session_state(self.page_name)
def sidebar_actions(self):
self.sidebar_interview()
self.sidebar_notes()
# Update session state
self.update_session_state(self.page_name)
def sidebar_notes(self):
with st.sidebar:
st.markdown(f"### Add new notes to {self.project.name}")
self.upload_notes_form()
self.add_text_note()
self.add_wikipedia_data()
# Update session state
self.update_session_state(self.page_name)
def sidebar_interview(self):
with st.sidebar:
st.markdown(f"### Add new interview to {self.project.name}")
self.upload_interview_form()
# Update session state
self.update_session_state(self.page_name)
def upload_notes_form(self):
with st.expander("Upload notes"):
with st.form("add_notes", clear_on_submit=True):
files = st.file_uploader(
"Upload PDF or image",
type=["png", "jpg", "pdf"],
accept_multiple_files=True,
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.process_uploaded_notes(files)
# Update session state
self.update_session_state(self.page_name)
def upload_interview_form(self):
with st.expander("Upload interview"):
with st.form("add_interview", clear_on_submit=True):
interview = st.file_uploader("Upload interview audio file")
interviewees = st.text_input(
"Enter the names of the interviewees, separated by commas"
)
interviewer = st.text_input(
"Enter the interviewer's name",
help="If left blank, the current user will be used",
)
date_of_interveiw = st.date_input(
"Date of interview", value=None, format="YYYY-MM-DD"
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.add_interview(
interview, interviewees, interviewer, date_of_interveiw
)
# Update session state
self.update_session_state(self.page_name)
def add_text_note(self):
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies."
note_text = st.text_area("Write or paste anything.", help=help_text)
if st.button("Add Note"):
self.project.add_note(
{
"text": note_text,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"),
}
)
st.success("Note added to the project")
# Update session state
self.update_session_state(self.page_name)
def add_wikipedia_data(self):
wiki_url = st.text_input(
"Paste the address to a Wikipedia page to add its summary as a note",
placeholder="Paste Wikipedia URL",
)
if st.button("Add Wikipedia data"):
with st.spinner("Fetching Wikipedia data..."):
wiki_data = self.project.get_wikipedia_data(wiki_url)
if wiki_data:
self.project.process_wikipedia_data(wiki_data, wiki_url)
st.success("Wikipedia data added to notes")
# Update session state
self.update_session_state(self.page_name)
st.rerun()
class Project(StreamlitBaseClass):
def __init__(self, username: str, project_name: str, user_arango: ArangoDB):
super().__init__(username=username)
self.name = project_name
self.user_arango = user_arango
self.description = ""
self.collections = []
self.notes = []
self.note_keys_hash = 0
self.settings = {}
self.notes_summary = ""
# Initialize attributes from arango doc if available
self.load_project()
def load_project(self):
print_blue("Project name:", self.name)
project_cursor = self.user_arango.db.aql.execute(
"FOR doc IN projects FILTER doc.name == @name RETURN doc",
bind_vars={"name": self.name},
)
project = next(project_cursor, None)
if not project:
raise ValueError(f"Project '{self.name}' not found.")
self._key = project["_key"]
self.name = project.get("name", "")
self.description = project.get("description", "")
self.collections = project.get("collections", [])
self.notes = project.get("notes", [])
self.note_keys_hash = project.get("note_keys_hash", 0)
self.settings = project.get("settings", {})
self.notes_summary = project.get("notes_summary", "")
def update_project(self):
updated_doc = {
"_key": self._key,
"name": self.name,
"description": self.description,
"collections": self.collections,
"notes": self.notes,
"note_keys_hash": self.note_keys_hash,
"settings": self.settings,
"notes_summary": self.notes_summary,
}
self.user_arango.db.collection("projects").update(updated_doc, check_rev=False)
self.update_session_state()
def add_collections(self, collections):
self.collections.extend(collections)
self.update_project()
def add_collection(self, collection_name):
self.collections.append(collection_name)
self.update_project()
def add_note(self, note: dict):
assert note["text"], "Note text cannot be empty"
note["text"] = note["text"].strip().strip("\n")
if "timestamp" not in note:
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M")
note_doc = self.user_arango.db.collection("notes").insert(note)
if note_doc["_id"] not in self.notes:
self.notes.append(note_doc["_id"])
self.update_project()
def add_interview(
self,
interview: UploadedFile,
intervievees: str,
interviewer: str,
date_of_interveiw: datetime.date = None,
):
# TODO Implement this method
# Check if interview is a sound (WAV, Mp3, AAC, etc) file or a text file (PDF, DOCX, TXT, etc)
if interview.type in ["audio/x-wav", "audio/mpeg"]:
transcription = self.transcribe(interview)
transcription_preview = (
" \n".join(transcription.split("\n")[:4]) + " \n(...)"
)
st.markdown(transcription_preview)
transcription_filename = os.path.splitext(interview.name)[0] + ".vtt"
c1, c2 = st.columns(2)
with c1:
st.button(
"Add to project",
on_click=self.add_interview_transcript,
args=(
transcription,
transcription_filename,
intervievees,
interviewer,
date_of_interveiw,
),
)
with c2:
st.download_button(
label="Download Transcription",
data=transcription,
file_name=transcription_filename,
mime="text/vtt",
)
elif interview.type in ["application/pdf"]:
PDFProcessor(
pdf_file=interview,
is_sci=False,
document_type="interview",
is_image=False,
)
elif interview.type in ["plain/text"]:
# TODO Implement text file processing
pass
def add_interview_transcript(
self,
transcript,
filename,
intervievees: str = None,
interviewer: str = None,
date_of_interveiw: datetime.date = None,
):
print_yellow(transcript)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
_key = fix_key(f"{filename}_{timestamp}")
if intervievees:
intervievees = [
i.strip() for i in intervievees.split(",") if len(i.strip()) > 0
]
if not interviewer:
interviewer = self.username
if not self.user_arango.db.has_collection("interviews"):
self.user_arango.db.create_collection("interviews")
if date_of_interveiw:
date_of_interveiw = datetime.strptime(date_of_interveiw, "%Y-%m-%d")
from article2db import Document
document = Document(
text=transcript,
is_sci=False,
_key=_key,
filename=filename,
arango_db_name=self.username,
username=self.username,
arango_collection="interviews",
)
print_rainbow(document.__dict__)
print(document.text)
document.make_chunks(len_chunks=600)
self.user_arango.db.collection("interviews").insert(
{
"_key": _key,
"transcript": transcript,
"project": self.name,
"filename": filename,
"timestamp": timestamp,
"intervievees": intervievees,
"interviewer": interviewer,
"date_of_interveiw": date_of_interveiw,
"chunks": document.chunks,
},
overwrite=True,
silent=True,
)
document.make_summary_in_background()
def transcribe(self, uploaded_file: UploadedFile):
from pydub import AudioSegment
import requests
import io
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
filename = uploaded_file.name
input_file_buffer = io.BytesIO(uploaded_file.getvalue())
progress_bar = st.progress(0)
status_text = st.empty()
if file_extension in [".m4a", ".mp3", ".wav", ".flac"]:
# Handle audio files
audio = AudioSegment.from_file(
input_file_buffer, format=file_extension.replace(".", "")
)
audio = audio.set_channels(1) # Convert to mono
file_buffer = io.BytesIO()
audio.export(file_buffer, format="mp3", bitrate="64k")
file_buffer.seek(0)
progress_bar.progress(50)
status_text.text("Audio file converted.")
else:
st.error("Unsupported file type")
st.stop()
# Send the converted audio data to the transcription service
try:
try:
url = os.getenv("TRANSCRIBE_URL")
except:
import dotenv
dotenv.load_dotenv()
url = os.getenv("TRANSCRIBE_URL")
# Prepare the files dictionary for the POST request
files = {"file": (filename, file_buffer, "audio/mp3")}
# Send the POST request with the file buffer
response = requests.post(url, files=files, timeout=3600)
response_json = response.json()
progress_bar.progress(100)
status_text.text("File uploaded and processed.")
if response.status_code == 200:
transcription_content = response_json.get("transcription", "")
transcription_content = self.format_transcription(transcription_content)
return transcription_content
else:
st.error("Failed to upload and process the file.")
except requests.exceptions.Timeout:
st.error("The request timed out. Please try again later.")
def format_transcription(self, transcription: str):
lines = transcription.split("\n")
transcript = []
timestamp = None
for line in lines:
if "-->" in line:
timestamp = line[: line.find(".")]
elif timestamp:
line = f"[{timestamp}] {line}"
transcript.append(line)
timestamp = None
return "\n".join(transcript)
def delete_note(self, note_id):
if note_id in self.notes:
self.notes.remove(note_id)
self.update_project()
def delete_interview(self, interview_id):
self.user_arango.db.collection("interviews").delete_match(
{"_key": interview_id}
)
def update_notes_hash(self):
current_hash = self.make_project_notes_hash()
if current_hash != self.note_keys_hash:
self.note_keys_hash = current_hash
with st.spinner("Summarizing notes for chatbot..."):
self.create_notes_summary()
self.update_project()
def make_project_notes_hash(self):
if not self.notes:
return hash("")
note_keys_str = "".join(self.notes)
return hash(note_keys_str)
def create_notes_summary(self):
notes_cursor = self.user_arango.db.aql.execute(
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text",
bind_vars={"note_ids": self.notes},
)
notes = list(notes_cursor)
notes_string = "\n---\n".join(notes)
llm = LLM(model="small")
query = get_note_summary_prompt(self, notes_string)
summary = llm.generate(query)
print_purple("New summary of notes:", summary)
self.notes_summary = summary
self.update_session_state()
def analyze_image(self, image_base64, text=None):
project_data = {"name": self.name}
llm = LLM(system_message=get_image_system_prompt(self))
prompt = (
f'Analyze the image. The text found in it read: "{text}"'
if text
else "Analyze the image."
)
print_blue(type(image_base64))
description = llm.generate(query=prompt, images=[image_base64], stream=False)
print_green("Image description:", description)
def process_uploaded_notes(self, files):
with st.spinner("Processing files..."):
for file in files:
st.write("Processing...")
filename = fix_key(file.name)
image_file = self.file2img(file)
pdf_file = self.convert_image_to_pdf(image_file)
pdf = PDFProcessor(
pdf_file=pdf_file,
is_sci=False,
document_type="notes",
is_image=True,
process=False,
)
text = pdf.process_document()
base64_str = base64.b64encode(file.read())
image_caption = self.analyze_image(base64_str, text=text)
self.add_note(
{
"_id": f"notes/{filename}",
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}",
}
)
st.success("Done!")
sleep(1.5)
self.update_session_state()
st.rerun()
def file2img(self, file):
img_bytes = file.read()
if not img_bytes:
raise ValueError("Uploaded file is empty.")
return Image.open(BytesIO(img_bytes))
def convert_image_to_pdf(self, img):
import pytesseract
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img)
pdf_file = BytesIO(pdf_bytes)
pdf_file.name = (
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf"
)
return pdf_file
def get_wikipedia_data(self, page_url: str) -> dict:
import wikipedia
from urllib.parse import urlparse
parsed_url = urlparse(page_url)
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path)
if page_name_match:
page_name = page_name_match.group(0)
else:
st.warning("Invalid Wikipedia URL")
return None
try:
page = wikipedia.page(page_name)
data = {
"title": page.title,
"summary": page.summary,
"content": page.content,
"url": page.url,
"references": page.references,
}
return data
except Exception as e:
st.error(f"Error fetching Wikipedia data: {e}")
return None
def process_wikipedia_data(self, wiki_data, wiki_url):
llm = LLM(
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!",
model="small",
)
if wiki_data.get("summary"):
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.'''
summary = llm.generate(query)
wiki_data["text"] = (
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}"
)
wiki_data.pop("summary", None)
wiki_data.pop("content", None)
self.user_arango.db.collection("notes").insert(
wiki_data, overwrite=True, silent=True
)
self.add_note(wiki_data)
processor = PDFProcessor(process=False)
dois = [
processor.extract_doi(ref)
for ref in wiki_data.get("references", [])
if processor.extract_doi(ref)
]
if dois:
current_collection = st.session_state["settings"].get("current_collection")
st.markdown(
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?"
)
if st.button("Add DOIs"):
self.process_dois(current_collection, dois=dois)
self.update_session_state()
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
self.update_session_state()

@ -14,6 +14,7 @@ def use_tools(use: bool = False):
return 'If you need to use a tool to fetch information, you can do that as well.'
else:
return ''
def get_assistant_prompt():
"""
Returns a multi-line string that serves as a prompt for a research assistant AI.
@ -156,7 +157,7 @@ def get_note_summary_prompt(project: "Project", notes_string: str):
return re.sub(r"\s*\n\s*", "\n", query)
def get_image_system_prompt(project: "Project"):
def get_image_system_prompt(project):
system_message = f"""
You are an assistant to a journalist who is working on a project called {project.name}. Your task is to analyze and describe the images that are part of the project.
@ -172,6 +173,7 @@ def get_tools_prompt(user_input):
return f'''User message: "{user_input}"
Choose one or many tools in order to answer the message. It's important that you think of what information (if any) is needed to make a good answer.
Make sure to read the description of the tools carefully before choosing!
You can ONLY chose a tool you are provided with, don't make up a tool!
'''

@ -85,38 +85,40 @@ if st.session_state["authentication_status"]:
article_collections = st.Page(Article_Collections)
settings = st.Page(Settings)
rss_feeds = st.Page(RSS_Feeds)
sleep(0.1)
pg = st.navigation([bot_chat, projects, article_collections, rss_feeds, settings])
try:
pg.run()
except Exception as e:
print_red(e)
st.error("An error occurred. The site will be reloaded.")
import traceback
from datetime import datetime
from time import sleep
traceback_string = traceback.format_exc()
traceback.print_exc()
arango = ArangoDB(db_name="base")
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
print_rainbow(st.session_state.to_dict())
session_state = st.session_state.to_dict()
if 'bot' in session_state:
del session_state['bot']
arango.db.collection("error_logs").insert(
{
"error": traceback_string,
"_key": timestamp,
"session_state": session_state,
},
overwrite=True,
)
with st.status(":red[An error occurred. The site will be reloaded.]"):
for i in range(5):
sleep(1)
st.write(f"Reloading in {5-i} seconds...")
st.rerun()
sleep(0.1)
pg.run()
# try: #TODO Use this when in production
# pg.run()
# except Exception as e:
# print_red(e)
# st.error("An error occurred. The site will be reloaded.")
# import traceback
# from datetime import datetime
# from time import sleep
# traceback_string = traceback.format_exc()
# traceback.print_exc()
# arango = ArangoDB(db_name="base")
# timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
# print_rainbow(st.session_state.to_dict())
# session_state = st.session_state.to_dict()
# if 'bot' in session_state:
# del session_state['bot']
# arango.db.collection("error_logs").insert(
# {
# "error": traceback_string,
# "_key": timestamp,
# "session_state": session_state,
# },
# overwrite=True,
# )
# with st.status(":red[An error occurred. The site will be reloaded.]"):
# for i in range(5):
# sleep(1)
# st.write(f"Reloading in {5-i} seconds...")
# st.rerun()
with st.sidebar:
st.write("---")
authenticator.logout()

File diff suppressed because it is too large Load Diff

@ -1,13 +1,11 @@
# streamlit_pages.py
import streamlit as st
from colorprinter.print_color import *
from time import sleep
from colorprinter.print_color import *
def Projects():
"""
Function to handle the Projects page.
"""
from _classes import ProjectsPage
from projects_page import ProjectsPage
if 'Projects' not in st.session_state:
st.session_state['Projects'] = {}
projectpage = ProjectsPage(username=st.session_state["username"])
@ -18,6 +16,7 @@ def Bot_Chat():
Function to handle the Chat Bot page.
"""
from _classes import BotChatPage
sleep(0.1)
if 'Bot Chat' not in st.session_state:
st.session_state['Bot Chat'] = {}
chatpage = BotChatPage(username=st.session_state["username"])
@ -27,7 +26,8 @@ def Article_Collections():
"""
Function to handle the Article Collections page.
"""
from _classes import ArticleCollectionsPage
from collections_page import ArticleCollectionsPage
sleep(0.1)
if 'Article Collections' not in st.session_state:
st.session_state['Article Collections'] = {}
@ -41,6 +41,7 @@ def Settings():
"""
from _classes import SettingsPage
settings = SettingsPage(username=st.session_state["username"])
sleep(0.1)
settings.run()
@ -53,4 +54,5 @@ def RSS_Feeds():
st.session_state['RSS Feeds'] = {}
rss_feeds_page = RSSFeedsPage(username=st.session_state["username"])
sleep(0.1)
rss_feeds_page.run()

@ -1,7 +1,7 @@
import os
import urllib
import streamlit as st
from _base_class import BaseClass
from _base_class import StreamlitBaseClass
import feedparser
import requests
from bs4 import BeautifulSoup
@ -11,7 +11,7 @@ from colorprinter.print_color import *
from datetime import datetime, timedelta
class RSSFeedsPage(BaseClass):
class RSSFeedsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.page_name = "RSS Feeds"

@ -1,32 +1,38 @@
import os
import base64
from ollama import Client
from ollama import Client, ChatResponse
import env_manager
from colorprinter.print_color import *
import httpx
env_manager.set_env()
# Encode the credentials
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
# Set up the headers with authentication details
headers = {
'Authorization': f'Basic {encoded_credentials}'
}
# Get the host URL (base URL only)
host_url = os.getenv("LLM_API_URL").rstrip('/api/chat/')
# Initialize the client with the host and headers
auth = httpx.BasicAuth(
username='lasse', password=os.getenv("LLM_API_PWD_LASSE")
)
client = httpx.Client(auth=auth)
client = Client(
host=host_url,
headers=headers
host="http://localhost:11434",
headers={
"X-Chosen-Backend": "backend_ollama" # Add this header to specify the chosen backend
},
auth=auth
)
response = client.chat(
model=os.getenv("LLM_MODEL"),
messages=[
{
"role": "user",
"content": "Why is the sky blue?",
},
],
)
# Example usage of the client
try:
response = client.chat(model=os.getenv('LLM_MODEL') , messages=[{'role': 'user', 'content': 'Why is the sky blue?'}])
print_rainbow(response)
except Exception as e:
print(f"Error: {e}")
# Print the response headers
# Print the chosen backend from the headers
print("Chosen Backend:", response.headers.get("X-Chosen-Backend"))
# Print the response content
print(response)

@ -0,0 +1,9 @@
from _llm import LLM
llm = LLM()
image = '/home/lasse/sci/test_image.png'
image_bytes = open(image, 'rb').read()
print(type(image_bytes))
response = llm.generate('What is this?', images=[image_bytes])
print(response)

@ -0,0 +1,59 @@
import io
import os
import requests
from pydub import AudioSegment
import streamlit as st
def streamlit_audio(uploaded_file):
if uploaded_file is not None:
# Read the uploaded file into a BytesIO buffer
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
filename = uploaded_file.name
input_file_buffer = io.BytesIO(uploaded_file.getvalue())
progress_bar = st.progress(0)
status_text = st.empty()
if file_extension in ['.m4a', '.mp3', '.wav', '.flac']:
# Handle audio files
audio = AudioSegment.from_file(input_file_buffer, format=file_extension.replace('.', ''))
audio = audio.set_channels(1) # Convert to mono
file_buffer = io.BytesIO()
audio.export(file_buffer, format="mp3", bitrate="64k")
file_buffer.seek(0)
progress_bar.progress(50)
status_text.text("Audio file converted.")
else:
st.error("Unsupported file type")
st.stop()
# Send the converted audio data to the transcription service
try:
response = transcribe(file_buffer, filename)
response_json = response.json()
progress_bar.progress(100)
status_text.text("File uploaded and processed.")
if response.status_code == 200:
transcription_content = response_json.get("transcription", "")
st.subheader("Transcription")
st.text_area("Transcription Content", transcription_content, height=300)
transcription_filename = os.path.splitext(filename)[0] + '.vtt'
st.download_button(
label="Download Transcription",
data=transcription_content,
file_name=transcription_filename,
mime='text/vtt'
)
else:
st.error("Failed to upload and process the file.")
except requests.exceptions.Timeout:
st.error("The request timed out. Please try again later.")
def transcribe(file_buffer, filename):
url = "http://98.128.172.165:4001/upload"
# Prepare the files dictionary for the POST request
files = {'file': (filename, file_buffer, 'audio/mp3')}
# Send the POST request with the file buffer
response = requests.post(url, files=files, timeout=3600)
return response
Loading…
Cancel
Save