You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
873 lines
33 KiB
873 lines
33 KiB
import re |
|
import os |
|
import streamlit as st |
|
from streamlit.runtime.uploaded_file_manager import UploadedFile |
|
from time import sleep |
|
from datetime import datetime |
|
from PIL import Image |
|
from io import BytesIO |
|
import base64 |
|
from article2db import PDFProcessor |
|
|
|
from utils import fix_key |
|
from _arango import ArangoDB |
|
from _llm import LLM |
|
from _base_class import StreamlitBaseClass |
|
from colorprinter.print_color import * |
|
|
|
from prompts import get_note_summary_prompt, get_image_system_prompt |
|
|
|
import env_manager |
|
|
|
env_manager.set_env() |
|
|
|
class ProjectsPage(StreamlitBaseClass): |
|
def __init__(self, username: str): |
|
super().__init__(username=username) |
|
self.projects = [] |
|
self.selected_project_name = None |
|
self.project = self.get_settings("current_project") |
|
self.page_name = "Projects" |
|
|
|
# Initialize attributes from session state if available |
|
page_state = st.session_state.get(self.page_name, {}) |
|
for k, v in page_state.items(): |
|
setattr(self, k, v) |
|
|
|
def run(self): |
|
self.update_current_page(self.page_name) |
|
self.load_projects() |
|
self.display_projects() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def load_projects(self): |
|
# Get projects using the new API method |
|
self.projects = self.user_arango.get_projects(username=self.username) |
|
|
|
def display_projects(self): |
|
with st.sidebar: |
|
self.new_project_button() |
|
projects = [proj["name"] for proj in self.projects] |
|
self.selected_project_name = st.selectbox( |
|
"Select a project to manage", |
|
options=projects, |
|
index=projects.index(self.project) if self.project in projects else None, |
|
) |
|
if self.selected_project_name: |
|
self.project = Project( |
|
username=self.username, |
|
project_name=self.selected_project_name, |
|
user_arango=self.user_arango, |
|
) |
|
self.manage_project() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def new_project_button(self): |
|
st.session_state.setdefault("new_project", False) |
|
with st.sidebar: |
|
if st.button("New project", type="primary"): |
|
st.session_state["new_project"] = True |
|
if st.session_state["new_project"]: |
|
self.create_new_project() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def create_new_project(self): |
|
new_project_name = st.text_input("Enter the name of the new project") |
|
new_project_description = st.text_area( |
|
"Enter the description of the new project" |
|
) |
|
if st.button("Create Project"): |
|
if new_project_name: |
|
# Use the API to create a new project |
|
self.user_arango.create_project({ |
|
"name": new_project_name, |
|
"description": new_project_description, |
|
"username": self.username, |
|
"collections": [], |
|
"notes": [], |
|
"note_keys_hash": hash(""), |
|
"settings": {}, |
|
}) |
|
st.success(f'New project "{new_project_name}" created') |
|
st.session_state["new_project"] = False |
|
self.update_settings("current_project", new_project_name) |
|
sleep(1) |
|
st.rerun() |
|
|
|
def show_project_notes(self): |
|
|
|
with st.expander("Show summarised notes"): |
|
st.markdown(self.project.notes_summary) |
|
|
|
with st.expander("Show project notes"): |
|
# Use the API to get project notes |
|
notes = self.user_arango.get_project_notes( |
|
project_name=self.project.name, |
|
username=self.username |
|
) |
|
if notes: |
|
for note in notes: |
|
st.markdown(f'_{note.get("timestamp", "")}_') |
|
st.markdown(note["text"].replace("\n", " \n")) |
|
st.button( |
|
key=f'delete_note_{note["_id"]}', |
|
label=":red[Delete note]", |
|
on_click=self.project.delete_note, |
|
args=(note["_id"],), |
|
) |
|
st.write("---") |
|
else: |
|
st.write("No notes in this project.") |
|
|
|
def show_project_interviews(self): |
|
with st.expander("Show project interviews"): |
|
# Use the API to create collection if it doesn't exist |
|
if not self.user_arango.has_collection("interviews"): |
|
self.user_arango.create_collection("interviews") |
|
|
|
# Use the API to get interviews for this project |
|
interviews = self.user_arango.execute_aql( |
|
""" |
|
FOR doc IN interviews |
|
FILTER doc.project == @project_name |
|
RETURN doc |
|
""", |
|
bind_vars={"project_name": self.project.name} |
|
) |
|
|
|
interviews_list = list(interviews) |
|
if interviews_list: |
|
for interview in interviews_list: |
|
st.markdown(f'_{interview.get("timestamp", "")}_') |
|
if interview.get('intervievees'): |
|
st.markdown( |
|
f"**Interviewees:** {', '.join(interview['intervievees'])}" |
|
) |
|
if interview.get('interviewer'): |
|
st.markdown(f"**Interviewer:** {interview['interviewer']}") |
|
if len(interview["transcript"].split("\n")) > 6: |
|
preview = ( |
|
" \n".join(interview["transcript"].split("\n")[:6]) |
|
+ " \n(...)" |
|
) |
|
else: |
|
preview = interview["transcript"] |
|
timestamps = re.findall(r"\[(.*?)\]", preview) |
|
for ts in timestamps: |
|
preview = preview.replace(f"[{ts}]", f":grey[{ts}]") |
|
st.markdown(preview) |
|
c1, c2 = st.columns(2) |
|
with c1: |
|
st.download_button( |
|
label="Download Transcript", |
|
key=f"download_transcript_{interview['_key']}", |
|
data=interview["transcript"], |
|
file_name=interview["filename"], |
|
mime="text/vtt", |
|
) |
|
with c2: |
|
st.button( |
|
key=f'delete_interview_{interview["_key"]}', |
|
label=":red[Delete interview]", |
|
on_click=self.project.delete_interview, |
|
args=(interview["_key"],), |
|
) |
|
st.write("---") |
|
else: |
|
st.write("No interviews in this project.") |
|
|
|
def manage_project(self): |
|
self.update_settings("current_project", self.selected_project_name) |
|
# Initialize the Project instance |
|
self.project = Project( |
|
self.username, self.selected_project_name, self.user_arango |
|
) |
|
st.write(f"## {self.project.name}") |
|
self.show_project_interviews() |
|
self.show_project_notes() |
|
self.relate_collections() |
|
self.sidebar_actions() |
|
self.project.update_notes_hash() |
|
if st.button(f":red[Remove project *{self.project.name}*]"): |
|
# Use the API to delete the project |
|
self.user_arango.delete_project( |
|
project_name=self.project.name, |
|
username=self.username |
|
) |
|
self.update_settings("current_project", None) |
|
st.success(f'Project "{self.project.name}" removed') |
|
st.rerun() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def relate_collections(self): |
|
# Get all collections using the API |
|
collections = self.user_arango.execute_aql( |
|
"FOR c IN article_collections RETURN c.name" |
|
) |
|
collections_list = list(collections) |
|
selected_collections = st.multiselect( |
|
"Relate existing collections", options=collections_list |
|
) |
|
if st.button("Relate Collections"): |
|
self.project.add_collections(selected_collections) |
|
st.success("Collections related to the project") |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
new_collection_name = st.text_input( |
|
"Enter the name of the new collection to create and relate" |
|
) |
|
if st.button("Create and Relate Collection"): |
|
if new_collection_name: |
|
# Use the API to insert a new collection |
|
self.user_arango.insert_document( |
|
collection_name="article_collections", |
|
document={"name": new_collection_name, "articles": []} |
|
) |
|
self.project.add_collection(new_collection_name) |
|
st.success( |
|
f'New collection "{new_collection_name}" created and related to the project' |
|
) |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def sidebar_actions(self): |
|
self.sidebar_interview() |
|
self.sidebar_notes() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def sidebar_notes(self): |
|
with st.sidebar: |
|
st.markdown(f"### Add new notes to {self.project.name}") |
|
self.upload_notes_form() |
|
self.add_text_note() |
|
self.add_wikipedia_data() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def sidebar_interview(self): |
|
with st.sidebar: |
|
st.markdown(f"### Add new interview to {self.project.name}") |
|
self.upload_interview_form() |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def upload_notes_form(self): |
|
with st.expander("Upload notes"): |
|
|
|
with st.form("add_notes", clear_on_submit=True): |
|
files = st.file_uploader( |
|
"Upload PDF or image", |
|
type=["png", "jpg", "pdf"], |
|
accept_multiple_files=True, |
|
) |
|
submitted = st.form_submit_button("Upload") |
|
if submitted: |
|
self.project.process_uploaded_notes(files) |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def upload_interview_form(self): |
|
with st.expander("Upload interview"): |
|
with st.form("add_interview", clear_on_submit=True): |
|
interview = st.file_uploader("Upload interview audio file or transcript") |
|
interviewees = st.text_input( |
|
"Enter the names of the interviewees, separated by commas" |
|
) |
|
interviewer = st.text_input( |
|
"Enter the interviewer's name", |
|
help="If left blank, the current user will be used", |
|
) |
|
date_of_interveiw = st.date_input( |
|
"Date of interview", value=None, format="YYYY-MM-DD" |
|
) |
|
submitted = st.form_submit_button("Upload") |
|
if submitted: |
|
self.project.add_interview( |
|
interview, interviewees, interviewer, date_of_interveiw |
|
) |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def add_text_note(self): |
|
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies." |
|
note_text = st.text_area("Write or paste anything.", help=help_text) |
|
if st.button("Add Note"): |
|
self.project.add_note( |
|
{ |
|
"text": note_text, |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"), |
|
} |
|
) |
|
st.success("Note added to the project") |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
|
|
def add_wikipedia_data(self): |
|
wiki_url = st.text_input( |
|
"Paste the address to a Wikipedia page to add its summary as a note", |
|
placeholder="Paste Wikipedia URL", |
|
) |
|
if st.button("Add Wikipedia data"): |
|
with st.spinner("Fetching Wikipedia data..."): |
|
wiki_data = self.project.get_wikipedia_data(wiki_url) |
|
if wiki_data: |
|
self.project.process_wikipedia_data(wiki_data, wiki_url) |
|
st.success("Wikipedia data added to notes") |
|
# Update session state |
|
self.update_session_state(self.page_name) |
|
st.rerun() |
|
|
|
|
|
class Project(StreamlitBaseClass): |
|
""" |
|
A class to represent a project in the Streamlit application. |
|
|
|
Attributes: |
|
----------- |
|
username : str |
|
The username of the project owner. |
|
project_name : str |
|
The name of the project. |
|
user_arango : ArangoDB |
|
The ArangoDB instance for the user. |
|
name : str |
|
The name of the project. |
|
description : str |
|
The description of the project. |
|
collections : list |
|
A list of collections associated with the project. |
|
notes : list |
|
A list of notes associated with the project. |
|
note_keys_hash : int |
|
A hash value representing the keys of the notes. |
|
settings : dict |
|
A dictionary of settings for the project. |
|
notes_summary : str |
|
A summary of the notes in the project. |
|
""" |
|
def __init__(self, username: str, project_name: str, user_arango: ArangoDB): |
|
super().__init__(username=username) |
|
self.name = project_name |
|
self.user_arango = user_arango |
|
self.description = "" |
|
self.collections = [] |
|
self.notes = [] |
|
self.note_keys_hash = 0 |
|
self.settings = {} |
|
self.notes_summary = "" |
|
self._key = None |
|
|
|
# Initialize attributes from arango doc if available |
|
self.load_project() |
|
|
|
def load_project(self): |
|
print_blue("Project name:", self.name) |
|
|
|
# Use the API to get project details |
|
project = self.user_arango.get_project( |
|
project_name=self.name, |
|
username=self.username |
|
) |
|
|
|
if not project: |
|
raise ValueError(f"Project '{self.name}' not found.") |
|
|
|
self._key = project["_key"] |
|
self.name = project.get("name", "") |
|
self.description = project.get("description", "") |
|
self.collections = project.get("collections", []) |
|
self.notes = project.get("notes", []) |
|
self.note_keys_hash = project.get("note_keys_hash", 0) |
|
self.settings = project.get("settings", {}) |
|
self.notes_summary = project.get("notes_summary", "") |
|
|
|
def update_project(self): |
|
# Use the API to update project details |
|
updated_doc = { |
|
"_id": f"projects/{self._key}", |
|
"_key": self._key, |
|
"name": self.name, |
|
"description": self.description, |
|
"collections": self.collections, |
|
"notes": self.notes, |
|
"note_keys_hash": self.note_keys_hash, |
|
"settings": self.settings, |
|
"notes_summary": self.notes_summary, |
|
"username": self.username |
|
} |
|
self.user_arango.update_project(updated_doc) |
|
self.update_session_state() |
|
|
|
def add_collections(self, collections): |
|
self.collections.extend(collections) |
|
self.collections = list(set(self.collections)) |
|
self.update_project() |
|
|
|
def add_collection(self, collection_name): |
|
self.collections.append(collection_name) |
|
self.collections = list(set(self.collections)) |
|
self.update_project() |
|
|
|
def add_note(self, note: dict): |
|
assert note["text"], "Note text cannot be empty" |
|
note["text"] = note["text"].strip().strip("\n") |
|
if "timestamp" not in note: |
|
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M") |
|
|
|
# Use the API to add a note to the project |
|
note["project"] = self.name |
|
note["username"] = self.username |
|
|
|
note_doc = self.user_arango.add_note_to_project(note) |
|
|
|
if note_doc["_id"] not in self.notes: |
|
self.notes.append(note_doc["_id"]) |
|
self.update_project() |
|
|
|
def add_interview( |
|
self, |
|
interview: UploadedFile, |
|
intervievees: str, |
|
interviewer: str, |
|
date_of_interveiw: datetime.date = None, |
|
): |
|
# TODO Implement this method |
|
# Check if interview is a sound (WAV, Mp3, AAC, etc) file or a text file (PDF, DOCX, TXT, etc) |
|
if interview.type in ["audio/x-wav", "audio/mpeg"]: |
|
transcription = self.transcribe(interview) |
|
transcription_preview = ( |
|
" \n".join(transcription.split("\n")[:4]) + " \n(...)" |
|
) |
|
st.markdown(transcription_preview) |
|
transcription_filename = os.path.splitext(interview.name)[0] + ".vtt" |
|
c1, c2 = st.columns(2) |
|
with c1: |
|
st.button( |
|
"Add to project", |
|
on_click=self.add_interview_transcript, |
|
args=( |
|
transcription, |
|
transcription_filename, |
|
intervievees, |
|
interviewer, |
|
date_of_interveiw, |
|
), |
|
) |
|
with c2: |
|
st.download_button( |
|
label="Download Transcription", |
|
data=transcription, |
|
file_name=transcription_filename, |
|
mime="text/vtt", |
|
) |
|
elif interview.type in ["application/pdf"]: |
|
PDFProcessor( |
|
pdf_file=interview, |
|
is_sci=False, |
|
document_type="interview", |
|
is_image=False, |
|
) |
|
|
|
elif interview.type in ["application/json", "text/plain"]: |
|
import json |
|
print_purple("JSON file processing") |
|
interview_content = interview.getvalue().decode("utf-8") |
|
print('Content:', interview_content) |
|
interview_json = json.loads(interview_content) |
|
formated_transcription = self.format_json_transcription(interview_json) |
|
self.add_interview_transcript( |
|
formated_transcription, |
|
interview.name, |
|
intervievees=None, |
|
interviewer=None, |
|
date_of_interveiw=None |
|
) |
|
else: |
|
print(interview.type) |
|
st.error("Unsupported file type") |
|
st.stop() |
|
|
|
st.rerun() |
|
|
|
def add_interview_transcript( |
|
self, |
|
transcript, |
|
filename, |
|
intervievees: str = None, |
|
interviewer: str = None, |
|
date_of_interveiw: datetime.date = None, |
|
): |
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") |
|
_key = fix_key(f"{filename}_{timestamp}") |
|
if intervievees: |
|
intervievees = [ |
|
i.strip() for i in intervievees.split(",") if len(i.strip()) > 0 |
|
] |
|
if not interviewer: |
|
interviewer = self.username |
|
|
|
# Ensure interviews collection exists using the API |
|
if not self.user_arango.has_collection("interviews"): |
|
self.user_arango.create_collection("interviews") |
|
|
|
if isinstance(date_of_interveiw, str): |
|
date_of_interveiw = datetime.strptime(date_of_interveiw, "%Y-%m-%d") |
|
|
|
from article2db import Document |
|
|
|
document = Document( |
|
text=transcript, |
|
is_sci=False, |
|
_key=_key, |
|
filename=filename, |
|
arango_db_name=self.username, |
|
username=self.username, |
|
arango_collection="interviews", |
|
) |
|
|
|
document.make_chunks(len_chunks=600) |
|
|
|
# Use the API to insert the interview document |
|
self.user_arango.insert_document( |
|
collection_name="interviews", |
|
document={ |
|
"_key": _key, |
|
"transcript": transcript, |
|
"project": self.name, |
|
"filename": filename, |
|
"timestamp": timestamp, |
|
"intervievees": intervievees, |
|
"interviewer": interviewer, |
|
"date_of_interveiw": date_of_interveiw.isoformat() if date_of_interveiw else None, |
|
"chunks": document.chunks, |
|
}, |
|
overwrite=True |
|
) |
|
|
|
document.make_summary_in_background() |
|
|
|
def transcribe(self, uploaded_file: UploadedFile): |
|
from pydub import AudioSegment |
|
import requests |
|
import io |
|
|
|
file_extension = os.path.splitext(uploaded_file.name)[1].lower() |
|
filename = uploaded_file.name |
|
input_file_buffer = io.BytesIO(uploaded_file.getvalue()) |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
if file_extension in [".m4a", ".mp3", ".wav", ".flac"]: |
|
# Handle audio files |
|
audio = AudioSegment.from_file( |
|
input_file_buffer, format=file_extension.replace(".", "") |
|
) |
|
audio = audio.set_channels(1) # Convert to mono |
|
file_buffer = io.BytesIO() |
|
audio.export(file_buffer, format="mp3", bitrate="64k") |
|
file_buffer.seek(0) |
|
progress_bar.progress(50) |
|
status_text.text("Audio file converted.") |
|
else: |
|
st.error("Unsupported file type") |
|
st.stop() |
|
|
|
# Send the converted audio data to the transcription service |
|
try: |
|
try: |
|
url = os.getenv("TRANSCRIBE_URL") |
|
except: |
|
import dotenv |
|
|
|
dotenv.load_dotenv() |
|
url = os.getenv("TRANSCRIBE_URL") |
|
|
|
# Prepare the files dictionary for the POST request |
|
files = {"file": (filename, file_buffer, "audio/mp3")} |
|
# Send the POST request with the file buffer |
|
response = requests.post(url, files=files, timeout=3600) |
|
|
|
response_json = response.json() |
|
progress_bar.progress(100) |
|
status_text.text("File uploaded and processed.") |
|
|
|
if response.status_code == 200: |
|
transcription_content = response_json.get("transcription", "") |
|
transcription_content = self.format_transcription(transcription_content) |
|
return transcription_content |
|
else: |
|
st.error("Failed to upload and process the file.") |
|
except requests.exceptions.Timeout: |
|
st.error("The request timed out. Please try again later.") |
|
|
|
def format_transcription(self, transcription: str): |
|
lines = transcription.split("\n") |
|
transcript = [] |
|
timestamp = None |
|
for line in lines: |
|
if "-->" in line: |
|
timestamp = line[: line.find(".")] |
|
elif timestamp: |
|
line = f"[{timestamp}] {line}" |
|
transcript.append(line) |
|
timestamp = None |
|
return "\n".join(transcript) |
|
|
|
def format_json_transcription(self, transcription: dict): |
|
transcript = [] |
|
print(transcript) |
|
if isinstance(transcription, list): |
|
# For the JSON format in MacWhisper |
|
for line in transcription: |
|
speaker = line.get("speaker", None) |
|
if speaker: |
|
line = f"[{line['timestamp']}] {speaker}: {line['text']}" |
|
else: |
|
line = f"[{line['timestamp']}] {line['text']}" |
|
transcript.append(line) |
|
elif isinstance(transcription, dict): |
|
# For the DOT format in MacWhisper |
|
if 'lines' in transcription: |
|
for line in transcription['lines']: |
|
timestamp = line['startTime'] |
|
text = line['text'] |
|
speaker = line.get('speaker', None) |
|
if speaker: |
|
line = f"[{timestamp}] {speaker}: {text}" |
|
else: |
|
line = f"[{timestamp}] {text}" |
|
transcript.append(line) |
|
return "\n".join(transcript) |
|
|
|
def delete_note(self, note_id): |
|
if note_id in self.notes: |
|
self.notes.remove(note_id) |
|
# Delete the note document using the API |
|
self.user_arango.delete_document( |
|
collection_name="notes", |
|
document_key=note_id.split("/")[1] |
|
) |
|
self.update_project() |
|
|
|
def delete_interview(self, interview_id): |
|
# Delete interview using the API |
|
self.user_arango.delete_document( |
|
collection_name="interviews", |
|
document_key=interview_id |
|
) |
|
|
|
def update_notes_hash(self): |
|
current_hash = self.make_project_notes_hash() |
|
if current_hash != self.note_keys_hash: |
|
self.note_keys_hash = current_hash |
|
with st.spinner("Summarizing notes for chatbot..."): |
|
self.create_notes_summary() |
|
self.update_project() |
|
|
|
def make_project_notes_hash(self): |
|
if not self.notes: |
|
return hash("") |
|
note_keys_str = "".join(self.notes) |
|
return hash(note_keys_str) |
|
|
|
def create_notes_summary(self): |
|
# Get note texts using the API |
|
notes_list = [] |
|
for note_id in self.notes: |
|
note = self.user_arango.get_document(note_id) |
|
if note and "text" in note: |
|
notes_list.append(note["text"]) |
|
|
|
notes_string = "\n---\n".join(notes_list) |
|
llm = LLM(model="small") |
|
query = get_note_summary_prompt(self, notes_string) |
|
summary = llm.generate(query).content |
|
print_purple("New summary of notes:", summary) |
|
self.notes_summary = summary |
|
self.update_session_state() |
|
|
|
def analyze_image(self, image_base64, text=None): |
|
project_data = {"name": self.name} |
|
llm = LLM(system_message=get_image_system_prompt(self)) |
|
prompt = ( |
|
f'Analyze the image. The text found in it read: "{text}"' |
|
if text |
|
else "Analyze the image." |
|
) |
|
print_blue(type(image_base64)) |
|
description = llm.generate(query=prompt, images=[image_base64], stream=False) |
|
print_green("Image description:", description) |
|
|
|
def process_uploaded_notes(self, files): |
|
with st.spinner("Processing files..."): |
|
for file in files: |
|
st.write("Processing...") |
|
filename = fix_key(file.name) |
|
|
|
image_file = self.file2img(file) |
|
pdf_file = self.convert_image_to_pdf(image_file) |
|
pdf = PDFProcessor( |
|
pdf_file=pdf_file, |
|
is_sci=False, |
|
document_type="notes", |
|
is_image=True, |
|
process=False, |
|
) |
|
text = pdf.process_document() |
|
base64_str = base64.b64encode(file.read()) |
|
image_caption = self.analyze_image(base64_str, text=text) |
|
self.add_note( |
|
{ |
|
"_id": f"notes/{filename}", |
|
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}", |
|
} |
|
) |
|
st.success("Done!") |
|
sleep(1.5) |
|
self.update_session_state() |
|
st.rerun() |
|
|
|
def file2img(self, file): |
|
img_bytes = file.read() |
|
if not img_bytes: |
|
raise ValueError("Uploaded file is empty.") |
|
return Image.open(BytesIO(img_bytes)) |
|
|
|
def convert_image_to_pdf(self, img): |
|
import pytesseract |
|
|
|
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img) |
|
pdf_file = BytesIO(pdf_bytes) |
|
pdf_file.name = ( |
|
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf" |
|
) |
|
return pdf_file |
|
|
|
def get_wikipedia_data(self, page_url: str) -> dict: |
|
import wikipedia |
|
from urllib.parse import urlparse |
|
|
|
parsed_url = urlparse(page_url) |
|
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path) |
|
if page_name_match: |
|
page_name = page_name_match.group(0) |
|
else: |
|
st.warning("Invalid Wikipedia URL") |
|
return None |
|
|
|
try: |
|
page = wikipedia.page(page_name, auto_suggest=False) |
|
data = { |
|
"title": page.title, |
|
"summary": page.summary, |
|
"content": page.content, |
|
"url": page.url, |
|
"references": page.references, |
|
} |
|
return data |
|
except Exception as e: |
|
st.error(f"Error fetching Wikipedia data: {e}") |
|
return None |
|
|
|
def process_wikipedia_data(self, wiki_data, wiki_url): |
|
llm = LLM( |
|
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!", |
|
model="small", |
|
) |
|
if wiki_data.get("summary"): |
|
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.''' |
|
summary = llm.generate(query).content |
|
wiki_data["text"] = ( |
|
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}" |
|
) |
|
wiki_data.pop("summary", None) |
|
wiki_data.pop("content", None) |
|
|
|
# Use the API to insert wiki data as a note |
|
self.user_arango.insert_document( |
|
collection_name="notes", |
|
document={ |
|
**wiki_data, |
|
"project": self.name, |
|
"username": self.username, |
|
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M") |
|
}, |
|
overwrite=True |
|
) |
|
self.add_note(wiki_data) |
|
|
|
processor = PDFProcessor(process=False) |
|
dois = [] |
|
print_rainbow(wiki_data.get("references", [])) |
|
for ref in wiki_data.get("references", []): |
|
doi = processor.extract_doi(ref) |
|
if doi: |
|
print_blue("Found DOI:", doi) |
|
dois.append(doi) |
|
|
|
if len(dois) > 0: |
|
current_collection = st.session_state["settings"].get("current_collection") |
|
st.markdown( |
|
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?" |
|
) |
|
if st.button("Add DOIs"): |
|
self.process_dois(current_collection, dois=dois) |
|
self.update_session_state() |
|
|
|
def process_dois( |
|
self, article_collection_name: str, text: str = None, dois: list = None |
|
) -> None: |
|
processor = PDFProcessor(process=False) |
|
if not dois and text: |
|
dois = processor.extract_doi(text, multi=True) |
|
if "not_downloaded" not in st.session_state: |
|
st.session_state["not_downloaded"] = {} |
|
for doi in dois: |
|
downloaded, url, path, in_db = processor.doi2pdf(doi) |
|
if downloaded and not in_db: |
|
processor.process_pdf(path) |
|
in_db = True |
|
elif not downloaded and not in_db: |
|
st.session_state["not_downloaded"][doi] = url |
|
|
|
if in_db: |
|
st.success(f"Article with DOI {doi} added") |
|
self.articles2collection( |
|
collection=article_collection_name, |
|
db="base", |
|
_id=f"sci_articles/{fix_key(doi)}", |
|
) |
|
self.update_session_state() |
|
|
|
def articles2collection(self, collection, db, _id): |
|
# Use the base/admin ArangoDB for general operations like adding to collections |
|
base_arango = ArangoDB(db_name="base") |
|
|
|
# Get the collection |
|
collection_doc = base_arango.execute_aql( |
|
"FOR c IN article_collections FILTER c.name == @name RETURN c", |
|
bind_vars={"name": collection} |
|
) |
|
|
|
try: |
|
collection_doc = next(collection_doc) |
|
if _id not in collection_doc["articles"]: |
|
collection_doc["articles"].append(_id) |
|
# Update the collection |
|
base_arango.update_document(collection_doc) |
|
except StopIteration: |
|
# Collection doesn't exist, create it |
|
base_arango.insert_document( |
|
collection_name="article_collections", |
|
document={ |
|
"name": collection, |
|
"articles": [_id] |
|
} |
|
)
|
|
|