parent
00fd42b32d
commit
732793b79f
23 changed files with 3101 additions and 1665 deletions
@ -0,0 +1,440 @@ |
||||
import os |
||||
import base64 |
||||
import re |
||||
import json |
||||
from typing import Any, Callable, Iterator, Literal, Mapping, Optional, Sequence, Union |
||||
|
||||
import tiktoken |
||||
from ollama import Client, AsyncClient, ResponseError, ChatResponse, Message, Tool, Options |
||||
from ollama._types import JsonSchemaValue, ChatRequest |
||||
|
||||
import env_manager |
||||
from colorprinter.print_color import * |
||||
|
||||
env_manager.set_env() |
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base") |
||||
|
||||
# Define a base class for common functionality |
||||
class BaseClient: |
||||
def chat( |
||||
self, |
||||
model: str = '', |
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
||||
*, |
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
||||
stream: bool = False, |
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
||||
options: Optional[Union[Mapping[str, Any], Options]] = None, |
||||
keep_alive: Optional[Union[float, str]] = None, |
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
||||
return self._request( |
||||
ChatResponse, |
||||
'POST', |
||||
'/api/chat', |
||||
json=ChatRequest( |
||||
model=model, |
||||
messages=[message for message in messages or []], |
||||
tools=[tool for tool in tools or []], |
||||
stream=stream, |
||||
format=format, |
||||
options=options, |
||||
keep_alive=keep_alive, |
||||
).model_dump(exclude_none=True), |
||||
stream=stream, |
||||
) |
||||
|
||||
# Define your custom MyAsyncClient class |
||||
class MyAsyncClient(AsyncClient, BaseClient): |
||||
async def _request(self, response_type, method, path, headers=None, **kwargs): |
||||
# Merge default headers with per-call headers |
||||
all_headers = {**self._client.headers, **(headers or {})} |
||||
|
||||
# Handle streaming separately |
||||
if kwargs.get('stream'): |
||||
kwargs.pop('stream') |
||||
async with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
||||
self.last_response = response # Store the response object |
||||
if response.status_code >= 400: |
||||
await response.aread() |
||||
raise ResponseError(response.text, response.status_code) |
||||
return self._stream(response_type, response) |
||||
else: |
||||
# Make the HTTP request with the combined headers |
||||
kwargs.pop('stream') |
||||
response = await self._request_raw(method, path, headers=all_headers, **kwargs) |
||||
self.last_response = response # Store the response object |
||||
|
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return response_type.model_validate_json(response.content) |
||||
|
||||
async def chat( |
||||
self, |
||||
model: str = '', |
||||
messages: Optional[Sequence[Union[Mapping[str, Any], Message]]] = None, |
||||
*, |
||||
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, |
||||
stream: bool = False, |
||||
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, |
||||
options: Optional[Union[Mapping[str, Any], Options]] = None, |
||||
keep_alive: Optional[Union[float, str]] = None, |
||||
) -> Union[ChatResponse, Iterator[ChatResponse]]: |
||||
return await self._request( |
||||
ChatResponse, |
||||
'POST', |
||||
'/api/chat', |
||||
json=ChatRequest( |
||||
model=model, |
||||
messages=[message for message in messages or []], |
||||
tools=[tool for tool in tools or []], |
||||
stream=stream, |
||||
format=format, |
||||
options=options, |
||||
keep_alive=keep_alive, |
||||
).model_dump(exclude_none=True), |
||||
stream=stream, |
||||
) |
||||
|
||||
# Define your custom MyClient class |
||||
class MyClient(Client, BaseClient): |
||||
def _request(self, response_type, method, path, headers=None, **kwargs): |
||||
# Merge default headers with per-call headers |
||||
all_headers = {**self._client.headers, **(headers or {})} |
||||
|
||||
# Handle streaming separately |
||||
if kwargs.get('stream'): |
||||
kwargs.pop('stream') |
||||
with self._client.stream(method, path, headers=all_headers, **kwargs) as response: |
||||
self.last_response = response # Store the response object |
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return self._stream(response_type, response) |
||||
else: |
||||
# Make the HTTP request with the combined headers |
||||
kwargs.pop('stream') |
||||
response = self._request_raw(method, path, headers=all_headers, **kwargs) |
||||
self.last_response = response # Store the response object |
||||
|
||||
if response.status_code >= 400: |
||||
raise ResponseError(response.text, response.status_code) |
||||
return response_type.model_validate_json(response.content) |
||||
|
||||
class LLM: |
||||
""" |
||||
LLM class for interacting with a language model. |
||||
""" |
||||
|
||||
def __init__( |
||||
self, |
||||
system_message="You are an assistant.", |
||||
temperature=0.01, |
||||
model: Optional[Literal["small", "standard", "vision"]] = "standard", |
||||
max_length_answer=4096, |
||||
messages=None, |
||||
chat=True, |
||||
chosen_backend=None, |
||||
) -> None: |
||||
|
||||
self.model = self.get_model(model) |
||||
self.system_message = system_message |
||||
self.options = {"temperature": temperature} |
||||
self.messages = messages or [{"role": "system", "content": self.system_message}] |
||||
self.max_length_answer = max_length_answer |
||||
self.chat = chat |
||||
self.chosen_backend = chosen_backend |
||||
|
||||
# Initialize the client with the host and default headers |
||||
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}" |
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode() |
||||
default_headers = { |
||||
"Authorization": f"Basic {encoded_credentials}", |
||||
} |
||||
host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") |
||||
self.client = MyClient(host=host_url, headers=default_headers) |
||||
self.async_client = MyAsyncClient(host=host_url, headers=default_headers) |
||||
|
||||
def get_model(self, model_alias): |
||||
models = { |
||||
"standard": "LLM_MODEL", |
||||
"small": "LLM_MODEL_SMALL", |
||||
"vision": "LLM_MODEL_VISION", |
||||
"standard_64k": "LLM_MODEL_64K", |
||||
} |
||||
return os.getenv(models.get(model_alias, "LLM_MODEL")) |
||||
|
||||
def count_tokens(self): |
||||
num_tokens = 0 |
||||
for i in self.messages: |
||||
for k, v in i.items(): |
||||
if k == "content": |
||||
if not isinstance(v, str): |
||||
v = str(v) |
||||
tokens = tokenizer.encode(v) |
||||
num_tokens += len(tokens) |
||||
return int(num_tokens) |
||||
|
||||
def generate( |
||||
self, |
||||
query: str = None, |
||||
user_input: str = None, |
||||
context: str = None, |
||||
stream: bool = False, |
||||
tools: list = None, |
||||
function_call: dict = None, |
||||
images: list = None, |
||||
model: Optional[Literal["small", "standard", "vision"]] = None, |
||||
temperature: float = None, |
||||
): |
||||
""" |
||||
Generates a response from the language model based on the provided inputs. |
||||
""" |
||||
|
||||
# Prepare the model and temperature |
||||
model = self.get_model(model) if model else self.model |
||||
temperature = temperature if temperature else self.options["temperature"] |
||||
|
||||
# Normalize whitespace and add the query to the messages |
||||
query = re.sub(r"\s*\n\s*", "\n", query) |
||||
message = {"role": "user", "content": query} |
||||
|
||||
# Handle images if any |
||||
if images: |
||||
import base64 |
||||
|
||||
base64_images = [] |
||||
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") |
||||
|
||||
for image in images: |
||||
if isinstance(image, str): |
||||
if base64_pattern.match(image): |
||||
base64_images.append(image) |
||||
else: |
||||
with open(image, "rb") as image_file: |
||||
base64_images.append( |
||||
base64.b64encode(image_file.read()).decode("utf-8") |
||||
) |
||||
elif isinstance(image, bytes): |
||||
base64_images.append(base64.b64encode(image).decode("utf-8")) |
||||
else: |
||||
print_red("Invalid image type") |
||||
|
||||
message["images"] = base64_images |
||||
# Use the vision model |
||||
model = self.get_model("vision") |
||||
|
||||
self.messages.append(message) |
||||
|
||||
# Prepare headers |
||||
headers = {} |
||||
if self.chosen_backend: |
||||
headers["X-Chosen-Backend"] = self.chosen_backend |
||||
|
||||
if model == self.get_model("small"): |
||||
headers["X-Model-Type"] = "small" |
||||
|
||||
# Prepare options |
||||
options = Options(**self.options) |
||||
options.temperature = temperature |
||||
|
||||
# Prepare tools if any |
||||
if tools: |
||||
tools = [ |
||||
Tool(**tool) if isinstance(tool, dict) else tool |
||||
for tool in tools |
||||
] |
||||
|
||||
# Adjust the options for long messages |
||||
if self.chat or len(self.messages) > 15000: |
||||
num_tokens = self.count_tokens() + self.max_length_answer // 2 |
||||
if num_tokens > 8000: |
||||
model = self.get_model("standard_64k") |
||||
headers["X-Model-Type"] = "large" |
||||
|
||||
# Call the client.chat method |
||||
try: |
||||
response = self.client.chat( |
||||
model=model, |
||||
messages=self.messages, |
||||
headers=headers, |
||||
tools=tools, |
||||
stream=stream, |
||||
options=options, |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
except ResponseError as e: |
||||
print_red("Error!") |
||||
print(e) |
||||
return "An error occurred." |
||||
|
||||
# If user_input is provided, update the last message |
||||
if user_input: |
||||
if context: |
||||
if len(context) > 2000: |
||||
context = self.make_summary(context) |
||||
user_input = ( |
||||
f"{user_input}\n\nUse the information below to answer the question.\n" |
||||
f'"""{context}"""\n[This is a summary of the context provided in the original message.]' |
||||
) |
||||
system_message_info = "\nSometimes some of the messages in the chat history are summarised, then that is clearly indicated in the message." |
||||
if system_message_info not in self.messages[0]["content"]: |
||||
self.messages[0]["content"] += system_message_info |
||||
self.messages[-1] = {"role": "user", "content": user_input} |
||||
|
||||
self.chosen_backend = self.client.last_response.headers.get("X-Chosen-Backend") |
||||
|
||||
# Handle streaming response |
||||
if stream: |
||||
return self.read_stream(response) |
||||
else: |
||||
# Process the response |
||||
if isinstance(response, ChatResponse): |
||||
result = response.message.content.strip('"') |
||||
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
||||
if tools and not response.message.get("tool_calls"): |
||||
print_yellow("No tool calls in response".upper()) |
||||
if not self.chat: |
||||
self.messages = [self.messages[0]] |
||||
return result |
||||
else: |
||||
print_red("Unexpected response type") |
||||
return "An error occurred." |
||||
|
||||
def make_summary(self, text): |
||||
# Implement your summary logic using self.client.chat() |
||||
summary_message = { |
||||
"role": "user", |
||||
"content": f'Summarize the text below:\n"""{text}"""\nRemember to be concise and detailed. Answer in English.', |
||||
} |
||||
messages = [ |
||||
{"role": "system", "content": "You are summarizing a text. Make it detailed and concise. Answer ONLY with the summary. Don't add any new information."}, |
||||
summary_message, |
||||
] |
||||
try: |
||||
response = self.client.chat( |
||||
model=self.get_model("small"), |
||||
messages=messages, |
||||
options=Options(temperature=0.01), |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
summary = response.message.content.strip() |
||||
print_blue("Summary:", summary) |
||||
return summary |
||||
except ResponseError as e: |
||||
print_red("Error generating summary:", e) |
||||
return "Summary generation failed." |
||||
|
||||
def read_stream(self, response): |
||||
# Implement streaming response handling if needed |
||||
buffer = "" |
||||
message = "" |
||||
first_chunk = True |
||||
prev_content = None |
||||
for chunk in response: |
||||
if chunk: |
||||
content = chunk.message.content |
||||
if first_chunk and content.startswith('"'): |
||||
content = content[1:] |
||||
first_chunk = False |
||||
|
||||
if chunk.done: |
||||
if prev_content and prev_content.endswith('"'): |
||||
prev_content = prev_content[:-1] |
||||
if prev_content: |
||||
yield prev_content |
||||
break |
||||
else: |
||||
if prev_content: |
||||
yield prev_content |
||||
prev_content = content |
||||
self.messages.append({"role": "assistant", "content": message.strip('"')}) |
||||
|
||||
async def async_generate( |
||||
self, |
||||
query: str = None, |
||||
user_input: str = None, |
||||
context: str = None, |
||||
stream: bool = False, |
||||
tools: list = None, |
||||
function_call: dict = None, |
||||
images: list = None, |
||||
model: Optional[Literal["small", "standard", "vision"]] = None, |
||||
temperature: float = None, |
||||
): |
||||
""" |
||||
Asynchronous method to generate a response from the language model. |
||||
""" |
||||
|
||||
# Prepare the model and temperature |
||||
model = self.get_model(model) if model else self.model |
||||
temperature = temperature if temperature else self.options["temperature"] |
||||
|
||||
# Normalize whitespace and add the query to the messages |
||||
query = re.sub(r"\s*\n\s*", "\n", query) |
||||
message = {"role": "user", "content": query} |
||||
|
||||
# Handle images if any |
||||
if images: |
||||
# (Image handling code as in the generate method) |
||||
... |
||||
|
||||
self.messages.append(message) |
||||
|
||||
# Prepare headers |
||||
headers = {} |
||||
if self.chosen_backend: |
||||
headers["X-Chosen-Backend"] = self.chosen_backend |
||||
|
||||
if model == self.get_model("small"): |
||||
headers["X-Model-Type"] = "small" |
||||
|
||||
# Prepare options |
||||
options = Options(**self.options) |
||||
options.temperature = temperature |
||||
|
||||
# Prepare tools if any |
||||
if tools: |
||||
tools = [ |
||||
Tool(**tool) if isinstance(tool, dict) else tool |
||||
for tool in tools |
||||
] |
||||
|
||||
# Adjust options for long messages |
||||
# (Adjustments as needed) |
||||
... |
||||
|
||||
# Call the async client's chat method |
||||
try: |
||||
response = await self.async_client.chat( |
||||
model=model, |
||||
messages=self.messages, |
||||
tools=tools, |
||||
stream=stream, |
||||
options=options, |
||||
keep_alive=3600 * 24 * 7, |
||||
) |
||||
except ResponseError as e: |
||||
print_red("Error!") |
||||
print(e) |
||||
return "An error occurred." |
||||
|
||||
# Process the response |
||||
if isinstance(response, ChatResponse): |
||||
result = response.message.content.strip('"') |
||||
self.messages.append({"role": "assistant", "content": result.strip('"')}) |
||||
return result |
||||
else: |
||||
print_red("Unexpected response type") |
||||
return "An error occurred." |
||||
|
||||
# Usage example |
||||
if __name__ == "__main__": |
||||
import asyncio |
||||
|
||||
llm = LLM() |
||||
|
||||
async def main(): |
||||
result = await llm.async_generate(query="Hello, how are you?") |
||||
print(result) |
||||
|
||||
asyncio.run(main()) |
||||
@ -0,0 +1,232 @@ |
||||
from _llm import LLM |
||||
import os, re |
||||
|
||||
from atproto import ( |
||||
CAR, |
||||
AtUri, |
||||
Client, |
||||
FirehoseSubscribeReposClient, |
||||
firehose_models, |
||||
models, |
||||
parse_subscribe_repos_message, |
||||
) |
||||
from colorprinter.print_color import * |
||||
from datetime import datetime |
||||
|
||||
from env_manager import set_env |
||||
set_env() |
||||
|
||||
|
||||
class Chat: |
||||
def __init__(self, bot_username, poster_username): |
||||
self.bot_username = bot_username |
||||
self.poster_username = poster_username |
||||
self.messages = [] |
||||
self.thread_posts = [] |
||||
|
||||
class Bot: |
||||
def __init__(self): |
||||
|
||||
# Create a client instance to interact with Bluesky |
||||
self.username = os.getenv("BLUESKY_USERNAME") |
||||
system_message = ''' |
||||
You are a research assistant bot chatting with a user on Bluesky, a social media platform similar to Twitter. |
||||
Your speciality is electric cars, and you will use facts in articles to answer the questions |
||||
Use ONLY the information in the articles to answer the questions. Do not add any additional information or speculation. |
||||
IF you don't know the answer, you can say "I don't know" or "I'm not sure". You can also ask the user to specify the question. |
||||
Your answers should be concise and not exceed 250 characters to fit the character limit on Bluesky. |
||||
Answer in English. |
||||
''' |
||||
self.llm: LLM = LLM(system_message=system_message, max_length_answer=200) |
||||
self.client = Client() |
||||
self.client.login(self.username, os.getenv("BLUESKY_PASSWORD")) |
||||
self.chat = None |
||||
|
||||
print("🐟 Bot is listening") |
||||
|
||||
# Create a firehose client to subscribe to repository events |
||||
self.firehose = FirehoseSubscribeReposClient() |
||||
# Start the firehose to listen for repository events |
||||
self.firehose.start(self.on_message_handler) |
||||
|
||||
def answer_message(self, message): |
||||
response = self.llm.generate(message).content |
||||
self.client.send_post(response) |
||||
|
||||
def get_file_extension(self, file_path): |
||||
# Utility function to get the file extension from a file path |
||||
return os.path.splitext(file_path)[1] |
||||
|
||||
def bot_mentioned(self, text: str) -> bool: |
||||
# Check if the text contains 'cc: dreambot' (case-insensitive) |
||||
return self.username.lower() in text.lower() |
||||
|
||||
def parse_thread(self, author_did, thread): |
||||
# Traverse the thread to collect prompts from posts by the author |
||||
entries = [] |
||||
stack = [thread] |
||||
|
||||
while stack: |
||||
current_thread = stack.pop() |
||||
|
||||
if current_thread is None: |
||||
continue |
||||
|
||||
if current_thread.post.author.did == author_did: |
||||
print(current_thread.post.record.text) |
||||
# Extract prompt from the current post |
||||
entries.append(current_thread.post.record.text) |
||||
|
||||
# Add parent thread to the stack for further traversal |
||||
stack.append(current_thread.parent) |
||||
|
||||
return entries |
||||
|
||||
def process_operation( |
||||
self, |
||||
op: models.ComAtprotoSyncSubscribeRepos.RepoOp, |
||||
car: CAR, |
||||
commit: models.ComAtprotoSyncSubscribeRepos.Commit, |
||||
) -> None: |
||||
# Construct the URI for the operation |
||||
uri = AtUri.from_str(f"at://{commit.repo}/{op.path}") |
||||
|
||||
if op.action == "create": |
||||
if not op.cid: |
||||
return |
||||
|
||||
# Retrieve the record from the CAR file using the content ID (CID) |
||||
record = car.blocks.get(op.cid) |
||||
if not record: |
||||
return |
||||
|
||||
# Build the record with additional metadata |
||||
record = { |
||||
"uri": str(uri), |
||||
"cid": str(op.cid), |
||||
"author": commit.repo, |
||||
**record, |
||||
} |
||||
|
||||
# Check if the operation is a post in the feed |
||||
if uri.collection == models.ids.AppBskyFeedPost: |
||||
if self.bot_mentioned(record["text"]): |
||||
poster_username = self.client.get_profile(actor=record["author"]).handle |
||||
self.chat = Chat(self.username, poster_username) |
||||
posts_in_thread = self.client.get_post_thread(uri=record["uri"]) |
||||
self.traverse_thread(posts_in_thread.thread) |
||||
self.chat.thread_posts.sort(key=lambda x: x["timestamp"]) |
||||
self.make_llm_messages() |
||||
answer = self.llm.generate(messages=self.chat.messages) |
||||
self.client.send_post(f'@{poster_username} {answer.content} ') |
||||
|
||||
|
||||
if op.action == "delete": |
||||
# Handle delete operations (not implemented) |
||||
return |
||||
|
||||
if op.action == "update": |
||||
# Handle update operations (not implemented) |
||||
return |
||||
|
||||
return |
||||
|
||||
|
||||
def traverse_thread(self, thread_view_post): |
||||
# Process the current post |
||||
post = thread_view_post.post |
||||
author_handle = post.author.handle |
||||
post_text = post.record.text |
||||
timestamp = int( |
||||
datetime.fromisoformat(post.indexed_at.replace("Z", "+00:00")).timestamp() |
||||
) |
||||
self.chat.thread_posts.append( |
||||
{ |
||||
"user": author_handle, |
||||
"text": post_text.replace("\n", " "), |
||||
"timestamp": timestamp, |
||||
} |
||||
) |
||||
|
||||
# If there's a parent, process it |
||||
if thread_view_post.parent: |
||||
self.traverse_thread(thread_view_post.parent) |
||||
|
||||
# If there are replies, process them |
||||
if getattr(thread_view_post, "replies", None): |
||||
for reply in thread_view_post.replies: |
||||
self.traverse_thread(reply) |
||||
|
||||
def make_llm_messages(self): |
||||
""" |
||||
Processes the chat thread posts and compiles them into a list of messages |
||||
formatted for a language model (LLM). |
||||
|
||||
The function performs the following steps: |
||||
1. Iterates through the chat thread posts. |
||||
2. Starts processing messages only after encountering a message mentioning the bot. |
||||
3. Adds messages from the bot and the poster to the `self.chat.messages` list in the |
||||
appropriate format for the LLM. |
||||
|
||||
The messages are formatted as follows: |
||||
- Messages from the bot are added with the role "assistant". |
||||
- Messages from the poster are added with the role "user". |
||||
- Consecutive messages from the same user are concatenated. |
||||
|
||||
Returns: |
||||
None |
||||
""" |
||||
print_rainbow(self.chat.thread_posts) |
||||
start = False |
||||
for i in self.chat.thread_posts: |
||||
# Make the messages start with a message mentioning the bot |
||||
if self.chat.bot_username in i["text"]: |
||||
start = True |
||||
elif self.chat.bot_username not in i["text"] and not start: |
||||
continue |
||||
# Compile the messages int a list for LLM |
||||
if ( |
||||
i["user"] == self.chat.bot_username |
||||
and len(self.chat.messages) > 0 |
||||
and self.chat.messages[-1] != self.chat.bot_username |
||||
): |
||||
i['text'] = i['text'].replace(f"@{self.chat.poster_username}", "").strip() |
||||
self.chat.messages.append({"role": "assistant", "content": i["text"]}) |
||||
elif i["user"] == self.chat.poster_username: |
||||
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip() |
||||
if len(self.chat.messages) > 0 and self.chat.messages[-1]['role'] == 'user': |
||||
self.chat.messages[-1]['content'] += f"\n\n{i['text']}" |
||||
else: |
||||
self.chat.messages.append({"role": "user", "content": i["text"]}) |
||||
|
||||
def on_message_handler(self, message: firehose_models.MessageFrame) -> None: |
||||
# Callback function that handles incoming messages from the firehose subscription |
||||
|
||||
# Parse the incoming message to extract the commit information |
||||
commit = parse_subscribe_repos_message(message) |
||||
|
||||
# Check if the parsed message is a Commit and if the commit contains blocks of data |
||||
if not isinstance( |
||||
commit, models.ComAtprotoSyncSubscribeRepos.Commit |
||||
) or not isinstance(commit.blocks, bytes): |
||||
# If the message is not a valid commit or blocks are missing, exit early |
||||
return |
||||
|
||||
# Parse the CAR (Content Addressable aRchive) file from the commit's blocks |
||||
# The CAR file contains the data blocks referenced in the commit operations |
||||
car = CAR.from_bytes(commit.blocks) |
||||
|
||||
# Iterate over each operation (e.g., create, delete, update) in the commit |
||||
for op in commit.ops: |
||||
# Process each operation using the process_operation method |
||||
# This method handles the logic based on the type of operation |
||||
self.process_operation(op, car, commit) |
||||
|
||||
|
||||
def main() -> None: |
||||
bot = Bot() |
||||
bot.answer_message("Hello, world!") |
||||
|
||||
|
||||
if __name__ == "__main__": |
||||
main() |
||||
@ -0,0 +1,412 @@ |
||||
import streamlit as st |
||||
from time import sleep |
||||
from article2db import PDFProcessor |
||||
|
||||
from info import country_emojis |
||||
from utils import fix_key |
||||
from _base_class import StreamlitBaseClass |
||||
from colorprinter.print_color import * |
||||
|
||||
|
||||
class ArticleCollectionsPage(StreamlitBaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
self.collection = None |
||||
self.page_name = "Article Collections" |
||||
|
||||
# Initialize attributes from session state if available |
||||
for k, v in st.session_state[self.page_name].items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
if self.user_arango.db.collection("article_collections").count() == 0: |
||||
self.create_new_collection() |
||||
|
||||
self.update_current_page(self.page_name) |
||||
|
||||
self.choose_collection_method() |
||||
self.choose_project_method() |
||||
|
||||
if self.collection: |
||||
self.display_collection() |
||||
self.sidebar_actions() |
||||
|
||||
if st.session_state.get("new_collection"): |
||||
self.create_new_collection() |
||||
|
||||
# Persist state to session_state |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def choose_collection_method(self): |
||||
self.collection = self.choose_collection() |
||||
# Persist state after choosing collection |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def choose_project_method(self): |
||||
# If you have a project selection similar to collection, implement here |
||||
pass # Placeholder for project-related logic |
||||
|
||||
def choose_collection(self): |
||||
collections = self.get_article_collections() |
||||
current_collection = self.collection |
||||
preselected = ( |
||||
collections.index(current_collection) |
||||
if current_collection in collections |
||||
else None |
||||
) |
||||
with st.sidebar: |
||||
collection = st.selectbox( |
||||
"Select a collection of favorite articles", |
||||
collections, |
||||
index=preselected, |
||||
) |
||||
if collection: |
||||
self.collection = collection |
||||
self.update_settings("current_collection", collection) |
||||
return self.collection |
||||
|
||||
def create_new_collection(self): |
||||
with st.form("create_collection_form", clear_on_submit=True): |
||||
new_collection_name = st.text_input("Enter the name of the new collection") |
||||
submitted = st.form_submit_button("Create Collection") |
||||
if submitted: |
||||
if new_collection_name: |
||||
self.user_arango.db.collection("article_collections").insert( |
||||
{"name": new_collection_name, "articles": []} |
||||
) |
||||
st.success(f'New collection "{new_collection_name}" created') |
||||
self.collection = new_collection_name |
||||
self.update_settings("current_collection", new_collection_name) |
||||
# Persist state after creating a new collection |
||||
self.update_session_state(page_name=self.page_name) |
||||
sleep(1) |
||||
st.rerun() |
||||
|
||||
def display_collection(self): |
||||
with st.sidebar: |
||||
col1, col2 = st.columns(2) |
||||
with col1: |
||||
if st.button("Create new collection"): |
||||
st.session_state["new_collection"] = True |
||||
with col2: |
||||
if st.button(f':red[Remove collection "{self.collection}"]'): |
||||
self.user_arango.db.collection("article_collections").delete_match( |
||||
{"name": self.collection} |
||||
) |
||||
st.success(f'Collection "{self.collection}" removed') |
||||
self.collection = None |
||||
self.update_settings("current_collection", None) |
||||
# Persist state after removing a collection |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
self.show_articles_in_collection() |
||||
|
||||
def show_articles_in_collection(self): |
||||
collection_articles_cursor = self.user_arango.db.aql.execute( |
||||
f""" |
||||
FOR doc IN article_collections |
||||
FILTER doc["name"] == @collection |
||||
FOR article IN doc["articles"] |
||||
RETURN article["_id"] |
||||
""", |
||||
bind_vars={"collection": self.collection}, |
||||
) |
||||
|
||||
collection_article_ids = list(collection_articles_cursor) |
||||
sci_articles = [ |
||||
_id for _id in collection_article_ids if _id.startswith("sci_articles") |
||||
] |
||||
other_articles = [ |
||||
_id for _id in collection_article_ids if not _id.startswith("sci_articles") |
||||
] |
||||
|
||||
collection_articles = [] |
||||
if sci_articles: |
||||
cursor = self.base_arango.db.aql.execute( |
||||
""" |
||||
FOR doc IN sci_articles |
||||
FILTER doc["_id"] IN @article_ids |
||||
RETURN doc |
||||
""", |
||||
bind_vars={"article_ids": sci_articles}, |
||||
) |
||||
collection_articles += list(cursor) |
||||
if other_articles: |
||||
cursor = self.user_arango.db.aql.execute( |
||||
""" |
||||
FOR doc IN other_documents |
||||
FILTER doc["_id"] IN @article_ids |
||||
RETURN doc |
||||
""", |
||||
bind_vars={"article_ids": other_articles}, |
||||
) |
||||
collection_articles += list(cursor) |
||||
|
||||
# Sort articles by title |
||||
collection_articles = sorted( |
||||
collection_articles, |
||||
key=lambda x: x.get("metadata", {}).get("title", "No Title"), |
||||
) |
||||
|
||||
if collection_articles: |
||||
st.markdown(f"#### Articles in *{self.collection}*:") |
||||
for article in collection_articles: |
||||
if article is None: |
||||
continue |
||||
metadata = article.get("metadata") |
||||
if metadata is None: |
||||
continue |
||||
|
||||
title = metadata.get("title", "No Title").strip() |
||||
journal = metadata.get("journal", "No Journal").strip() |
||||
published_year = metadata.get("published_year", "No Year") |
||||
published_date = metadata.get("published_date", None) |
||||
language = metadata.get("language", "No Language") |
||||
icon = country_emojis.get(language.upper(), "") if language else "" |
||||
|
||||
expander_title = f"**{title}** *{journal}* ({published_year}) {icon}" |
||||
|
||||
with st.expander(expander_title): |
||||
if not title == "No Title": |
||||
st.markdown(f"**Title:** \n{title}") |
||||
if not journal == "No Journal": |
||||
st.markdown(f"**Journal:** \n{journal}") |
||||
|
||||
if published_date: |
||||
st.markdown(f"**Published Date:** \n{published_date}") |
||||
for key, value in article.items(): |
||||
if key in [ |
||||
"_key", |
||||
"text", |
||||
"file", |
||||
"_rev", |
||||
"chunks", |
||||
"user_access", |
||||
"_id", |
||||
"metadata", |
||||
"doi", |
||||
"title", |
||||
"user_notes", |
||||
]: |
||||
continue |
||||
if isinstance(value, list): |
||||
value = ", ".join(value) |
||||
st.markdown(f"**{key.capitalize()}**: \n{value} ") |
||||
if "doi" in article: |
||||
if article["doi"]: |
||||
st.markdown( |
||||
f"**DOI:** \n[{article['doi']}](https://doi.org/{article['doi']}) " |
||||
) |
||||
|
||||
# Let the user add notes to the article, if it's not a scientific article |
||||
# if not article._id.startswith("sci_articles"): |
||||
if "user_notes" in article and article["user_notes"]: |
||||
st.markdown( |
||||
f":blue[**Your notes:**]" |
||||
) |
||||
note_number = 0 |
||||
for note in article["user_notes"]: |
||||
note_number += 1 |
||||
c1, c2 = st.columns([4, 1]) |
||||
with c1: |
||||
st.markdown(f":blue[{note}]") |
||||
with c2: |
||||
st.button(key=f'{article["_key"]}_{note_number}', |
||||
label=f":red[Delete note]", |
||||
on_click=self.delete_article_note, |
||||
args=(article, note), |
||||
) |
||||
|
||||
with st.form(f"add_info_form_{article['_id']}", clear_on_submit=True): |
||||
new_info = st.text_area( |
||||
":blue[Add a note about the article]", |
||||
key=f'new_info_{article["_id"]}', |
||||
help="Add information such as what kind of article it is, what it's about, who's the author, etc.", |
||||
) |
||||
submitted = st.form_submit_button(":blue[Add note]") |
||||
if submitted: |
||||
self.update_article(article, "user_notes", new_info) |
||||
|
||||
st.button( |
||||
key=f'delete_{article["_id"]}', |
||||
label=":red[Delete article from collection]", |
||||
on_click=self.delete_article, |
||||
args=(self.collection, article["_id"]), |
||||
) |
||||
|
||||
else: |
||||
st.write("No articles in this collection.") |
||||
|
||||
def sidebar_actions(self): |
||||
with st.sidebar: |
||||
st.markdown(f"### Add new articles to {self.collection}") |
||||
with st.form("add_articles_form", clear_on_submit=True): |
||||
pdf_files = st.file_uploader( |
||||
"Upload PDF file(s)", type=["pdf"], accept_multiple_files=True |
||||
) |
||||
is_sci = st.checkbox("All articles are from scientific journals") |
||||
submitted = st.form_submit_button("Upload") |
||||
if submitted and pdf_files: |
||||
self.add_articles(pdf_files, is_sci) |
||||
# Persist state after adding articles |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
help_text = 'Paste a text containing DOIs, e.g., the reference section of a paper, and click "Add Articles" to add them to the collection.' |
||||
new_articles = st.text_area( |
||||
"Add articles to this collection", help=help_text |
||||
) |
||||
if st.button("Add Articles"): |
||||
with st.spinner("Processing..."): |
||||
self.process_dois( |
||||
article_collection_name=self.collection, text=new_articles |
||||
) |
||||
# Persist state after processing DOIs |
||||
self.update_session_state(page_name=self.page_name) |
||||
st.rerun() |
||||
|
||||
self.write_not_downloaded() |
||||
|
||||
def add_articles(self, pdf_files: list, is_sci: bool) -> None: |
||||
|
||||
for pdf_file in pdf_files: |
||||
status_container = st.empty() |
||||
with status_container: |
||||
is_sci = is_sci if is_sci else None |
||||
with st.status(f"Processing {pdf_file.name}..."): |
||||
processor = PDFProcessor( |
||||
pdf_file=pdf_file, |
||||
filename=pdf_file.name, |
||||
process=False, |
||||
username=st.session_state["username"], |
||||
document_type="other_documents", |
||||
is_sci=is_sci, |
||||
) |
||||
_id, db, doi = processor.process_document() |
||||
print_rainbow(_id, db, doi) |
||||
if doi in st.session_state.get("not_downloaded", {}): |
||||
st.session_state["not_downloaded"].pop(doi) |
||||
self.articles2collection(collection=self.collection, db=db, _id=_id) |
||||
st.success("Done!") |
||||
sleep(1.5) |
||||
|
||||
def articles2collection(self, collection: str, db: str, _id: str = None) -> None: |
||||
info = self.get_article_info(db, _id=_id) |
||||
info = { |
||||
k: v for k, v in info.items() if k in ["_id", "doi", "title", "metadata"] |
||||
} |
||||
doc_cursor = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc' |
||||
) |
||||
doc = next(doc_cursor, None) |
||||
if doc: |
||||
articles = doc.get("articles", []) |
||||
keys = [i["_id"] for i in articles] |
||||
if info["_id"] not in keys: |
||||
articles.append(info) |
||||
self.user_arango.db.collection("article_collections").update_match( |
||||
filters={"name": collection}, |
||||
body={"articles": articles}, |
||||
merge=True, |
||||
) |
||||
# Persist state after updating articles |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def get_article_info(self, db: str, _id: str = None, doi: str = None) -> dict: |
||||
assert _id or doi, "Either _id or doi must be provided." |
||||
arango = self.get_arango(db_name=db) |
||||
if _id: |
||||
query = """ |
||||
RETURN { |
||||
"_id": DOCUMENT(@doc_id)._id, |
||||
"doi": DOCUMENT(@doc_id).doi, |
||||
"title": DOCUMENT(@doc_id).title, |
||||
"metadata": DOCUMENT(@doc_id).metadata, |
||||
"summary": DOCUMENT(@doc_id).summary |
||||
} |
||||
""" |
||||
|
||||
info_cursor = arango.db.aql.execute(query, bind_vars={"doc_id": _id}) |
||||
elif doi: |
||||
info_cursor = arango.db.aql.execute( |
||||
f'FOR doc IN sci_articles FILTER doc["doi"] == "{doi}" LIMIT 1 RETURN {{"_id": doc["_id"], "doi": doc["doi"], "title": doc["title"], "metadata": doc["metadata"], "summary": doc["summary"]}}' |
||||
) |
||||
return next(info_cursor, None) |
||||
|
||||
def process_dois( |
||||
self, article_collection_name: str, text: str = None, dois: list = None |
||||
) -> None: |
||||
processor = PDFProcessor(process=False) |
||||
if not dois and text: |
||||
dois = processor.extract_doi(text, multi=True) |
||||
if "not_downloaded" not in st.session_state: |
||||
st.session_state["not_downloaded"] = {} |
||||
for doi in dois: |
||||
downloaded, url, path, in_db = processor.doi2pdf(doi) |
||||
if downloaded and not in_db: |
||||
processor.process_pdf(path) |
||||
in_db = True |
||||
elif not downloaded and not in_db: |
||||
st.session_state["not_downloaded"][doi] = url |
||||
|
||||
if in_db: |
||||
st.success(f"Article with DOI {doi} added") |
||||
self.articles2collection( |
||||
collection=article_collection_name, |
||||
db="base", |
||||
_id=f"sci_articles/{fix_key(doi)}", |
||||
) |
||||
# Persist state after processing DOIs |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def write_not_downloaded(self): |
||||
not_downloaded = st.session_state.get("not_downloaded", {}) |
||||
if not_downloaded: |
||||
st.markdown( |
||||
"*The articles below were not downloaded. Download them yourself and add them to the collection by dropping them in the area above. Some of them can be downloaded using the link.*" |
||||
) |
||||
for doi, url in not_downloaded.items(): |
||||
if url: |
||||
st.markdown(f"- [{doi}]({url})") |
||||
else: |
||||
st.markdown(f"- {doi}") |
||||
|
||||
def delete_article(self, collection, _id): |
||||
doc_cursor = self.user_arango.db.aql.execute( |
||||
f'FOR doc IN article_collections FILTER doc["name"] == "{collection}" RETURN doc' |
||||
) |
||||
doc = next(doc_cursor, None) |
||||
if doc: |
||||
articles = [ |
||||
article for article in doc.get("articles", []) if article["_id"] != _id |
||||
] |
||||
self.user_arango.db.collection("article_collections").update_match( |
||||
filters={"_id": doc["_id"]}, |
||||
body={"articles": articles}, |
||||
) |
||||
# Persist state after deleting an article |
||||
self.update_session_state(page_name=self.page_name) |
||||
|
||||
def update_article(self, article, field, value): |
||||
"Update a field in an article document" |
||||
value = str(value.strip()) |
||||
print(value) |
||||
print(type(value)) |
||||
if field in article: |
||||
if isinstance(article[field], list): |
||||
article[field].append(value) |
||||
else: |
||||
article[field] = [article[field], value] |
||||
else: |
||||
article[field] = [value] |
||||
self.user_arango.db.update_document(article, check_rev=False, silent=True) |
||||
sleep(0.2) |
||||
st.rerun() |
||||
|
||||
def delete_article_note(self, article: dict, note: str): |
||||
"Delete a note from a list of notes in an article document." |
||||
if "user_notes" in article and note in article["user_notes"]: |
||||
article["user_notes"].remove(note) |
||||
self.user_arango.db.update_document(article, check_rev=False, silent=True) |
||||
sleep(0.1) |
||||
@ -1,27 +1,59 @@ |
||||
|
||||
from typing import Callable, Dict, Any, List |
||||
|
||||
class ToolRegistry: |
||||
_tools = [] |
||||
""" |
||||
A registry for managing and accessing tools (functions). |
||||
|
||||
This class provides methods to register functions as tools and retrieve them by name. |
||||
|
||||
Attributes: |
||||
_tools (Dict[str, Callable]): A dictionary mapping tool names to their corresponding functions. |
||||
|
||||
Methods: |
||||
register(func: Callable) -> Callable: |
||||
Registers a function as a tool. The function's name is used as the key in the registry. |
||||
|
||||
get_tools(tools: List[str] = None) -> List[Callable]: |
||||
Retrieves a list of registered tools. If a list of tool names is provided, only the tools |
||||
with those names are returned. If no list is provided, all registered tools are returned. |
||||
""" |
||||
_tools: Dict[str, Callable] = {} |
||||
|
||||
@classmethod |
||||
def register(cls, name: str, description: str, parameters: Dict[str, Any] = None): |
||||
def decorator(func: Callable): |
||||
cls._tools.append({ |
||||
"type": "function", |
||||
"function": { |
||||
"name": name, |
||||
"description": description, |
||||
"parameters": parameters or {} |
||||
} |
||||
}) |
||||
# No need for the wrapper since we're not adding any extra logic |
||||
return func |
||||
return decorator |
||||
def register(cls, func: Callable): |
||||
""" |
||||
Registers a function as a tool in the class. |
||||
|
||||
This method adds the given function to the class's `_tools` dictionary, |
||||
using the function's name as the key. |
||||
|
||||
Args: |
||||
func (Callable): The function to be registered. |
||||
|
||||
Returns: |
||||
Callable: The same function that was passed in, allowing for decorator usage. |
||||
""" |
||||
cls._tools[func.__name__] = func |
||||
return func |
||||
|
||||
@classmethod |
||||
def get_tools(cls, tools: list = None) -> List[Dict[str, Any]]: |
||||
def get_tools(cls, tools: List[str] = None) -> List[Callable]: |
||||
""" |
||||
Retrieve a list of tool callables. |
||||
|
||||
This method returns a list of tool callables based on the provided tool names. |
||||
If no tool names are provided, it returns all available tool callables. |
||||
|
||||
Args: |
||||
tools (List[str], optional): A list of tool names to retrieve. Defaults to None. |
||||
|
||||
Returns: |
||||
List[Callable]: A list of tool callables. |
||||
""" |
||||
print(tools) |
||||
|
||||
if tools: |
||||
return [tool for tool in cls._tools if tool['function']['name'] in tools] |
||||
print(cls._tools) |
||||
return [cls._tools[name] for name in tools if name in cls._tools] |
||||
else: |
||||
return cls._tools |
||||
return list(cls._tools.values()) |
||||
@ -0,0 +1,725 @@ |
||||
import re |
||||
import os |
||||
import streamlit as st |
||||
from streamlit.runtime.uploaded_file_manager import UploadedFile |
||||
from time import sleep |
||||
from datetime import datetime |
||||
from PIL import Image |
||||
from io import BytesIO |
||||
import base64 |
||||
from article2db import PDFProcessor |
||||
|
||||
from utils import fix_key |
||||
from _arango import ArangoDB |
||||
from _llm import LLM |
||||
from _base_class import StreamlitBaseClass |
||||
from colorprinter.print_color import * |
||||
|
||||
from prompts import get_note_summary_prompt, get_image_system_prompt |
||||
|
||||
import env_manager |
||||
|
||||
env_manager.set_env() |
||||
print_green("Environment variables set.") |
||||
|
||||
|
||||
class ProjectsPage(StreamlitBaseClass): |
||||
def __init__(self, username: str): |
||||
super().__init__(username=username) |
||||
self.projects = [] |
||||
self.selected_project_name = None |
||||
self.project = None |
||||
self.page_name = "Projects" |
||||
|
||||
# Initialize attributes from session state if available |
||||
page_state = st.session_state.get(self.page_name, {}) |
||||
for k, v in page_state.items(): |
||||
setattr(self, k, v) |
||||
|
||||
def run(self): |
||||
self.update_current_page(self.page_name) |
||||
self.load_projects() |
||||
self.display_projects() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def load_projects(self): |
||||
projects_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN projects RETURN doc", count=True |
||||
) |
||||
self.projects = list(projects_cursor) |
||||
|
||||
def display_projects(self): |
||||
with st.sidebar: |
||||
self.new_project_button() |
||||
self.selected_project_name = st.selectbox( |
||||
"Select a project to manage", |
||||
options=[proj["name"] for proj in self.projects], |
||||
) |
||||
if self.selected_project_name: |
||||
self.project = Project( |
||||
username=self.username, |
||||
project_name=self.selected_project_name, |
||||
user_arango=self.user_arango, |
||||
) |
||||
self.manage_project() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def new_project_button(self): |
||||
st.session_state.setdefault("new_project", False) |
||||
with st.sidebar: |
||||
if st.button("New project", type="primary"): |
||||
st.session_state["new_project"] = True |
||||
if st.session_state["new_project"]: |
||||
self.create_new_project() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def create_new_project(self): |
||||
new_project_name = st.text_input("Enter the name of the new project") |
||||
new_project_description = st.text_area( |
||||
"Enter the description of the new project" |
||||
) |
||||
if st.button("Create Project"): |
||||
if new_project_name: |
||||
self.user_arango.db.collection("projects").insert( |
||||
{ |
||||
"name": new_project_name, |
||||
"description": new_project_description, |
||||
"collections": [], |
||||
"notes": [], |
||||
"note_keys_hash": hash(""), |
||||
"settings": {}, |
||||
} |
||||
) |
||||
st.success(f'New project "{new_project_name}" created') |
||||
st.session_state["new_project"] = False |
||||
self.update_settings("current_project", new_project_name) |
||||
sleep(1) |
||||
st.rerun() |
||||
|
||||
def show_project_notes(self): |
||||
|
||||
with st.expander("Show summarised notes"): |
||||
st.markdown(self.project.notes_summary) |
||||
|
||||
with st.expander("Show project notes"): |
||||
notes_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc", |
||||
bind_vars={"note_ids": self.project.notes}, |
||||
) |
||||
notes = list(notes_cursor) |
||||
if notes: |
||||
for note in notes: |
||||
st.markdown(f'_{note.get("timestamp", "")}_') |
||||
st.markdown(note["text"].replace("\n", " \n")) |
||||
st.button( |
||||
key=f'delete_note_{note["_id"]}', |
||||
label=":red[Delete note]", |
||||
on_click=self.project.delete_note, |
||||
args=(note["_id"],), |
||||
) |
||||
st.write("---") |
||||
else: |
||||
st.write("No notes in this project.") |
||||
|
||||
def show_project_interviews(self): |
||||
with st.expander("Show project interviews"): |
||||
if not self.user_arango.db.has_collection("interviews"): |
||||
self.user_arango.db.create_collection("interviews") |
||||
interviews_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN interviews FILTER doc.project == @project_name RETURN doc", |
||||
bind_vars={"project_name": self.project.name}, |
||||
) |
||||
interviews = list(interviews_cursor) |
||||
if interviews: |
||||
for interview in interviews: |
||||
st.markdown(f'_{interview.get("timestamp", "")}_') |
||||
st.markdown( |
||||
f"**Interviewees:** {', '.join(interview['intervievees'])}" |
||||
) |
||||
st.markdown(f"**Interviewer:** {interview['interviewer']}") |
||||
if len(interview["transcript"].split("\n")) > 6: |
||||
preview = ( |
||||
" \n".join(interview["transcript"].split("\n")[:6]) |
||||
+ " \n(...)" |
||||
) |
||||
else: |
||||
preview = interview["transcript"] |
||||
timestamps = re.findall(r"\[(.*?)\]", preview) |
||||
for ts in timestamps: |
||||
preview = preview.replace(f"[{ts}]", f":grey[{ts}]") |
||||
st.markdown(preview) |
||||
c1, c2 = st.columns(2) |
||||
with c1: |
||||
st.download_button( |
||||
label="Download Transcript", |
||||
key=f"download_transcript_{interview['_key']}", |
||||
data=interview["transcript"], |
||||
file_name=interview["filename"], |
||||
mime="text/vtt", |
||||
) |
||||
with c2: |
||||
st.button( |
||||
key=f'delete_interview_{interview["_key"]}', |
||||
label=":red[Delete interview]", |
||||
on_click=self.project.delete_interview, |
||||
args=(interview["_key"],), |
||||
) |
||||
st.write("---") |
||||
else: |
||||
st.write("No interviews in this project.") |
||||
|
||||
def manage_project(self): |
||||
self.update_settings("current_project", self.selected_project_name) |
||||
# Initialize the Project instance |
||||
self.project = Project( |
||||
self.username, self.selected_project_name, self.user_arango |
||||
) |
||||
st.write(f"## {self.project.name}") |
||||
self.show_project_interviews() |
||||
self.show_project_notes() |
||||
self.relate_collections() |
||||
self.sidebar_actions() |
||||
self.project.update_notes_hash() |
||||
if st.button(f":red[Remove project *{self.project.name}*]"): |
||||
self.user_arango.db.collection("projects").delete_match( |
||||
{"name": self.project.name} |
||||
) |
||||
self.update_settings("current_project", None) |
||||
st.success(f'Project "{self.project.name}" removed') |
||||
st.rerun() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def relate_collections(self): |
||||
collections = [ |
||||
col["name"] |
||||
for col in self.user_arango.db.collection("article_collections").all() |
||||
] |
||||
selected_collections = st.multiselect( |
||||
"Relate existing collections", options=collections |
||||
) |
||||
if st.button("Relate Collections"): |
||||
self.project.add_collections(selected_collections) |
||||
st.success("Collections related to the project") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
new_collection_name = st.text_input( |
||||
"Enter the name of the new collection to create and relate" |
||||
) |
||||
if st.button("Create and Relate Collection"): |
||||
if new_collection_name: |
||||
self.user_arango.db.collection("article_collections").insert( |
||||
{"name": new_collection_name, "articles": []} |
||||
) |
||||
self.project.add_collection(new_collection_name) |
||||
st.success( |
||||
f'New collection "{new_collection_name}" created and related to the project' |
||||
) |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def sidebar_actions(self): |
||||
self.sidebar_interview() |
||||
self.sidebar_notes() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def sidebar_notes(self): |
||||
with st.sidebar: |
||||
st.markdown(f"### Add new notes to {self.project.name}") |
||||
self.upload_notes_form() |
||||
self.add_text_note() |
||||
self.add_wikipedia_data() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def sidebar_interview(self): |
||||
with st.sidebar: |
||||
st.markdown(f"### Add new interview to {self.project.name}") |
||||
self.upload_interview_form() |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def upload_notes_form(self): |
||||
with st.expander("Upload notes"): |
||||
with st.form("add_notes", clear_on_submit=True): |
||||
files = st.file_uploader( |
||||
"Upload PDF or image", |
||||
type=["png", "jpg", "pdf"], |
||||
accept_multiple_files=True, |
||||
) |
||||
submitted = st.form_submit_button("Upload") |
||||
if submitted: |
||||
self.project.process_uploaded_notes(files) |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def upload_interview_form(self): |
||||
with st.expander("Upload interview"): |
||||
with st.form("add_interview", clear_on_submit=True): |
||||
interview = st.file_uploader("Upload interview audio file") |
||||
interviewees = st.text_input( |
||||
"Enter the names of the interviewees, separated by commas" |
||||
) |
||||
interviewer = st.text_input( |
||||
"Enter the interviewer's name", |
||||
help="If left blank, the current user will be used", |
||||
) |
||||
date_of_interveiw = st.date_input( |
||||
"Date of interview", value=None, format="YYYY-MM-DD" |
||||
) |
||||
submitted = st.form_submit_button("Upload") |
||||
if submitted: |
||||
self.project.add_interview( |
||||
interview, interviewees, interviewer, date_of_interveiw |
||||
) |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def add_text_note(self): |
||||
help_text = "Add notes to the project. Notes can be anything you want to affect how the editor bot replies." |
||||
note_text = st.text_area("Write or paste anything.", help=help_text) |
||||
if st.button("Add Note"): |
||||
self.project.add_note( |
||||
{ |
||||
"text": note_text, |
||||
"timestamp": datetime.now().strftime("%Y-%m-%d %H:%M"), |
||||
} |
||||
) |
||||
st.success("Note added to the project") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
|
||||
def add_wikipedia_data(self): |
||||
wiki_url = st.text_input( |
||||
"Paste the address to a Wikipedia page to add its summary as a note", |
||||
placeholder="Paste Wikipedia URL", |
||||
) |
||||
if st.button("Add Wikipedia data"): |
||||
with st.spinner("Fetching Wikipedia data..."): |
||||
wiki_data = self.project.get_wikipedia_data(wiki_url) |
||||
if wiki_data: |
||||
self.project.process_wikipedia_data(wiki_data, wiki_url) |
||||
st.success("Wikipedia data added to notes") |
||||
# Update session state |
||||
self.update_session_state(self.page_name) |
||||
st.rerun() |
||||
|
||||
|
||||
class Project(StreamlitBaseClass): |
||||
def __init__(self, username: str, project_name: str, user_arango: ArangoDB): |
||||
super().__init__(username=username) |
||||
self.name = project_name |
||||
self.user_arango = user_arango |
||||
self.description = "" |
||||
self.collections = [] |
||||
self.notes = [] |
||||
self.note_keys_hash = 0 |
||||
self.settings = {} |
||||
self.notes_summary = "" |
||||
|
||||
# Initialize attributes from arango doc if available |
||||
self.load_project() |
||||
|
||||
def load_project(self): |
||||
print_blue("Project name:", self.name) |
||||
project_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN projects FILTER doc.name == @name RETURN doc", |
||||
bind_vars={"name": self.name}, |
||||
) |
||||
project = next(project_cursor, None) |
||||
if not project: |
||||
raise ValueError(f"Project '{self.name}' not found.") |
||||
self._key = project["_key"] |
||||
self.name = project.get("name", "") |
||||
self.description = project.get("description", "") |
||||
self.collections = project.get("collections", []) |
||||
self.notes = project.get("notes", []) |
||||
self.note_keys_hash = project.get("note_keys_hash", 0) |
||||
self.settings = project.get("settings", {}) |
||||
self.notes_summary = project.get("notes_summary", "") |
||||
|
||||
def update_project(self): |
||||
updated_doc = { |
||||
"_key": self._key, |
||||
"name": self.name, |
||||
"description": self.description, |
||||
"collections": self.collections, |
||||
"notes": self.notes, |
||||
"note_keys_hash": self.note_keys_hash, |
||||
"settings": self.settings, |
||||
"notes_summary": self.notes_summary, |
||||
} |
||||
self.user_arango.db.collection("projects").update(updated_doc, check_rev=False) |
||||
self.update_session_state() |
||||
|
||||
def add_collections(self, collections): |
||||
self.collections.extend(collections) |
||||
self.update_project() |
||||
|
||||
def add_collection(self, collection_name): |
||||
self.collections.append(collection_name) |
||||
self.update_project() |
||||
|
||||
def add_note(self, note: dict): |
||||
assert note["text"], "Note text cannot be empty" |
||||
note["text"] = note["text"].strip().strip("\n") |
||||
if "timestamp" not in note: |
||||
note["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M") |
||||
note_doc = self.user_arango.db.collection("notes").insert(note) |
||||
if note_doc["_id"] not in self.notes: |
||||
self.notes.append(note_doc["_id"]) |
||||
self.update_project() |
||||
|
||||
def add_interview( |
||||
self, |
||||
interview: UploadedFile, |
||||
intervievees: str, |
||||
interviewer: str, |
||||
date_of_interveiw: datetime.date = None, |
||||
): |
||||
# TODO Implement this method |
||||
# Check if interview is a sound (WAV, Mp3, AAC, etc) file or a text file (PDF, DOCX, TXT, etc) |
||||
if interview.type in ["audio/x-wav", "audio/mpeg"]: |
||||
transcription = self.transcribe(interview) |
||||
transcription_preview = ( |
||||
" \n".join(transcription.split("\n")[:4]) + " \n(...)" |
||||
) |
||||
st.markdown(transcription_preview) |
||||
transcription_filename = os.path.splitext(interview.name)[0] + ".vtt" |
||||
c1, c2 = st.columns(2) |
||||
with c1: |
||||
st.button( |
||||
"Add to project", |
||||
on_click=self.add_interview_transcript, |
||||
args=( |
||||
transcription, |
||||
transcription_filename, |
||||
intervievees, |
||||
interviewer, |
||||
date_of_interveiw, |
||||
), |
||||
) |
||||
with c2: |
||||
st.download_button( |
||||
label="Download Transcription", |
||||
data=transcription, |
||||
file_name=transcription_filename, |
||||
mime="text/vtt", |
||||
) |
||||
elif interview.type in ["application/pdf"]: |
||||
PDFProcessor( |
||||
pdf_file=interview, |
||||
is_sci=False, |
||||
document_type="interview", |
||||
is_image=False, |
||||
) |
||||
elif interview.type in ["plain/text"]: |
||||
# TODO Implement text file processing |
||||
pass |
||||
|
||||
def add_interview_transcript( |
||||
self, |
||||
transcript, |
||||
filename, |
||||
intervievees: str = None, |
||||
interviewer: str = None, |
||||
date_of_interveiw: datetime.date = None, |
||||
): |
||||
print_yellow(transcript) |
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M") |
||||
_key = fix_key(f"{filename}_{timestamp}") |
||||
if intervievees: |
||||
intervievees = [ |
||||
i.strip() for i in intervievees.split(",") if len(i.strip()) > 0 |
||||
] |
||||
if not interviewer: |
||||
interviewer = self.username |
||||
if not self.user_arango.db.has_collection("interviews"): |
||||
self.user_arango.db.create_collection("interviews") |
||||
if date_of_interveiw: |
||||
date_of_interveiw = datetime.strptime(date_of_interveiw, "%Y-%m-%d") |
||||
|
||||
from article2db import Document |
||||
|
||||
document = Document( |
||||
text=transcript, |
||||
is_sci=False, |
||||
_key=_key, |
||||
filename=filename, |
||||
arango_db_name=self.username, |
||||
username=self.username, |
||||
arango_collection="interviews", |
||||
) |
||||
print_rainbow(document.__dict__) |
||||
print(document.text) |
||||
document.make_chunks(len_chunks=600) |
||||
|
||||
self.user_arango.db.collection("interviews").insert( |
||||
{ |
||||
"_key": _key, |
||||
"transcript": transcript, |
||||
"project": self.name, |
||||
"filename": filename, |
||||
"timestamp": timestamp, |
||||
"intervievees": intervievees, |
||||
"interviewer": interviewer, |
||||
"date_of_interveiw": date_of_interveiw, |
||||
"chunks": document.chunks, |
||||
}, |
||||
overwrite=True, |
||||
silent=True, |
||||
) |
||||
|
||||
document.make_summary_in_background() |
||||
|
||||
def transcribe(self, uploaded_file: UploadedFile): |
||||
from pydub import AudioSegment |
||||
import requests |
||||
import io |
||||
|
||||
file_extension = os.path.splitext(uploaded_file.name)[1].lower() |
||||
filename = uploaded_file.name |
||||
input_file_buffer = io.BytesIO(uploaded_file.getvalue()) |
||||
|
||||
progress_bar = st.progress(0) |
||||
status_text = st.empty() |
||||
|
||||
if file_extension in [".m4a", ".mp3", ".wav", ".flac"]: |
||||
# Handle audio files |
||||
audio = AudioSegment.from_file( |
||||
input_file_buffer, format=file_extension.replace(".", "") |
||||
) |
||||
audio = audio.set_channels(1) # Convert to mono |
||||
file_buffer = io.BytesIO() |
||||
audio.export(file_buffer, format="mp3", bitrate="64k") |
||||
file_buffer.seek(0) |
||||
progress_bar.progress(50) |
||||
status_text.text("Audio file converted.") |
||||
else: |
||||
st.error("Unsupported file type") |
||||
st.stop() |
||||
|
||||
# Send the converted audio data to the transcription service |
||||
try: |
||||
try: |
||||
url = os.getenv("TRANSCRIBE_URL") |
||||
except: |
||||
import dotenv |
||||
|
||||
dotenv.load_dotenv() |
||||
url = os.getenv("TRANSCRIBE_URL") |
||||
|
||||
# Prepare the files dictionary for the POST request |
||||
files = {"file": (filename, file_buffer, "audio/mp3")} |
||||
# Send the POST request with the file buffer |
||||
response = requests.post(url, files=files, timeout=3600) |
||||
|
||||
response_json = response.json() |
||||
progress_bar.progress(100) |
||||
status_text.text("File uploaded and processed.") |
||||
|
||||
if response.status_code == 200: |
||||
transcription_content = response_json.get("transcription", "") |
||||
transcription_content = self.format_transcription(transcription_content) |
||||
return transcription_content |
||||
else: |
||||
st.error("Failed to upload and process the file.") |
||||
except requests.exceptions.Timeout: |
||||
st.error("The request timed out. Please try again later.") |
||||
|
||||
def format_transcription(self, transcription: str): |
||||
lines = transcription.split("\n") |
||||
transcript = [] |
||||
timestamp = None |
||||
for line in lines: |
||||
if "-->" in line: |
||||
timestamp = line[: line.find(".")] |
||||
elif timestamp: |
||||
line = f"[{timestamp}] {line}" |
||||
transcript.append(line) |
||||
timestamp = None |
||||
return "\n".join(transcript) |
||||
|
||||
def delete_note(self, note_id): |
||||
if note_id in self.notes: |
||||
self.notes.remove(note_id) |
||||
self.update_project() |
||||
|
||||
def delete_interview(self, interview_id): |
||||
self.user_arango.db.collection("interviews").delete_match( |
||||
{"_key": interview_id} |
||||
) |
||||
|
||||
def update_notes_hash(self): |
||||
current_hash = self.make_project_notes_hash() |
||||
if current_hash != self.note_keys_hash: |
||||
self.note_keys_hash = current_hash |
||||
with st.spinner("Summarizing notes for chatbot..."): |
||||
self.create_notes_summary() |
||||
self.update_project() |
||||
|
||||
def make_project_notes_hash(self): |
||||
if not self.notes: |
||||
return hash("") |
||||
note_keys_str = "".join(self.notes) |
||||
return hash(note_keys_str) |
||||
|
||||
def create_notes_summary(self): |
||||
notes_cursor = self.user_arango.db.aql.execute( |
||||
"FOR doc IN notes FILTER doc._id IN @note_ids RETURN doc.text", |
||||
bind_vars={"note_ids": self.notes}, |
||||
) |
||||
notes = list(notes_cursor) |
||||
notes_string = "\n---\n".join(notes) |
||||
llm = LLM(model="small") |
||||
query = get_note_summary_prompt(self, notes_string) |
||||
summary = llm.generate(query) |
||||
print_purple("New summary of notes:", summary) |
||||
self.notes_summary = summary |
||||
self.update_session_state() |
||||
|
||||
def analyze_image(self, image_base64, text=None): |
||||
project_data = {"name": self.name} |
||||
llm = LLM(system_message=get_image_system_prompt(self)) |
||||
prompt = ( |
||||
f'Analyze the image. The text found in it read: "{text}"' |
||||
if text |
||||
else "Analyze the image." |
||||
) |
||||
print_blue(type(image_base64)) |
||||
description = llm.generate(query=prompt, images=[image_base64], stream=False) |
||||
print_green("Image description:", description) |
||||
|
||||
def process_uploaded_notes(self, files): |
||||
with st.spinner("Processing files..."): |
||||
for file in files: |
||||
st.write("Processing...") |
||||
filename = fix_key(file.name) |
||||
|
||||
image_file = self.file2img(file) |
||||
pdf_file = self.convert_image_to_pdf(image_file) |
||||
pdf = PDFProcessor( |
||||
pdf_file=pdf_file, |
||||
is_sci=False, |
||||
document_type="notes", |
||||
is_image=True, |
||||
process=False, |
||||
) |
||||
text = pdf.process_document() |
||||
base64_str = base64.b64encode(file.read()) |
||||
image_caption = self.analyze_image(base64_str, text=text) |
||||
self.add_note( |
||||
{ |
||||
"_id": f"notes/{filename}", |
||||
"text": f"## Image caption: \n{image_caption} \n#### Text extracted from image: \n{text}", |
||||
} |
||||
) |
||||
st.success("Done!") |
||||
sleep(1.5) |
||||
self.update_session_state() |
||||
st.rerun() |
||||
|
||||
def file2img(self, file): |
||||
img_bytes = file.read() |
||||
if not img_bytes: |
||||
raise ValueError("Uploaded file is empty.") |
||||
return Image.open(BytesIO(img_bytes)) |
||||
|
||||
def convert_image_to_pdf(self, img): |
||||
import pytesseract |
||||
|
||||
pdf_bytes = pytesseract.image_to_pdf_or_hocr(img) |
||||
pdf_file = BytesIO(pdf_bytes) |
||||
pdf_file.name = ( |
||||
"converted_image_" + datetime.now().strftime("%Y%m%d%H%M%S") + ".pdf" |
||||
) |
||||
return pdf_file |
||||
|
||||
def get_wikipedia_data(self, page_url: str) -> dict: |
||||
import wikipedia |
||||
from urllib.parse import urlparse |
||||
|
||||
parsed_url = urlparse(page_url) |
||||
page_name_match = re.search(r"(?<=/wiki/)[^?#]*", parsed_url.path) |
||||
if page_name_match: |
||||
page_name = page_name_match.group(0) |
||||
else: |
||||
st.warning("Invalid Wikipedia URL") |
||||
return None |
||||
|
||||
try: |
||||
page = wikipedia.page(page_name) |
||||
data = { |
||||
"title": page.title, |
||||
"summary": page.summary, |
||||
"content": page.content, |
||||
"url": page.url, |
||||
"references": page.references, |
||||
} |
||||
return data |
||||
except Exception as e: |
||||
st.error(f"Error fetching Wikipedia data: {e}") |
||||
return None |
||||
|
||||
def process_wikipedia_data(self, wiki_data, wiki_url): |
||||
llm = LLM( |
||||
system_message="You are an assistant summarisen wikipedia data. Answer ONLY with the summary, nothing else!", |
||||
model="small", |
||||
) |
||||
if wiki_data.get("summary"): |
||||
query = f'''Summarize the text below. It's from a Wikipedia page about {wiki_data["title"]}. \n\n"""{wiki_data['summary']}"""\nMake a detailed and concise summary of the text.''' |
||||
summary = llm.generate(query) |
||||
wiki_data["text"] = ( |
||||
f"(_Summarised using AI, read original [here]({wiki_url})_)\n{summary}" |
||||
) |
||||
wiki_data.pop("summary", None) |
||||
wiki_data.pop("content", None) |
||||
self.user_arango.db.collection("notes").insert( |
||||
wiki_data, overwrite=True, silent=True |
||||
) |
||||
self.add_note(wiki_data) |
||||
|
||||
processor = PDFProcessor(process=False) |
||||
dois = [ |
||||
processor.extract_doi(ref) |
||||
for ref in wiki_data.get("references", []) |
||||
if processor.extract_doi(ref) |
||||
] |
||||
if dois: |
||||
current_collection = st.session_state["settings"].get("current_collection") |
||||
st.markdown( |
||||
f"Found {len(dois)} references with DOI numbers. Do you want to add them to {current_collection}?" |
||||
) |
||||
if st.button("Add DOIs"): |
||||
self.process_dois(current_collection, dois=dois) |
||||
self.update_session_state() |
||||
|
||||
def process_dois( |
||||
self, article_collection_name: str, text: str = None, dois: list = None |
||||
) -> None: |
||||
processor = PDFProcessor(process=False) |
||||
if not dois and text: |
||||
dois = processor.extract_doi(text, multi=True) |
||||
if "not_downloaded" not in st.session_state: |
||||
st.session_state["not_downloaded"] = {} |
||||
for doi in dois: |
||||
downloaded, url, path, in_db = processor.doi2pdf(doi) |
||||
if downloaded and not in_db: |
||||
processor.process_pdf(path) |
||||
in_db = True |
||||
elif not downloaded and not in_db: |
||||
st.session_state["not_downloaded"][doi] = url |
||||
|
||||
if in_db: |
||||
st.success(f"Article with DOI {doi} added") |
||||
self.articles2collection( |
||||
collection=article_collection_name, |
||||
db="base", |
||||
_id=f"sci_articles/{fix_key(doi)}", |
||||
) |
||||
self.update_session_state() |
||||
File diff suppressed because it is too large
Load Diff
@ -1,32 +1,38 @@ |
||||
import os |
||||
import base64 |
||||
from ollama import Client |
||||
from ollama import Client, ChatResponse |
||||
import env_manager |
||||
from colorprinter.print_color import * |
||||
import httpx |
||||
|
||||
env_manager.set_env() |
||||
|
||||
# Encode the credentials |
||||
credentials = f"{os.getenv('LLM_API_USER')}:{os.getenv('LLM_API_PWD_LASSE')}" |
||||
encoded_credentials = base64.b64encode(credentials.encode()).decode() |
||||
|
||||
# Set up the headers with authentication details |
||||
headers = { |
||||
'Authorization': f'Basic {encoded_credentials}' |
||||
} |
||||
|
||||
# Get the host URL (base URL only) |
||||
host_url = os.getenv("LLM_API_URL").rstrip('/api/chat/') |
||||
|
||||
|
||||
# Initialize the client with the host and headers |
||||
auth = httpx.BasicAuth( |
||||
username='lasse', password=os.getenv("LLM_API_PWD_LASSE") |
||||
) |
||||
client = httpx.Client(auth=auth) |
||||
client = Client( |
||||
host=host_url, |
||||
headers=headers |
||||
host="http://localhost:11434", |
||||
headers={ |
||||
"X-Chosen-Backend": "backend_ollama" # Add this header to specify the chosen backend |
||||
}, |
||||
auth=auth |
||||
) |
||||
response = client.chat( |
||||
model=os.getenv("LLM_MODEL"), |
||||
messages=[ |
||||
{ |
||||
"role": "user", |
||||
"content": "Why is the sky blue?", |
||||
}, |
||||
], |
||||
) |
||||
|
||||
# Example usage of the client |
||||
try: |
||||
response = client.chat(model=os.getenv('LLM_MODEL') , messages=[{'role': 'user', 'content': 'Why is the sky blue?'}]) |
||||
print_rainbow(response) |
||||
except Exception as e: |
||||
print(f"Error: {e}") |
||||
# Print the response headers |
||||
|
||||
# Print the chosen backend from the headers |
||||
print("Chosen Backend:", response.headers.get("X-Chosen-Backend")) |
||||
|
||||
# Print the response content |
||||
print(response) |
||||
@ -0,0 +1,9 @@ |
||||
from _llm import LLM |
||||
|
||||
llm = LLM() |
||||
|
||||
image = '/home/lasse/sci/test_image.png' |
||||
image_bytes = open(image, 'rb').read() |
||||
print(type(image_bytes)) |
||||
response = llm.generate('What is this?', images=[image_bytes]) |
||||
print(response) |
||||
@ -0,0 +1,59 @@ |
||||
import io |
||||
import os |
||||
import requests |
||||
from pydub import AudioSegment |
||||
import streamlit as st |
||||
|
||||
def streamlit_audio(uploaded_file): |
||||
if uploaded_file is not None: |
||||
# Read the uploaded file into a BytesIO buffer |
||||
file_extension = os.path.splitext(uploaded_file.name)[1].lower() |
||||
filename = uploaded_file.name |
||||
input_file_buffer = io.BytesIO(uploaded_file.getvalue()) |
||||
|
||||
progress_bar = st.progress(0) |
||||
status_text = st.empty() |
||||
|
||||
if file_extension in ['.m4a', '.mp3', '.wav', '.flac']: |
||||
# Handle audio files |
||||
audio = AudioSegment.from_file(input_file_buffer, format=file_extension.replace('.', '')) |
||||
audio = audio.set_channels(1) # Convert to mono |
||||
file_buffer = io.BytesIO() |
||||
audio.export(file_buffer, format="mp3", bitrate="64k") |
||||
file_buffer.seek(0) |
||||
progress_bar.progress(50) |
||||
status_text.text("Audio file converted.") |
||||
else: |
||||
st.error("Unsupported file type") |
||||
st.stop() |
||||
|
||||
# Send the converted audio data to the transcription service |
||||
try: |
||||
response = transcribe(file_buffer, filename) |
||||
response_json = response.json() |
||||
progress_bar.progress(100) |
||||
status_text.text("File uploaded and processed.") |
||||
|
||||
if response.status_code == 200: |
||||
transcription_content = response_json.get("transcription", "") |
||||
st.subheader("Transcription") |
||||
st.text_area("Transcription Content", transcription_content, height=300) |
||||
transcription_filename = os.path.splitext(filename)[0] + '.vtt' |
||||
st.download_button( |
||||
label="Download Transcription", |
||||
data=transcription_content, |
||||
file_name=transcription_filename, |
||||
mime='text/vtt' |
||||
) |
||||
else: |
||||
st.error("Failed to upload and process the file.") |
||||
except requests.exceptions.Timeout: |
||||
st.error("The request timed out. Please try again later.") |
||||
|
||||
def transcribe(file_buffer, filename): |
||||
url = "http://98.128.172.165:4001/upload" |
||||
# Prepare the files dictionary for the POST request |
||||
files = {'file': (filename, file_buffer, 'audio/mp3')} |
||||
# Send the POST request with the file buffer |
||||
response = requests.post(url, files=files, timeout=3600) |
||||
return response |
||||
Loading…
Reference in new issue