You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
264 lines
9.4 KiB
264 lines
9.4 KiB
from time import sleep |
|
import requests |
|
import concurrent.futures |
|
import queue |
|
import threading |
|
from _arango import arango |
|
|
|
|
|
class LLM: |
|
def __init__( |
|
self, |
|
chat=False, |
|
model="llama3:8b-instruct-q5_K_M", |
|
keep_alive=3600 * 24, |
|
start=False, |
|
): |
|
""" |
|
Initializes an instance of MyClass. |
|
|
|
Args: |
|
chat (bool, optional): Specifies whether the instance is for chat purposes. Defaults to False. |
|
model (str, optional): The model to be used. Defaults to "llama3:8b-instruct-q5_K_M". |
|
keep_alive (int, optional): The duration in seconds to keep the instance alive. Defaults to 3600*24. |
|
start (bool, optional): If True, the instance will automatically start processing requests upon initialization. |
|
This means that a separate thread will be started that runs the generate_concurrent method, |
|
which processes requests concurrently. Defaults to False. |
|
""" |
|
|
|
self.server = "192.168.1.12" |
|
self.port = 3300 # 11440 All 4 GPU # 4500 "SW" |
|
self.model = model |
|
self.temperature = 0 |
|
self.system_message = 'Svara alltid på svenska. Svara bara på det som efterfrågas. Om du inte kan svara, skriv "Jag vet inte".' |
|
self.messages = [{"role": "system", "content": self.system_message}] |
|
self.chat = chat |
|
self.max_tokens = 24000 |
|
self.keep_alive = keep_alive |
|
self.request_queue = queue.Queue() |
|
self.result_queue = queue.Queue() |
|
self.all_requests_added_event = threading.Event() |
|
self.all_results_processed_event = threading.Event() |
|
self.stop_event = threading.Event() |
|
|
|
if start: |
|
self.start() |
|
|
|
def generate(self, message): |
|
|
|
# Prepare the request data |
|
options = { |
|
"temperature": self.temperature, |
|
} |
|
|
|
if self.chat: |
|
self.build_message(message) |
|
messages = self.messages |
|
else: |
|
self.messages.append({"role": "user", "content": message}) |
|
messages = self.messages |
|
|
|
data = { |
|
"model": self.model, |
|
"messages": messages, |
|
"options": options, |
|
"keep_alive": self.keep_alive, |
|
"stream": False, |
|
} |
|
|
|
# Make a POST request to the API endpoint |
|
result = requests.post( |
|
f"http://{self.server}:{self.port}/api/chat", json=data |
|
).json() |
|
|
|
if "message" in result: |
|
answer = result["message"]["content"] |
|
else: |
|
from pprint import pprint |
|
pprint(result) |
|
raise Exception("Error occurred during API request") |
|
|
|
if self.chat: |
|
self.messages.append({"role": "assistant", "content": answer}) |
|
|
|
return answer |
|
|
|
def generate_concurrent( |
|
self, |
|
request_queue, |
|
result_queue, |
|
all_requests_added_event, |
|
all_results_processed_event, |
|
): |
|
self.chat = False |
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future_to_message = {} |
|
buffer_size = 6 # The number of tasks to keep in the executor |
|
while True: |
|
if self.stop_event.is_set(): |
|
break |
|
try: |
|
# If there are less than buffer_size tasks being processed, add new tasks |
|
while len(future_to_message) < buffer_size: |
|
# Take a request from the queue |
|
doc_id, message = request_queue.get(timeout=1) |
|
# Submit the generate method to the executor for execution |
|
future = executor.submit(self.generate, message) |
|
future_to_message[future] = doc_id |
|
except queue.Empty: |
|
# If the queue is empty and all requests have been added, break the loop |
|
if all_requests_added_event.is_set(): |
|
break |
|
else: |
|
continue |
|
|
|
# Process completed futures |
|
done_futures = [f for f in future_to_message if f.done()] |
|
for future in done_futures: |
|
doc_id = future_to_message.pop(future) |
|
try: |
|
summary = future.result() |
|
except Exception as exc: |
|
print("Document %r generated an exception: %s" % (doc_id, exc)) |
|
else: |
|
# Put the document ID and the summary into the result queue |
|
result_queue.put((doc_id, summary)) |
|
|
|
all_results_processed_event.set() |
|
def start(self): |
|
# Start a separate thread that runs the generate_concurrent method |
|
threading.Thread( |
|
target=self.generate_concurrent, |
|
args=( |
|
self.request_queue, |
|
self.result_queue, |
|
self.all_requests_added_event, |
|
self.all_results_processed_event, |
|
), |
|
).start() |
|
|
|
def stop(self): |
|
""" |
|
Stops the instance from processing further requests. |
|
""" |
|
self.stop_event.set() |
|
|
|
def add_request(self, id, prompt): |
|
# Add a request to the request queue |
|
self.request_queue.put((id, prompt)) |
|
|
|
def finish_adding_requests(self): |
|
# Signal that all requests have been added |
|
print("\033[92mAll requests added\033[0m") |
|
self.all_requests_added_event.set() |
|
|
|
def get_results(self): |
|
# Process the results |
|
while True: |
|
try: |
|
# Take a result from the result queue |
|
doc_id, summary = self.result_queue.get(timeout=1) |
|
return doc_id, summary |
|
|
|
except queue.Empty: |
|
# If the result queue is empty and all results have been processed, break the loop |
|
if self.all_results_processed_event.is_set(): |
|
break |
|
else: |
|
sleep(0.2) |
|
continue |
|
|
|
def build_message(self, message): |
|
# Add the new message to the list |
|
self.messages.append({"role": "user", "content": message}) |
|
|
|
# Calculate the total token length of the messages |
|
total_tokens = sum([len((msg["content"])) for msg in self.messages]) |
|
|
|
# While the total token length exceeds the limit, remove the oldest messages |
|
while total_tokens > self.max_tokens: |
|
removed_message = self.messages.pop( |
|
1 |
|
) # Remove the oldest message (not the system message) |
|
total_tokens -= len((removed_message["content"])) |
|
|
|
def unload_model(self): |
|
data = { |
|
"model": self.model, |
|
"messages": self.messages, |
|
"keep_alive": 0, |
|
"stream": False, |
|
} |
|
|
|
# Make a POST request to the API endpoint |
|
requests.post(f"http://{self.server}:{self.port}/api/chat", json=data).json()[ |
|
"message" |
|
]["content"] |
|
|
|
|
|
if __name__ == "__main__": |
|
# Initialize the LLM object |
|
llm = LLM(chat=False, model="llama3:8b-instruct-q5_K_M") |
|
|
|
# Create a queue for requests and a queue for results |
|
request_queue = queue.Queue() |
|
result_queue = queue.Queue() |
|
|
|
# Create an event to signal when all requests have been added |
|
all_requests_added_event = threading.Event() |
|
all_results_processed_event = threading.Event() |
|
|
|
# Start a separate thread that runs the generate_concurrent method |
|
threading.Thread( |
|
target=llm.generate_concurrent, |
|
args=( |
|
request_queue, |
|
result_queue, |
|
all_requests_added_event, |
|
all_results_processed_event, |
|
), |
|
).start() |
|
|
|
# Add requests to the request queue |
|
from _arango import arango |
|
|
|
interrogations = arango.db.collection("interrogations").all() |
|
for doc in interrogations: |
|
text = doc["text"] |
|
prompt = f'Kolla på texten nedan: \n\n """{text}""" \n\n Sammanfatta förhöret med fokus på vad som sades, inte var det hölls eller annat formalia. Svara så kort som möjligt men var noga med detaljer som händelser som beskrivs, namn, datum och platser.\nKort sammanfattning:' |
|
request_queue.put((doc["_key"], prompt)) |
|
|
|
# Signal that all requests have been added |
|
all_requests_added_event.set() |
|
|
|
# Process the results |
|
while True: |
|
try: |
|
# Take a result from the result queue |
|
doc_id, summary = result_queue.get(timeout=1) |
|
print("\033[92m" + doc_id + "\033[0m", summary) |
|
# Update the document with the summary |
|
arango.db.collection("interrogations").update_match( |
|
{"_key": doc_id}, {"summary": summary} |
|
) |
|
except queue.Empty: |
|
# If the result queue is empty and all results have been processed, break the loop |
|
if all_results_processed_event.is_set(): |
|
break |
|
else: |
|
continue |
|
|
|
# import argparse |
|
# parser = argparse.ArgumentParser() |
|
# parser.add_argument("--unload", action="store_true", help="Unload the model") |
|
# args = parser.parse_args() |
|
|
|
# #llm = LLM(model='llama3:70b-text-q4_K_M', keep_alive=6000, chat=True) |
|
# llm = LLM(keep_alive=60, chat=True) |
|
|
|
# if args.unload: |
|
# llm.unload_model() |
|
# else: |
|
# while True: |
|
# message = input(">>> ") |
|
# print(llm.generate(message))
|
|
|