You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
335 lines
14 KiB
335 lines
14 KiB
from _llm import LLM |
|
import os |
|
import re |
|
import random |
|
from typing import List, Dict |
|
from atproto import models |
|
from types import SimpleNamespace |
|
import requests |
|
from atproto import ( |
|
CAR, |
|
AtUri, |
|
Client, |
|
FirehoseSubscribeReposClient, |
|
firehose_models, |
|
models, |
|
parse_subscribe_repos_message, |
|
models, |
|
) |
|
from colorprinter.print_color import * |
|
from datetime import datetime |
|
|
|
from semantic_text_splitter import MarkdownSplitter |
|
|
|
|
|
from env_manager import set_env |
|
set_env() |
|
|
|
|
|
class Chat: |
|
def __init__(self, bot_username, poster_username): |
|
self.bot_username = bot_username |
|
self.poster_username = poster_username |
|
self.messages = [] |
|
self.thread_posts = [] |
|
|
|
class Bot: |
|
def __init__(self): |
|
|
|
# Create a client instance to interact with Bluesky |
|
self.username = os.getenv("BLUESKY_USERNAME") |
|
self.max_length_answer = 280 |
|
system_message = ''' |
|
You are a research assistant bot chatting with a user on Bluesky, a social media platform similar to Twitter. |
|
Your speciality is electric cars, and you will use facts in articles to answer the questions |
|
Use ONLY the information in the articles to answer the questions. Do not add any additional information or speculation. |
|
IF you don't know the answer, you can say "I don't know" or "I'm not sure". You can also ask the user to specify the question. |
|
Your answers should be concise and NOT EXCEED 1000 CHARACTERS to fit the character limit on Bluesky. |
|
Always answer in English. |
|
''' |
|
self.llm: LLM = LLM(system_message=system_message) |
|
|
|
post_maker_system_message = f''' |
|
You will get a text and you have to format it for Bluesky, a plaform similar to Twitter. |
|
You should format the text in a thread of posts with a maximum of {self.max_length_answer} characters per post. |
|
It's VERY important to keep the text as close as possible to the original text. |
|
Format the thread without any additional formatation, just plain text. |
|
Add "---" to separate the posts. Don't add a counter of any type, that will be added automatically. |
|
''' |
|
self.post_maker = LLM(system_message=post_maker_system_message, model="small") |
|
|
|
self.client = Client() |
|
self.client.login(self.username, os.getenv("BLUESKY_PASSWORD")) |
|
self.chat = None |
|
self.pds_url = 'https://bsky.social' |
|
|
|
print("🐟 Bot is listening") |
|
|
|
# Create a firehose client to subscribe to repository events |
|
self.firehose = FirehoseSubscribeReposClient() |
|
# Start the firehose to listen for repository events |
|
self.firehose.start(self.on_message_handler) |
|
|
|
def answer_message(self, message): |
|
response = self.llm.generate(message).content |
|
self.client.send_post(response) |
|
|
|
def get_file_extension(self, file_path): |
|
# Utility function to get the file extension from a file path |
|
return os.path.splitext(file_path)[1] |
|
|
|
def bot_mentioned(self, text: str) -> bool: |
|
# Check if the text contains 'cc: dreambot' (case-insensitive) |
|
return self.username.lower() in text.lower() |
|
|
|
def parse_thread(self, author_did, thread): |
|
# Traverse the thread to collect prompts from posts by the author |
|
entries = [] |
|
stack = [thread] |
|
|
|
while stack: |
|
current_thread = stack.pop() |
|
|
|
if current_thread is None: |
|
continue |
|
|
|
if current_thread.post.author.did == author_did: |
|
print(current_thread.post.record.text) |
|
# Extract prompt from the current post |
|
entries.append(current_thread.post.record.text) |
|
|
|
# Add parent thread to the stack for further traversal |
|
stack.append(current_thread.parent) |
|
|
|
return entries |
|
|
|
|
|
def process_operation( |
|
self, |
|
op: models.ComAtprotoSyncSubscribeRepos.RepoOp, |
|
car: CAR, |
|
commit: models.ComAtprotoSyncSubscribeRepos.Commit, |
|
) -> None: |
|
# Construct the URI for the operation |
|
uri = AtUri.from_str(f"at://{commit.repo}/{op.path}") |
|
|
|
if op.action == "create": |
|
if not op.cid: |
|
return |
|
|
|
# Retrieve the record from the CAR file using the content ID (CID) |
|
record = car.blocks.get(op.cid) |
|
if not record: |
|
return |
|
|
|
# Build the record with additional metadata |
|
record = { |
|
"uri": str(uri), |
|
"cid": str(op.cid), |
|
"author": commit.repo, |
|
**record, |
|
} |
|
|
|
# Check if the operation is a post in the feed |
|
if uri.collection == models.ids.AppBskyFeedPost: |
|
if self.bot_mentioned(record["text"]): |
|
poster_username = self.client.get_profile(actor=record["author"]).handle |
|
self.chat = Chat(self.username, poster_username) |
|
posts_in_thread = self.client.get_post_thread(uri=record["uri"]) |
|
self.traverse_thread(posts_in_thread.thread) |
|
self.chat.thread_posts.sort(key=lambda x: x["timestamp"]) |
|
self.make_llm_messages() |
|
print_purple(self.chat.messages) |
|
answer = self.llm.generate(messages=self.chat.messages) |
|
print_green(answer.content) |
|
print_purple('Length of answer:', len(answer.content)) |
|
answer_as_thread = False |
|
if len(answer.content) > self.max_length_answer: |
|
#Save the answer as a unique file as html at bluesky_bot_answers |
|
filename_answer = f'bluesky_bot_answers/{record["cid"]}_{random.randint(a=10000, b=99000)}.html' |
|
with open(filename_answer, 'w') as f: |
|
f.write(answer.content) |
|
formated_answer = self.post_maker.generate(query=f"Format the text below as a thread of posts with a maximum of {self.max_length_answer} characters per post.\n\n{answer.content}") |
|
# # Optionally can also have the splitter not trim whitespace for you |
|
# splitter = MarkdownSplitter(self.max_length_answer) |
|
|
|
# chunks = splitter.chunks(answer.content) |
|
chunks = formated_answer.content.split('---') |
|
print_yellow('Formated answer') |
|
print_rainbow(chunks) |
|
answer_as_thread = True |
|
|
|
record_obj = SimpleNamespace(cid=record["cid"], uri=record["uri"]) |
|
parent = models.create_strong_ref(record_obj) |
|
|
|
root_obj = SimpleNamespace(cid=record['reply']['root']["cid"], uri=record['reply']['root']["uri"]) |
|
root_post = models.create_strong_ref(root_obj) |
|
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post) |
|
|
|
mention_handle = f"@{poster_username}" |
|
|
|
if not answer_as_thread: |
|
text = f"{mention_handle} {answer.content}" |
|
facets = self.parse_facets(text) |
|
print('Handle:', mention_handle) |
|
print(f"Facets") |
|
print(facets) |
|
sent_answer = self.client.send_post(text=text, facets=facets, reply_to=reply_ref, langs=["en-US"]) |
|
else: |
|
for n, chunk in enumerate(chunks, start=1): |
|
chunk = chunk.strip() |
|
text = f"{chunk}\n({n}/{len(chunks)})" |
|
facets = self.parse_facets(text) |
|
sent_answer = self.client.send_post(text=text, facets=facets, reply_to=reply_ref, langs=["en-US"]) |
|
parent = models.create_strong_ref(sent_answer) |
|
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post) |
|
text = f'The answers above are a blueskyified version of the original answer. The original answer can be found at https://sci.assistant.fish/answers/{filename_answer}' |
|
reply_ref = models.AppBskyFeedPost.ReplyRef(parent=parent, root=root_post) |
|
sent_answer = self.client.send_post(text=text, reply_to=reply_ref, langs=["en-US"]) |
|
|
|
|
|
if op.action == "delete": |
|
# Handle delete operations (not implemented) |
|
return |
|
|
|
if op.action == "update": |
|
# Handle update operations (not implemented) |
|
return |
|
|
|
return |
|
|
|
|
|
def traverse_thread(self, thread_view_post): |
|
# Process the current post |
|
post = thread_view_post.post |
|
author_handle = post.author.handle |
|
post_text = post.record.text |
|
timestamp = int( |
|
datetime.fromisoformat(post.indexed_at.replace("Z", "+00:00")).timestamp() |
|
) |
|
self.chat.thread_posts.append( |
|
{ |
|
"user": author_handle, |
|
"text": post_text.replace("\n", " "), |
|
"timestamp": timestamp, |
|
} |
|
) |
|
|
|
# If there's a parent, process it |
|
if thread_view_post.parent: |
|
self.traverse_thread(thread_view_post.parent) |
|
|
|
# If there are replies, process them |
|
if getattr(thread_view_post, "replies", None): |
|
for reply in thread_view_post.replies: |
|
self.traverse_thread(reply) |
|
|
|
def make_llm_messages(self): |
|
""" |
|
Processes the chat thread posts and compiles them into a list of messages |
|
formatted for a language model (LLM). |
|
|
|
The function performs the following steps: |
|
1. Iterates through the chat thread posts. |
|
2. Starts processing messages only after encountering a message mentioning the bot. |
|
3. Adds messages from the bot and the poster to the `self.chat.messages` list in the |
|
appropriate format for the LLM. |
|
|
|
The messages are formatted as follows: |
|
- Messages from the bot are added with the role "assistant". |
|
- Messages from the poster are added with the role "user". |
|
- Consecutive messages from the same user are concatenated. |
|
|
|
Returns: |
|
None |
|
""" |
|
start = False |
|
for i in self.chat.thread_posts: |
|
# Make the messages start with a message mentioning the bot |
|
if self.chat.bot_username in i["text"]: |
|
start = True |
|
elif self.chat.bot_username not in i["text"] and not start: |
|
continue |
|
# Compile the messages int a list for LLM |
|
if ( |
|
i["user"] == self.chat.bot_username |
|
and len(self.chat.messages) > 0 |
|
and self.chat.messages[-1] != self.chat.bot_username |
|
): |
|
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip() |
|
self.chat.messages.append({"role": "assistant", "content": i["text"]}) |
|
elif i["user"] == self.chat.poster_username: |
|
i['text'] = i['text'].replace(f"@{self.chat.bot_username}", "").strip() |
|
if len(self.chat.messages) > 0 and self.chat.messages[-1]['role'] == 'user': |
|
self.chat.messages[-1]['content'] += f"\n\n{i['text']}" |
|
else: |
|
self.chat.messages.append({"role": "user", "content": i["text"]}) |
|
|
|
|
|
|
|
def on_message_handler(self, message: firehose_models.MessageFrame) -> None: |
|
# Callback function that handles incoming messages from the firehose subscription |
|
|
|
# Parse the incoming message to extract the commit information |
|
commit = parse_subscribe_repos_message(message) |
|
|
|
# Check if the parsed message is a Commit and if the commit contains blocks of data |
|
if not isinstance( |
|
commit, models.ComAtprotoSyncSubscribeRepos.Commit |
|
) or not isinstance(commit.blocks, bytes): |
|
# If the message is not a valid commit or blocks are missing, exit early |
|
return |
|
|
|
# Parse the CAR (Content Addressable aRchive) file from the commit's blocks |
|
# The CAR file contains the data blocks referenced in the commit operations |
|
car = CAR.from_bytes(commit.blocks) |
|
|
|
# Iterate over each operation (e.g., create, delete, update) in the commit |
|
for op in commit.ops: |
|
# Process each operation using the process_operation method |
|
# This method handles the logic based on the type of operation |
|
self.process_operation(op, car, commit) |
|
|
|
|
|
# Parse facets from text and resolve the handles to DIDs |
|
def parse_facets(self, text: str) -> List[Dict]: |
|
facets = [] |
|
for m in self.parse_mentions(text): |
|
resp = requests.get( |
|
self.pds_url + "/xrpc/com.atproto.identity.resolveHandle", |
|
params={"handle": m["handle"]}, |
|
) |
|
# If the handle can't be resolved, just skip it! |
|
# It will be rendered as text in the post instead of a link |
|
if resp.status_code == 400: |
|
continue |
|
did = resp.json()["did"] |
|
facets.append({ |
|
"index": { |
|
"byteStart": m["start"], |
|
"byteEnd": m["end"], |
|
}, |
|
"features": [{"$type": "app.bsky.richtext.facet#mention", "did": did}], |
|
}) |
|
|
|
return facets |
|
|
|
def parse_mentions(self, text: str) -> List[Dict]: |
|
spans = [] |
|
# Simplified regex to match handles |
|
mention_regex = rb"(@[a-zA-Z0-9._-]+)" |
|
text_bytes = text.encode("UTF-8") |
|
for m in re.finditer(mention_regex, text_bytes): |
|
spans.append({ |
|
"start": m.start(1), |
|
"end": m.end(1), |
|
"handle": m.group(1)[1:].decode("UTF-8") |
|
}) |
|
return spans |
|
|
|
def main() -> None: |
|
Bot() |
|
|
|
|
|
if __name__ == "__main__": |
|
main()
|
|
|