You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
3.7 KiB

import os
import requests
from requests.auth import HTTPBasicAuth
import tiktoken
import json
import env_manager
from colorprinter.print_color import *
tokenizer = tiktoken.get_encoding("cl100k_base")
class LLM:
def __init__(
self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000
) -> None:
if model == 'standard':
self.model = os.getenv("LLM_MODEL")
if model == 'small':
self.model = os.getenv('LLM_MODEL_SMALL')
self.system_message = system_message
self.options = {"temperature": temperature, "num_ctx": num_ctx}
self.messages = [{"role": "system", "content": self.system_message}]
self.chat = chat
self.max_length_answer = max_length_answer
def count_tokens(self):
num_tokens = 0
for i in self.messages:
for k, v in i.items():
if k == "content":
if not isinstance(v, str):
v = str(v)
tokens = tokenizer.encode(v)
num_tokens += len(tokens)
return int(num_tokens)
def read_stream(self, response):
buffer = ""
for chunk in response.iter_content(chunk_size=64):
if chunk:
try:
buffer += chunk.decode('utf-8')
except UnicodeDecodeError:
continue
while "\n" in buffer:
line, buffer = buffer.split("\n", 1)
if line:
try:
json_data = json.loads(line)
yield json_data["message"]["content"]
except json.JSONDecodeError:
continue
def generate(self, query, stream=False):
self.messages.append({"role": "user", "content": query})
# Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer
num_tokens = self.count_tokens() + self.max_length_answer / 2
if num_tokens < 4096:
del self.options["num_ctx"]
else:
self.options["num_ctx"] = num_tokens
headers = {"Content-Type": "application/json"}
data = {
"messages": self.messages,
"stream": stream,
"keep_alive": 3600 * 24 * 7,
"model": self.model,
"options": self.options,
}
# Ensure the environment variable is correctly referenced
response = requests.post(
os.getenv("LLM_API_URL"),
headers=headers,
json=data,
auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')),
stream=stream,
timeout= 3600,
)
if response.status_code == 404:
return "Target endpoint not found"
if response.status_code == 504:
return f"Gateway Timeout: {response.content.decode('utf-8')}"
if stream:
return self.read_stream(response)
else:
try:
result = response.json()["message"]["content"]
except requests.exceptions.JSONDecodeError:
print_red("Error: ", response.status_code, response.text)
return "An error occurred."
self.messages.append({"role": "assistant", "content": result})
if not self.chat:
self.messages = self.messages[:1]
return result
if __name__ == "__main__":
llm = LLM()
for chunk in llm.generate("Why is the sky red?", stream=True):
print(chunk, end='', flush=True)