diff --git a/_llm.py b/_llm.py index 60db3f4..42108f4 100644 --- a/_llm.py +++ b/_llm.py @@ -1,37 +1,103 @@ -from ollama import Client import os +import requests +from requests.auth import HTTPBasicAuth +import tiktoken +import json import env_manager +from colorprinter.print_color import * -env_manager.set_env() - +tokenizer = tiktoken.get_encoding("cl100k_base") class LLM: def __init__( - self, system_message=None, num_ctx=20000, temperature=0, chat=True + self, system_message='You are an assistant.', num_ctx=4096, temperature=0, chat=True, model='standard', max_length_answer=4000 ) -> None: - self.llm_model = "mistral-nemo:12b-instruct-2407-q5_K_M" #os.getenv("LLM_MODEL") + if model == 'standard': + self.model = os.getenv("LLM_MODEL") + if model == 'small': + self.model = os.getenv('LLM_MODEL_SMALL') self.system_message = system_message self.options = {"temperature": temperature, "num_ctx": num_ctx} self.messages = [{"role": "system", "content": self.system_message}] self.chat = chat - self.ollama = Client( - host=f'{os.getenv("LLM_URL")}:{os.getenv("LLM_PORT")}', - ) + self.max_length_answer = max_length_answer + + def count_tokens(self): + num_tokens = 0 + for i in self.messages: + for k, v in i.items(): + if k == "content": + if not isinstance(v, str): + v = str(v) + tokens = tokenizer.encode(v) + num_tokens += len(tokens) + return int(num_tokens) + + def read_stream(self, response): + buffer = "" + for chunk in response.iter_content(chunk_size=64): + if chunk: + buffer += chunk.decode('utf-8') + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + if line: + try: + json_data = json.loads(line) + yield json_data["message"]["content"] + except json.JSONDecodeError: + continue + def generate(self, query, stream=False): + self.messages.append({"role": "user", "content": query}) + + # Set the number of tokens to be the sum of the tokens in the messages and half of the max length of the answer + num_tokens = self.count_tokens() + self.max_length_answer / 2 + if num_tokens < 4096: + del self.options["num_ctx"] + else: + self.options["num_ctx"] = num_tokens + + headers = {"Content-Type": "application/json"} + data = { + "messages": self.messages, + "stream": stream, + "keep_alive": 3600 * 24 * 7, + "model": self.model, + "options": self.options, + } + + # Ensure the environment variable is correctly referenced - def generate(self, prompt: str) -> str: - self.messages.append({"role": "user", "content": prompt}) - result = self.ollama.chat( - model=self.llm_model, messages=self.messages, options=self.options + response = requests.post( + os.getenv("LLM_API_URL"), + headers=headers, + json=data, + auth=HTTPBasicAuth(os.getenv('LLM_API_USER'), os.getenv('LLM_API_PWD_LASSE')), + stream=stream, ) - answer = result["message"]["content"] - self.messages.append({"role": "assistant", "content": answer}) - if not self.chat: - self.messages = [{"role": "system", "content": self.system_message}] + if response.status_code == 404: + return "Target endpoint not found" + + if response.status_code == 504: + return f"Gateway Timeout: {response.content.decode('utf-8')}" - return answer + if stream: + return self.read_stream(response) + + else: + try: + result = response.json()["message"]["content"] + except requests.exceptions.JSONDecodeError: + print_red("Error: ", response.status_code, response.text) + return "An error occurred." + + self.messages.append({"role": "assistant", "content": result}) + if not self.chat: + self.messages = self.messages[:1] + return result if __name__ == "__main__": llm = LLM() - print(llm.generate("Why is the sky red?")) \ No newline at end of file + for chunk in llm.generate("Why is the sky red?", stream=True): + print(chunk, end='', flush=True) \ No newline at end of file