import os import requests from requests.auth import HTTPBasicAuth import tiktoken import json import env_manager from colorprinter.print_color import * tokenizer = tiktoken.get_encoding("cl100k_base") class LLM: def __init__( self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000 ) -> None: if model == 'standard': self.model = os.getenv("LLM_MODEL") if model == 'small': self.model = os.getenv('LLM_MODEL_SMALL') self.system_message = system_message self.options = {"temperature": temperature, "num_ctx": num_ctx} self.messages = [{"role": "system", "content": self.system_message}] self.chat = chat self.max_length_answer = max_length_answer def count_tokens(self): num_tokens = 0 for i in self.messages: for k, v in i.items(): if k == "content": if not isinstance(v, str): v = str(v) tokens = tokenizer.encode(v) num_tokens += len(tokens) return int(num_tokens) def read_stream(self, response): buffer = "" for chunk in response.iter_content(chunk_size=64): if chunk: try: buffer += chunk.decode('utf-8') except UnicodeDecodeError: continue while "\n" in buffer: line, buffer = buffer.split("\n", 1) if line: try: json_data = json.loads(line) yield json_data["message"]["content"] except json.JSONDecodeError: continue def generate(self, query, stream=False): self.messages.append({"role": "user", "content": query}) # Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer num_tokens = self.count_tokens() + self.max_length_answer / 2 if num_tokens < 4096: del self.options["num_ctx"] else: self.options["num_ctx"] = num_tokens headers = {"Content-Type": "application/json"} data = { "messages": self.messages, "stream": stream, "keep_alive": 3600 * 24 * 7, "model": self.model, "options": self.options, } # Ensure the environment variable is correctly referenced response = requests.post( os.getenv("LLM_API_URL"), headers=headers, json=data, auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')), stream=stream, timeout= 3600, ) if response.status_code == 404: return "Target endpoint not found" if response.status_code == 504: return f"Gateway Timeout: {response.content.decode('utf-8')}" if stream: return self.read_stream(response) else: try: result = response.json()["message"]["content"] except requests.exceptions.JSONDecodeError: print_red("Error: ", response.status_code, response.text) return "An error occurred." self.messages.append({"role": "assistant", "content": result}) if not self.chat: self.messages = self.messages[:1] return result if __name__ == "__main__": llm = LLM() for chunk in llm.generate("Why is the sky red?", stream=True): print(chunk, end='', flush=True)