You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

335 lines
14 KiB

from _llm import LLM
import os
import re
import random
from typing import List, Dict
from atproto import models
from types import SimpleNamespace
import requests
from atproto import (
CAR,
AtUri,
Client,
FirehoseSubscribeReposClient,
firehose_models,
models,
parse_subscribe_repos_message,
models,
)
from colorprinter.print_color import *
from datetime import datetime
from semantic_text_splitter import MarkdownSplitter
from env_manager import set_env
set_env()
class Chat:
def __init__(self, bot_username, poster_username):
self.bot_username = bot_username
self.poster_username = poster_username
self.messages = []
self.thread_posts = []
class Bot:
def __init__(self):
# Create a client instance to interact with Bluesky
self.username = os.getenv("BLUESKY_USERNAME")
self.max_length_answer = 280
system_message = '''
You are a research assistant bot chatting with a user on Bluesky, a social media platform similar to Twitter.
Your speciality is electric cars, and you will use facts in articles to answer the questions
Use ONLY the information in the articles to answer the questions. Do not add any additional information or speculation.
IF you don't know the answer, you can say "I don't know" or "I'm not sure". You can also ask the user to specify the question.
Your answers should be concise and NOT EXCEED 1000 CHARACTERS to fit the character limit on Bluesky.
Always answer in English.
'''
self.llm: LLM = LLM(system_message=system_message)
post_maker_system_message = f'''
You will get a text and you have to format it for Bluesky, a plaform similar to Twitter.
You should format the text in a thread of posts with a maximum of {self.max_length_answer} characters per post.
It's VERY important to keep the text as close as possible to the original text.
Format the thread without any additional formatation, just plain text.
Add "---" to separate the posts. Don't add a counter of any type, that will be added automatically.
'''
self.post_maker = LLM(system_message=post_maker_system_message, model="small")
self.client = Client()
self.client.login(self.username, os.getenv("BLUESKY_PASSWORD"))
self.chat = None
self.pds_url = 'https://bsky.social'
print("🐟 Bot is listening")
# Create a firehose client to subscribe to repository events
self.firehose = FirehoseSubscribeReposClient()
# Start the firehose to listen for repository events
self.firehose.start(self.on_message_handler)
def answer_message(self, message):
response = self.llm.generate(message).content
self.client.send_post(response)
def get_file_extension(self, file_path):
# Utility function to get the file extension from a file path
return os.path.splitext(file_path)[1]
def bot_mentioned(self, text: str) -> bool:
# Check if the text contains 'cc: dreambot' (case-insensitive)
return self.username.lower() in text.lower()
def parse_thread(self, author_did, thread):
# Traverse the thread to collect prompts from posts by the author
entries = []
stack = [thread]
while stack:
current_thread = stack.pop()
if current_thread is None:
continue
if current_thread.post.author.did == author_did:
print(current_thread.post.record.text)
# Extract prompt from the current post
entries.append(current_thread.post.record.text)
# Add parent thread to the stack for further traversal
stack.append(current_thread.parent)
return entries
def process_operation(
self,
op: models.ComAtprotoSyncSubscribeRepos.RepoOp,
car: CAR,
commit: models.ComAtprotoSyncSubscribeRepos.Commit,
) -> None:
# Construct the URI for the operation
uri = AtUri.from_str(f"at://{commit.repo}/{op.path}")
if op.action == "create":
if not op.cid:
return
# Retrieve the record from the CAR file using the content ID (CID)
record = car.blocks.get(op.cid)
if not record:
return
# Build the record with additional metadata
record = {
"uri": str(uri),
"cid": str(op.cid),
"author": commit.repo,
**record,
}
# Check if the operation is a post in the feed
if uri.collection == models.ids.AppBskyFeedPost:
if self.bot_mentioned(record["text"]):
poster_username = self.client.get_profile(actor=record["author"]).handle
self.chat = Chat(self.username, poster_username)
posts_in_thread = self.client.get_post_thread(uri=record["uri"])
self.traverse_thread(posts_in_thread.thread)
self.chat.thread_posts.sort(key=lambda x: x["timestamp"])
self.make_llm_messages()
print_purple(self.chat.messages)
answer = self.llm.generate(messages=self.chat.messages)
print_green(answer.content)
print_purple('Length of answer:', len(answer.content))
answer_as_thread = False
if len(answer.content) > self.max_length_answer:
#Save the answer as a unique file as html at bluesky_bot_answers
filename_answer = f'bluesky_bot_answers/{record["cid"]}_{random.randint(a=10000, b=99000)}.html'
with open(filename_answer, 'w') as f:
f.write(answer.content)
formated_answer = self.post_maker.generate(query=f"Format the text below as a thread of posts with a maximum of {self.max_length_answer} characters per post.\n\n{answer.content}")
# # Optionally can also have the splitter not trim whitespace for you
# splitter = MarkdownSplitter(self.max_length_answer)
# chunks = splitter.chunks(answer.content)
chunks = formated_answer.content.split('---')
print_yellow('Formated answer')
print_rainbow(chunks)
answer_as_thread = True
record_obj = SimpleNamespace(cid=record["cid"], uri=record["uri"])
parent = models.create_strong_ref(record_obj)
root_obj = SimpleNamespace(cid=record['reply']['root']["cid"], uri=record['reply']['root']["uri"])
root_post = models.create_strong_ref(root_obj)
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post)
mention_handle = f"@{poster_username}"
if not answer_as_thread:
text = f"{mention_handle} {answer.content}"
facets = self.parse_facets(text)
print('Handle:', mention_handle)
print(f"Facets")
print(facets)
sent_answer = self.client.send_post(text=text, facets=facets, reply_to=reply_ref, langs=["en-US"])
else:
for n, chunk in enumerate(chunks, start=1):
chunk = chunk.strip()
text = f"{chunk}\n({n}/{len(chunks)})"
facets = self.parse_facets(text)
sent_answer = self.client.send_post(text=text, facets=facets, reply_to=reply_ref, langs=["en-US"])
parent = models.create_strong_ref(sent_answer)
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post)
text = f'The answers above are a blueskyified version of the original answer. The original answer can be found at https://sci.assistant.fish/answers/{filename_answer}'
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post)
sent_answer = self.client.send_post(text=text, reply_to=reply_ref, langs=["en-US"])
if op.action == "delete":
# Handle delete operations (not implemented)
return
if op.action == "update":
# Handle update operations (not implemented)
return
return
def traverse_thread(self, thread_view_post):
# Process the current post
post = thread_view_post.post
author_handle = post.author.handle
post_text = post.record.text
timestamp = int(
datetime.fromisoformat(post.indexed_at.replace("Z", "+00:00")).timestamp()
)
self.chat.thread_posts.append(
{
"user": author_handle,
"text": post_text.replace("\n", " "),
"timestamp": timestamp,
}
)
# If there's a parent, process it
if thread_view_post.parent:
self.traverse_thread(thread_view_post.parent)
# If there are replies, process them
if getattr(thread_view_post, "replies", None):
for reply in thread_view_post.replies:
self.traverse_thread(reply)
def make_llm_messages(self):
"""
Processes the chat thread posts and compiles them into a list of messages
formatted for a language model (LLM).
The function performs the following steps:
1. Iterates through the chat thread posts.
2. Starts processing messages only after encountering a message mentioning the bot.
3. Adds messages from the bot and the poster to the `self.chat.messages` list in the
appropriate format for the LLM.
The messages are formatted as follows:
- Messages from the bot are added with the role "assistant".
- Messages from the poster are added with the role "user".
- Consecutive messages from the same user are concatenated.
Returns:
None
"""
start = False
for i in self.chat.thread_posts:
# Make the messages start with a message mentioning the bot
if self.chat.bot_username in i["text"]:
start = True
elif self.chat.bot_username not in i["text"] and not start:
continue
# Compile the messages int a list for LLM
if (
i["user"] == self.chat.bot_username
and len(self.chat.messages) > 0
and self.chat.messages[-1] != self.chat.bot_username
):
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip()
self.chat.messages.append({"role": "assistant", "content": i["text"]})
elif i["user"] == self.chat.poster_username:
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip()
if len(self.chat.messages) > 0 and self.chat.messages[-1]['role'] == 'user':
self.chat.messages[-1]['content'] += f"\n\n{i['text']}"
else:
self.chat.messages.append({"role": "user", "content": i["text"]})
def on_message_handler(self, message: firehose_models.MessageFrame) -> None:
# Callback function that handles incoming messages from the firehose subscription
# Parse the incoming message to extract the commit information
commit = parse_subscribe_repos_message(message)
# Check if the parsed message is a Commit and if the commit contains blocks of data
if not isinstance(
commit, models.ComAtprotoSyncSubscribeRepos.Commit
) or not isinstance(commit.blocks, bytes):
# If the message is not a valid commit or blocks are missing, exit early
return
# Parse the CAR (Content Addressable aRchive) file from the commit's blocks
# The CAR file contains the data blocks referenced in the commit operations
car = CAR.from_bytes(commit.blocks)
# Iterate over each operation (e.g., create, delete, update) in the commit
for op in commit.ops:
# Process each operation using the process_operation method
# This method handles the logic based on the type of operation
self.process_operation(op, car, commit)
# Parse facets from text and resolve the handles to DIDs
def parse_facets(self, text: str) -> List[Dict]:
facets = []
for m in self.parse_mentions(text):
resp = requests.get(
self.pds_url + "/xrpc/com.atproto.identity.resolveHandle",
params={"handle": m["handle"]},
)
# If the handle can't be resolved, just skip it!
# It will be rendered as text in the post instead of a link
if resp.status_code == 400:
continue
did = resp.json()["did"]
facets.append({
"index": {
"byteStart": m["start"],
"byteEnd": m["end"],
},
"features": [{"$type": "app.bsky.richtext.facet#mention", "did": did}],
})
return facets
def parse_mentions(self, text: str) -> List[Dict]:
spans = []
# Simplified regex to match handles
mention_regex = rb"(@[a-zA-Z0-9._-]+)"
text_bytes = text.encode("UTF-8")
for m in re.finditer(mention_regex, text_bytes):
spans.append({
"start": m.start(1),
"end": m.end(1),
"handle": m.group(1)[1:].decode("UTF-8")
})
return spans
def main() -> None:
Bot()
if __name__ == "__main__":
main()