You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
8.8 KiB
261 lines
8.8 KiB
from time import sleep |
|
import requests |
|
import concurrent.futures |
|
import queue |
|
import threading |
|
import re |
|
from dotenv import load_dotenv |
|
import os |
|
import json |
|
from print_color import * |
|
|
|
load_dotenv() |
|
|
|
|
|
class LLM: |
|
def __init__( |
|
self, |
|
chat: bool = False, |
|
model: str = os.getenv("LLM_MODEL"), |
|
keep_alive: int = 3600 * 24, |
|
start: bool = False, |
|
system_prompt: str = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".', |
|
temperature: str = 0, |
|
stream=False, |
|
small=False |
|
): |
|
""" |
|
Initializes an instance of MyClass. |
|
|
|
Args: |
|
chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False. |
|
model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M". |
|
keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24. |
|
start (bool, optional): If True, the instance will automatically start processing requests upon initialization. |
|
This means that a separate thread will be started that runs the generate_concurrent method, |
|
which processes requests concurrently. Defaults to False. |
|
""" |
|
|
|
self.model = model |
|
self.server = os.getenv("LLM_URL") |
|
self.port = os.getenv("LLM_PORT") |
|
if small: |
|
self.model = os.getenv("LLM_SMALL_MODEL") |
|
self.server = os.getenv("LLM_SMALL_URL") |
|
self.port = os.getenv("LLM_SMALL_PORT") |
|
self.temperature = temperature |
|
self.system_message = {"role": "system", "content": system_prompt} |
|
self.messages = [self.system_message] |
|
self.chat = chat |
|
self.max_length = 24000 |
|
self.keep_alive = keep_alive |
|
self.request_queue = queue.Queue() |
|
self.result_queue = queue.Queue() |
|
self.all_requests_added_event = threading.Event() |
|
self.all_results_processed_event = threading.Event() |
|
self.stop_event = threading.Event() |
|
self.stream = stream |
|
|
|
if start: |
|
self.start() |
|
|
|
|
|
def create_data_request(self, message): |
|
# Remove leading and trailing whitespace |
|
message = '\n'.join(line.strip() for line in message.split('\n')) |
|
|
|
# Prepare the request data |
|
options = { |
|
"temperature": self.temperature, |
|
} |
|
|
|
if self.chat: |
|
self.build_message(message) |
|
messages = self.messages |
|
else: |
|
messages = [self.system_message, {"role": "user", "content": message}] |
|
|
|
data = { |
|
"model": self.model, |
|
"messages": messages, |
|
"options": options, |
|
"keep_alive": self.keep_alive, |
|
"stream": self.stream, |
|
} |
|
|
|
return data |
|
|
|
|
|
def generate_stream(self, message): |
|
# Make a POST request to the API endpoint |
|
data = self.create_data_request(message) |
|
|
|
response = requests.post( |
|
f"http://{self.server}:{self.port}/api/chat", json=data, stream=True |
|
) |
|
|
|
# Iterate over the response |
|
# Iterate over the response |
|
for line in response.iter_lines(): |
|
# Filter out keep-alive new lines |
|
if line: |
|
decoded_line = line.decode('utf-8') |
|
json_line = json.loads(decoded_line) # Parse the line as JSON |
|
yield json_line['message']['content'] |
|
|
|
|
|
def generate(self, message): |
|
|
|
data = self.create_data_request(message) |
|
# Make a POST request to the API endpoint |
|
result = requests.post( |
|
f"http://{self.server}:{self.port}/api/chat", json=data |
|
) |
|
|
|
try: |
|
if 'message' in result.json(): |
|
answer = result.json()["message"]["content"] |
|
else: |
|
print_red(result.content) |
|
raise 'Error occurred during API request' |
|
except requests.exceptions.JSONDecodeError: |
|
print_red(result.content) |
|
raise Exception("Error occurred during API request") |
|
|
|
if self.chat: |
|
self.messages.append({"role": "assistant", "content": answer}) |
|
|
|
return answer |
|
|
|
def generate_concurrent( |
|
self, |
|
request_queue, |
|
result_queue, |
|
all_requests_added_event, |
|
all_results_processed_event, |
|
): |
|
self.chat = False |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future_to_message = {} |
|
buffer_size = 6 # The number of tasks to keep in the executor |
|
while True: |
|
if self.stop_event.is_set(): |
|
break |
|
try: |
|
# If there are less than buffer_size tasks being processed, add new tasks |
|
while len(future_to_message) < buffer_size: |
|
# Take a request from the queue |
|
doc_id, message = request_queue.get(timeout=1) |
|
# Submit the generate method to the executor for execution |
|
future = executor.submit(self.generate, message) |
|
future_to_message[future] = doc_id |
|
except queue.Empty: |
|
# If the queue is empty and all requests have been added, break the loop |
|
if all_requests_added_event.is_set(): |
|
break |
|
else: |
|
continue |
|
|
|
# Process completed futures |
|
done_futures = [f for f in future_to_message if f.done()] |
|
for future in done_futures: |
|
doc_id = future_to_message.pop(future) |
|
try: |
|
summary = future.result() |
|
except Exception as exc: |
|
print("Document %r generated an exception: %s" % (doc_id, exc)) |
|
else: |
|
# Put the document ID and the summary into the result queue |
|
result_queue.put((doc_id, summary)) |
|
|
|
all_results_processed_event.set() |
|
|
|
def start(self): |
|
# Start a separate thread that runs the generate_concurrent method |
|
threading.Thread( |
|
target=self.generate_concurrent, |
|
args=( |
|
self.request_queue, |
|
self.result_queue, |
|
self.all_requests_added_event, |
|
self.all_results_processed_event, |
|
), |
|
).start() |
|
|
|
def stop(self): |
|
""" |
|
Stops the instance from processing further requests. |
|
""" |
|
self.stop_event.set() |
|
|
|
def add_request(self, id, prompt): |
|
# Add a request to the request queue |
|
self.request_queue.put((id, prompt)) |
|
|
|
def finish_adding_requests(self): |
|
# Signal that all requests have been added |
|
print("\033[92mAll requests added\033[0m") |
|
self.all_requests_added_event.set() |
|
|
|
def get_results(self): |
|
# Process the results |
|
while True: |
|
try: |
|
# Take a result from the result queue |
|
doc_id, summary = self.result_queue.get(timeout=1) |
|
return doc_id, summary |
|
|
|
except queue.Empty: |
|
# If the result queue is empty and all results have been processed, break the loop |
|
if self.all_results_processed_event.is_set(): |
|
break |
|
else: |
|
sleep(0.2) |
|
continue |
|
|
|
def build_message(self, message): |
|
# Add the new message to the list |
|
self.messages.append({"role": "user", "content": message}) |
|
|
|
# Calculate the total length of the messages |
|
total_length = sum([len((msg["content"])) for msg in self.messages]) |
|
|
|
# While the total length exceeds the limit, remove the oldest messages |
|
while total_length > self.max_length: |
|
removed_message = self.messages.pop( |
|
1 |
|
) # Remove the oldest message (not the system message) |
|
total_length -= len((removed_message["content"])) |
|
|
|
def unload_model(self): |
|
data = { |
|
"model": self.model, |
|
"messages": self.messages, |
|
"keep_alive": 0, |
|
"stream": False, |
|
} |
|
|
|
# Make a POST request to the API endpoint |
|
requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[ |
|
"message" |
|
]["content"] |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--unload", action="store_true", help="Unload the model") |
|
args = parser.parse_args() |
|
|
|
llm = LLM(keep_alive=60000, chat=True, small=False) |
|
|
|
if args.unload: |
|
llm.unload_model() |
|
else: |
|
while True: |
|
message = input(">>> ") |
|
print(llm.generate(message)) |
|
|
|
|
|
|