You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

873 lines
33 KiB

import re
import os
import streamlit as st
from streamlit.runtime.uploaded_file_manager import UploadedFile
from time import sleep
from datetime import datetime
from PIL import Image
from io import BytesIO
import base64
from article2db import PDFProcessor
from utils import fix_key
from _arango import ArangoDB
from _llm import LLM
from _base_class import StreamlitBaseClass
from colorprinter.print_color import *
from prompts import get_note_summary_prompt, get_image_system_prompt
import env_manager
env_manager.set_env()
class ProjectsPage(StreamlitBaseClass):
def __init__(self, username: str):
super().__init__(username=username)
self.projects = []
self.selected_project_name = None
self.project = self.get_settings("current_project")
self.page_name = "Projects"
# Initialize attributes from session state if available
page_state = st.session_state.get(self.page_name, {})
for k, v in page_state.items():
setattr(self, k, v)
def run(self):
self.update_current_page(self.page_name)
self.load_projects()
self.display_projects()
# Update session state
self.update_session_state(self.page_name)
def load_projects(self):
# Get projects using the new API method
self.projects = self.user_arango.get_projects(username=self.username)
def display_projects(self):
with st.sidebar:
self.new_project_button()
projects = [proj["name"] for proj in self.projects]
self.selected_project_name = st.selectbox(
"Select a project to manage",
options=projects,
index=projects.index(self.project) if self.project in projects else None,
)
if self.selected_project_name:
self.project = Project(
username=self.username,
project_name=self.selected_project_name,
user_arango=self.user_arango,
)
self.manage_project()
# Update session state
self.update_session_state(self.page_name)
def new_project_button(self):
st.session_state.setdefault("new_project", False)
with st.sidebar:
if st.button("New project", type="primary"):
st.session_state["new_project"] = True
if st.session_state["new_project"]:
self.create_new_project()
# Update session state
self.update_session_state(self.page_name)
def create_new_project(self):
new_project_name = st.text_input("Enter the name of the new project")
new_project_description = st.text_area(
"Enter the description of the new project"
)
if st.button("Create Project"):
if new_project_name:
# Use the API to create a new project
self.user_arango.create_project({
"name": new_project_name,
"description": new_project_description,
"username": self.username,
"collections": [],
"notes": [],
"note_keys_hash": hash(""),
"settings": {},
})
st.success(f'New project "{new_project_name}" created')
st.session_state["new_project"] = False
self.update_settings("current_project", new_project_name)
sleep(1)
st.rerun()
def show_project_notes(self):
with st.expander("Show summarised notes"):
st.markdown(self.project.notes_summary)
with st.expander("Show project notes"):
# Use the API to get project notes
notes = self.user_arango.get_project_notes(
project_name=self.project.name,
username=self.username
)
if notes:
for note in notes:
st.markdown(f'_{note.get("timestamp", "")}_')
st.markdown(note["text"].replace("\n", " \n"))
st.button(
key=f'delete_note_{note["_id"]}',
label=":red[Delete note]",
on_click=self.project.delete_note,
args=(note["_id"],),
)
st.write("---")
else:
st.write("No notes in this project.")
def show_project_interviews(self):
with st.expander("Show project interviews"):
# Use the API to create collection if it doesn't exist
if not self.user_arango.has_collection("interviews"):
self.user_arango.create_collection("interviews")
# Use the API to get interviews for this project
interviews = self.user_arango.execute_aql(
"""
FOR doc IN interviews
FILTER doc.project == @project_name
RETURN doc
""",
bind_vars={"project_name": self.project.name}
)
interviews_list = list(interviews)
if interviews_list:
for interview in interviews_list:
st.markdown(f'_{interview.get("timestamp", "")}_')
if interview.get('intervievees'):
st.markdown(
f"**Interviewees:** {', '.join(interview['intervievees'])}"
)
if interview.get('interviewer'):
st.markdown(f"**Interviewer:** {interview['interviewer']}")
if len(interview["transcript"].split("\n")) > 6:
preview = (
" \n".join(interview["transcript"].split("\n")[:6])
+ " \n(...)"
)
else:
preview = interview["transcript"]
timestamps = re.findall(r"\[(.*?)\]", preview)
for ts in timestamps:
preview = preview.replace(f"[{ts}]", f":grey[{ts}]")
st.markdown(preview)
c1, c2 = st.columns(2)
with c1:
st.download_button(
label="Download Transcript",
key=f"download_transcript_{interview['_key']}",
data=interview["transcript"],
file_name=interview["filename"],
mime="text/vtt",
)
with c2:
st.button(
key=f'delete_interview_{interview["_key"]}',
label=":red[Delete interview]",
on_click=self.project.delete_interview,
args=(interview["_key"],),
)
st.write("---")
else:
st.write("No interviews in this project.")
def manage_project(self):
self.update_settings("current_project", self.selected_project_name)
# Initialize the Project instance
self.project = Project(
self.username, self.selected_project_name, self.user_arango
)
st.write(f"## {self.project.name}")
self.show_project_interviews()
self.show_project_notes()
self.relate_collections()
self.sidebar_actions()
self.project.update_notes_hash()
if st.button(f":red[Remove project *{self.project.name}*]"):
# Use the API to delete the project
self.user_arango.delete_project(
project_name=self.project.name,
username=self.username
)
self.update_settings("current_project", None)
st.success(f'Project "{self.project.name}" removed')
st.rerun()
# Update session state
self.update_session_state(self.page_name)
def relate_collections(self):
# Get all collections using the API
collections = self.user_arango.execute_aql(
"FOR c IN article_collections RETURN c.name"
)
collections_list = list(collections)
selected_collections = st.multiselect(
"Relate existing collections", options=collections_list
)
if st.button("Relate Collections"):
self.project.add_collections(selected_collections)
st.success("Collections related to the project")
# Update session state
self.update_session_state(self.page_name)
new_collection_name = st.text_input(
"Enter the name of the new collection to create and relate"
)
if st.button("Create and Relate Collection"):
if new_collection_name:
# Use the API to insert a new collection
self.user_arango.insert_document(
collection_name="article_collections",
document={"name": new_collection_name, "articles": []}
)
self.project.add_collection(new_collection_name)
st.success(
f'New collection "{new_collection_name}" created and related to the project'
)
# Update session state
self.update_session_state(self.page_name)
def sidebar_actions(self):
self.sidebar_interview()
self.sidebar_notes()
# Update session state
self.update_session_state(self.page_name)
def sidebar_notes(self):
with st.sidebar:
st.markdown(f"### Add new notes to {self.project.name}")
self.upload_notes_form()
self.add_text_note()
self.add_wikipedia_data()
# Update session state
self.update_session_state(self.page_name)
def sidebar_interview(self):
with st.sidebar:
st.markdown(f"### Add new interview to {self.project.name}")
self.upload_interview_form()
# Update session state
self.update_session_state(self.page_name)
def upload_notes_form(self):
with st.expander("Upload notes"):
with st.form("add_notes", clear_on_submit=True):
files = st.file_uploader(
"Upload PDF or image",
type=["png", "jpg", "pdf"],
accept_multiple_files=True,
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.process_uploaded_notes(files)
# Update session state
self.update_session_state(self.page_name)
def upload_interview_form(self):
with st.expander("Upload interview"):
with st.form("add_interview", clear_on_submit=True):
interview = st.file_uploader("Upload interview audio file or transcript")
interviewees = st.text_input(
"Enter the names of the interviewees, separated by commas"
)
interviewer = st.text_input(
"Enter the interviewer's name",
help="If left blank, the current user will be used",
)
date_of_interveiw = st.date_input(
"Date of interview", value=None, format="YYYY-MM-DD"
)
submitted = st.form_submit_button("Upload")
if submitted:
self.project.add_interview(
interview, interviewees, interviewer, date_of_interveiw
)
# Update session state
self.update_session_state(self.page_name)
def add_text_note(self):
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies."
note_text = st.text_area("Write or paste anything.", help=help_text)
if st.button("Add Note"):
self.project.add_note(
{
"text": note_text,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"),
}
)
st.success("Note added to the project")
# Update session state
self.update_session_state(self.page_name)
def add_wikipedia_data(self):
wiki_url = st.text_input(
"Paste the address to a Wikipedia page to add its summary as a note",
placeholder="Paste Wikipedia URL",
)
if st.button("Add Wikipedia data"):
with st.spinner("Fetching Wikipedia data..."):
wiki_data = self.project.get_wikipedia_data(wiki_url)
if wiki_data:
self.project.process_wikipedia_data(wiki_data, wiki_url)
st.success("Wikipedia data added to notes")
# Update session state
self.update_session_state(self.page_name)
st.rerun()
class Project(StreamlitBaseClass):
"""
A class to represent a project in the Streamlit application.
Attributes:
-----------
username : str
The username of the project owner.
project_name : str
The name of the project.
user_arango : ArangoDB
The ArangoDB instance for the user.
name : str
The name of the project.
description : str
The description of the project.
collections : list
A list of collections associated with the project.
notes : list
A list of notes associated with the project.
note_keys_hash : int
A hash value representing the keys of the notes.
settings : dict
A dictionary of settings for the project.
notes_summary : str
A summary of the notes in the project.
"""
def __init__(self, username: str, project_name: str, user_arango: ArangoDB):
super().__init__(username=username)
self.name = project_name
self.user_arango = user_arango
self.description = ""
self.collections = []
self.notes = []
self.note_keys_hash = 0
self.settings = {}
self.notes_summary = ""
self._key = None
# Initialize attributes from arango doc if available
self.load_project()
def load_project(self):
print_blue("Project name:", self.name)
# Use the API to get project details
project = self.user_arango.get_project(
project_name=self.name,
username=self.username
)
if not project:
raise ValueError(f"Project '{self.name}' not found.")
self._key = project["_key"]
self.name = project.get("name", "")
self.description = project.get("description", "")
self.collections = project.get("collections", [])
self.notes = project.get("notes", [])
self.note_keys_hash = project.get("note_keys_hash", 0)
self.settings = project.get("settings", {})
self.notes_summary = project.get("notes_summary", "")
def update_project(self):
# Use the API to update project details
updated_doc = {
"_id": f"projects/{self._key}",
"_key": self._key,
"name": self.name,
"description": self.description,
"collections": self.collections,
"notes": self.notes,
"note_keys_hash": self.note_keys_hash,
"settings": self.settings,
"notes_summary": self.notes_summary,
"username": self.username
}
self.user_arango.update_project(updated_doc)
self.update_session_state()
def add_collections(self, collections):
self.collections.extend(collections)
self.collections = list(set(self.collections))
self.update_project()
def add_collection(self, collection_name):
self.collections.append(collection_name)
self.collections = list(set(self.collections))
self.update_project()
def add_note(self, note: dict):
assert note["text"], "Note text cannot be empty"
note["text"] = note["text"].strip().strip("\n")
if "timestamp" not in note:
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M")
# Use the API to add a note to the project
note["project"] = self.name
note["username"] = self.username
note_doc = self.user_arango.add_note_to_project(note)
if note_doc["_id"] not in self.notes:
self.notes.append(note_doc["_id"])
self.update_project()
def add_interview(
self,
interview: UploadedFile,
intervievees: str,
interviewer: str,
date_of_interveiw: datetime.date = None,
):
# TODO Implement this method
# Check if interview is a sound (WAV, Mp3, AAC, etc) file or a text file (PDF, DOCX, TXT, etc)
if interview.type in ["audio/x-wav", "audio/mpeg"]:
transcription = self.transcribe(interview)
transcription_preview = (
" \n".join(transcription.split("\n")[:4]) + " \n(...)"
)
st.markdown(transcription_preview)
transcription_filename = os.path.splitext(interview.name)[0] + ".vtt"
c1, c2 = st.columns(2)
with c1:
st.button(
"Add to project",
on_click=self.add_interview_transcript,
args=(
transcription,
transcription_filename,
intervievees,
interviewer,
date_of_interveiw,
),
)
with c2:
st.download_button(
label="Download Transcription",
data=transcription,
file_name=transcription_filename,
mime="text/vtt",
)
elif interview.type in ["application/pdf"]:
PDFProcessor(
pdf_file=interview,
is_sci=False,
document_type="interview",
is_image=False,
)
elif interview.type in ["application/json", "text/plain"]:
import json
print_purple("JSON file processing")
interview_content = interview.getvalue().decode("utf-8")
print('Content:', interview_content)
interview_json = json.loads(interview_content)
formated_transcription = self.format_json_transcription(interview_json)
self.add_interview_transcript(
formated_transcription,
interview.name,
intervievees=None,
interviewer=None,
date_of_interveiw=None
)
else:
print(interview.type)
st.error("Unsupported file type")
st.stop()
st.rerun()
def add_interview_transcript(
self,
transcript,
filename,
intervievees: str = None,
interviewer: str = None,
date_of_interveiw: datetime.date = None,
):
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
_key = fix_key(f"{filename}_{timestamp}")
if intervievees:
intervievees = [
i.strip() for i in intervievees.split(",") if len(i.strip()) > 0
]
if not interviewer:
interviewer = self.username
# Ensure interviews collection exists using the API
if not self.user_arango.has_collection("interviews"):
self.user_arango.create_collection("interviews")
if isinstance(date_of_interveiw, str):
date_of_interveiw = datetime.strptime(date_of_interveiw, "%Y-%m-%d")
from article2db import Document
document = Document(
text=transcript,
is_sci=False,
_key=_key,
filename=filename,
arango_db_name=self.username,
username=self.username,
arango_collection="interviews",
)
document.make_chunks(len_chunks=600)
# Use the API to insert the interview document
self.user_arango.insert_document(
collection_name="interviews",
document={
"_key": _key,
"transcript": transcript,
"project": self.name,
"filename": filename,
"timestamp": timestamp,
"intervievees": intervievees,
"interviewer": interviewer,
"date_of_interveiw": date_of_interveiw.isoformat() if date_of_interveiw else None,
"chunks": document.chunks,
},
overwrite=True
)
document.make_summary_in_background()
def transcribe(self, uploaded_file: UploadedFile):
from pydub import AudioSegment
import requests
import io
file_extension = os.path.splitext(uploaded_file.name)[1].lower()
filename = uploaded_file.name
input_file_buffer = io.BytesIO(uploaded_file.getvalue())
progress_bar = st.progress(0)
status_text = st.empty()
if file_extension in [".m4a", ".mp3", ".wav", ".flac"]:
# Handle audio files
audio = AudioSegment.from_file(
input_file_buffer, format=file_extension.replace(".", "")
)
audio = audio.set_channels(1) # Convert to mono
file_buffer = io.BytesIO()
audio.export(file_buffer, format="mp3", bitrate="64k")
file_buffer.seek(0)
progress_bar.progress(50)
status_text.text("Audio file converted.")
else:
st.error("Unsupported file type")
st.stop()
# Send the converted audio data to the transcription service
try:
try:
url = os.getenv("TRANSCRIBE_URL")
except:
import dotenv
dotenv.load_dotenv()
url = os.getenv("TRANSCRIBE_URL")
# Prepare the files dictionary for the POST request
files = {"file": (filename, file_buffer, "audio/mp3")}
# Send the POST request with the file buffer
response = requests.post(url, files=files, timeout=3600)
response_json = response.json()
progress_bar.progress(100)
status_text.text("File uploaded and processed.")
if response.status_code == 200:
transcription_content = response_json.get("transcription", "")
transcription_content = self.format_transcription(transcription_content)
return transcription_content
else:
st.error("Failed to upload and process the file.")
except requests.exceptions.Timeout:
st.error("The request timed out. Please try again later.")
def format_transcription(self, transcription: str):
lines = transcription.split("\n")
transcript = []
timestamp = None
for line in lines:
if "-->" in line:
timestamp = line[: line.find(".")]
elif timestamp:
line = f"[{timestamp}] {line}"
transcript.append(line)
timestamp = None
return "\n".join(transcript)
def format_json_transcription(self, transcription: dict):
transcript = []
print(transcript)
if isinstance(transcription, list):
# For the JSON format in MacWhisper
for line in transcription:
speaker = line.get("speaker", None)
if speaker:
line = f"[{line['timestamp']}] {speaker}: {line['text']}"
else:
line = f"[{line['timestamp']}] {line['text']}"
transcript.append(line)
elif isinstance(transcription, dict):
# For the DOT format in MacWhisper
if 'lines' in transcription:
for line in transcription['lines']:
timestamp = line['startTime']
text = line['text']
speaker = line.get('speaker', None)
if speaker:
line = f"[{timestamp}] {speaker}: {text}"
else:
line = f"[{timestamp}] {text}"
transcript.append(line)
return "\n".join(transcript)
def delete_note(self, note_id):
if note_id in self.notes:
self.notes.remove(note_id)
# Delete the note document using the API
self.user_arango.delete_document(
collection_name="notes",
document_key=note_id.split("/")[1]
)
self.update_project()
def delete_interview(self, interview_id):
# Delete interview using the API
self.user_arango.delete_document(
collection_name="interviews",
document_key=interview_id
)
def update_notes_hash(self):
current_hash = self.make_project_notes_hash()
if current_hash != self.note_keys_hash:
self.note_keys_hash = current_hash
with st.spinner("Summarizing notes for chatbot..."):
self.create_notes_summary()
self.update_project()
def make_project_notes_hash(self):
if not self.notes:
return hash("")
note_keys_str = "".join(self.notes)
return hash(note_keys_str)
def create_notes_summary(self):
# Get note texts using the API
notes_list = []
for note_id in self.notes:
note = self.user_arango.get_document(note_id)
if note and "text" in note:
notes_list.append(note["text"])
notes_string = "\n---\n".join(notes_list)
llm = LLM(model="small")
query = get_note_summary_prompt(self, notes_string)
summary = llm.generate(query).content
print_purple("New summary of notes:", summary)
self.notes_summary = summary
self.update_session_state()
def analyze_image(self, image_base64, text=None):
project_data = {"name": self.name}
llm = LLM(system_message=get_image_system_prompt(self))
prompt = (
f'Analyze the image. The text found in it read: "{text}"'
if text
else "Analyze the image."
)
print_blue(type(image_base64))
description = llm.generate(query=prompt, images=[image_base64], stream=False)
print_green("Image description:", description)
def process_uploaded_notes(self, files):
with st.spinner("Processing files..."):
for file in files:
st.write("Processing...")
filename = fix_key(file.name)
image_file = self.file2img(file)
pdf_file = self.convert_image_to_pdf(image_file)
pdf = PDFProcessor(
pdf_file=pdf_file,
is_sci=False,
document_type="notes",
is_image=True,
process=False,
)
text = pdf.process_document()
base64_str = base64.b64encode(file.read())
image_caption = self.analyze_image(base64_str, text=text)
self.add_note(
{
"_id": f"notes/{filename}",
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}",
}
)
st.success("Done!")
sleep(1.5)
self.update_session_state()
st.rerun()
def file2img(self, file):
img_bytes = file.read()
if not img_bytes:
raise ValueError("Uploaded file is empty.")
return Image.open(BytesIO(img_bytes))
def convert_image_to_pdf(self, img):
import pytesseract
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img)
pdf_file = BytesIO(pdf_bytes)
pdf_file.name = (
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf"
)
return pdf_file
def get_wikipedia_data(self, page_url: str) -> dict:
import wikipedia
from urllib.parse import urlparse
parsed_url = urlparse(page_url)
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path)
if page_name_match:
page_name = page_name_match.group(0)
else:
st.warning("Invalid Wikipedia URL")
return None
try:
page = wikipedia.page(page_name, auto_suggest=False)
data = {
"title": page.title,
"summary": page.summary,
"content": page.content,
"url": page.url,
"references": page.references,
}
return data
except Exception as e:
st.error(f"Error fetching Wikipedia data: {e}")
return None
def process_wikipedia_data(self, wiki_data, wiki_url):
llm = LLM(
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!",
model="small",
)
if wiki_data.get("summary"):
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.'''
summary = llm.generate(query).content
wiki_data["text"] = (
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}"
)
wiki_data.pop("summary", None)
wiki_data.pop("content", None)
# Use the API to insert wiki data as a note
self.user_arango.insert_document(
collection_name="notes",
document={
**wiki_data,
"project": self.name,
"username": self.username,
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M")
},
overwrite=True
)
self.add_note(wiki_data)
processor = PDFProcessor(process=False)
dois = []
print_rainbow(wiki_data.get("references", []))
for ref in wiki_data.get("references", []):
doi = processor.extract_doi(ref)
if doi:
print_blue("Found DOI:", doi)
dois.append(doi)
if len(dois) > 0:
current_collection = st.session_state["settings"].get("current_collection")
st.markdown(
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?"
)
if st.button("Add DOIs"):
self.process_dois(current_collection, dois=dois)
self.update_session_state()
def process_dois(
self, article_collection_name: str, text: str = None, dois: list = None
) -> None:
processor = PDFProcessor(process=False)
if not dois and text:
dois = processor.extract_doi(text, multi=True)
if "not_downloaded" not in st.session_state:
st.session_state["not_downloaded"] = {}
for doi in dois:
downloaded, url, path, in_db = processor.doi2pdf(doi)
if downloaded and not in_db:
processor.process_pdf(path)
in_db = True
elif not downloaded and not in_db:
st.session_state["not_downloaded"][doi] = url
if in_db:
st.success(f"Article with DOI {doi} added")
self.articles2collection(
collection=article_collection_name,
db="base",
_id=f"sci_articles/{fix_key(doi)}",
)
self.update_session_state()
def articles2collection(self, collection, db, _id):
# Use the base/admin ArangoDB for general operations like adding to collections
base_arango = ArangoDB(db_name="base")
# Get the collection
collection_doc = base_arango.execute_aql(
"FOR c IN article_collections FILTER c.name == @name RETURN c",
bind_vars={"name": collection}
)
try:
collection_doc = next(collection_doc)
if _id not in collection_doc["articles"]:
collection_doc["articles"].append(_id)
# Update the collection
base_arango.update_document(collection_doc)
except StopIteration:
# Collection doesn't exist, create it
base_arango.insert_document(
collection_name="article_collections",
document={
"name": collection,
"articles": [_id]
}
)