diff --git a/_llm/llm.py b/_llm/llm.py index a33a220..4705599 100644 --- a/_llm/llm.py +++ b/_llm/llm.py @@ -13,7 +13,10 @@ from ollama import ( ) import backoff import env_manager -from colorprinter.print_color import * +try: + from colorprinter.colorprinter.print_color import * +except ImportError: + from colorprinter.print_color import * env_manager.set_env() @@ -51,9 +54,9 @@ class LLM: system_message: str = "You are an assistant.", temperature: float = 0.01, model: Optional[ - Literal["small", "standard", "vision", "reasoning", "tools"] + Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"] ] = "standard", - max_length_answer: int = 4096, + max_length_answer: int = 8000, messages: list[dict] = None, chat: bool = True, chosen_backend: str = None, @@ -61,6 +64,7 @@ class LLM: think: bool = False, timeout: int = 240, local_available: bool = False, + on_vpn: bool = False, ) -> None: """ Initialize the assistant with the given parameters. @@ -68,12 +72,13 @@ class LLM: Args: system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.". temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01. - model (Optional[Literal["small", "standard", "vision", "reasoning"]]): The model type to use. Defaults to "standard". + model (Optional[Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"]]): The model type to use. Defaults to "standard". max_length_answer (int): The maximum length of the generated answer. Defaults to 4096. messages (list[dict], optional): A list of initial messages. Defaults to None. chat (bool): Whether the assistant is in chat mode. Defaults to True. chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen. think (bool): Whether to use thinking mode for reasoning models. Defaults to False. + on_vpn (bool): Whether the connection is over VPN and a local path to server can be used. Defaults to False. Returns: None @@ -91,7 +96,6 @@ class LLM: self.think = think self.tools = tools or [] self.local_available = local_available - self.chosen_backend = chosen_backend headers = { @@ -101,7 +105,12 @@ class LLM: if self.chosen_backend: headers["X-Chosen-Backend"] = self.chosen_backend - self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") + # If connected over VPN + self.on_vpn = True + if on_vpn: + self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}" + else: + self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") self.client: Client = Client(host=self.host_url, headers=headers, timeout=timeout) self.async_client: AsyncClient = AsyncClient() @@ -118,9 +127,9 @@ class LLM: "standard_64k": "LLM_MODEL_LARGE", "reasoning": "LLM_MODEL_REASONING", "tools": "LLM_MODEL_TOOLS", + "embeddings": "LLM_MODEL_EMBEDDINGS", } model = os.getenv(models.get(model_alias, "LLM_MODEL")) - print_purple(f"Using model: {model}") return model def count_tokens(self): @@ -135,9 +144,16 @@ class LLM: return int(num_tokens) def _prepare_messages_and_model( - self, query, user_input, context, messages, images, model + self, query, user_input, context, messages, images, model, tools=None ): """Prepare messages and select the appropriate model, handling images if present.""" + + if model == "embeddings": + self.messages = [{"role": "user", "content": query}] + model = self.get_model("embeddings") + print_red(f"Using embeddings model: {model}") + return model + if messages: messages = [ {"role": i["role"], "content": re.sub(r"\s*\n\s*", "\n", i["content"])} @@ -154,6 +170,7 @@ class LLM: if images: message = self.prepare_images(images, message) model = self.get_model("vision") + print_blue(f"Using vision model: {model}") else: if model in [ "small", @@ -182,6 +199,8 @@ class LLM: headers["X-Model-Type"] = "small" if model == self.get_model("tools"): headers["X-Model-Type"] = "tools" + if model == self.get_model("embeddings"): + headers["X-Model-Type"] = "embeddings" # No longer need to modify message content for thinking - handled by native API return headers @@ -209,7 +228,26 @@ class LLM: """Call the remote Ollama API synchronously.""" self.call_model = model self.client: Client = Client(host=self.host_url, headers=headers, timeout=300) - print_yellow(f"🤖 Generating using {model} (remote)...") + if self.on_vpn: + print_yellow(f"🤖 Generating using {model} (remote, on VPN)...") + else: + print_yellow(f"🤖 Generating using {model} (remote)...") + + # If this is an embeddings model, call the embed endpoint instead of chat. + if model == self.get_model("embeddings"): + # Find the last user message content to embed + input_text = "" + for m in reversed(self.messages): + if m.get("role") == "user" and m.get("content"): + input_text = m["content"] + break + if not input_text and self.messages: + input_text = self.messages[-1].get("content", "") + + # Use the embed API (synchronous) + response = self.client.embed(model=model, input=input_text, keep_alive=3600 * 24 * 7) + return response + response = self.client.chat( model=model, messages=self.messages, @@ -237,6 +275,20 @@ class LLM: ): """Call the remote Ollama API asynchronously.""" print_yellow(f"🤖 Generating using {model} (remote, async)...") + + # If embedding model, use async embed endpoint + if model == self.get_model("embeddings"): + input_text = "" + for m in reversed(self.messages): + if m.get("role") == "user" and m.get("content"): + input_text = m["content"] + break + if not input_text and self.messages: + input_text = self.messages[-1].get("content", "") + + response = await self.async_client.embed(model=model, input=input_text) + return response + response = await self.async_client.chat( model=model, messages=self.messages, @@ -332,6 +384,7 @@ class LLM: # f"Retrying due to error: {details['exception']}" # ) # ) + async def _call_local_ollama_async(self, model, stream, temperature, think=False): """Call the local Ollama instance asynchronously (using a thread pool).""" import ollama @@ -424,7 +477,7 @@ class LLM: tools: list = None, images: list = None, model: Optional[ - Literal["small", "standard", "vision", "reasoning", "tools"] + Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"] ] = None, temperature: float = None, messages: list[dict] = None, @@ -443,7 +496,7 @@ class LLM: stream (bool, optional): Whether to stream the response. Defaults to False. tools (list, optional): List of tools to make available for the model. images (list, optional): List of images to include in the request. - model (Literal["small", "standard", "vision", "reasoning", "tools"], optional): + model (Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"], optional): The model type to use. Defaults to "standard". temperature (float, optional): Temperature parameter for generation randomness. Uses instance default if not provided. @@ -472,6 +525,7 @@ class LLM: model = self._prepare_messages_and_model( query, user_input, context, messages, images, model ) + print_red(model) temperature = temperature if temperature else self.options["temperature"] if think is None: think = self.think @@ -482,6 +536,11 @@ class LLM: response = self._call_remote_api( model, tools, stream, options, format, headers, think=think ) + + # If using embeddings model, the response is an embed result (not a ChatResponse). + if model == self.get_model("embeddings"): + return response + if stream: return self.read_stream(response) else: @@ -554,12 +613,16 @@ class LLM: # First try with remote API if not force_local: try: - headers = self._build_headers(model, tools, think) + headers = self._build_headers(model) options = self._get_options(temperature) response = await self._call_remote_api_async( model, tools, stream, options, format, headers, think=think ) + # If using embeddings model, return the embed response directly + if model == self.get_model("embeddings"): + return response + if stream: return self.read_stream(response) else: @@ -672,11 +735,13 @@ class LLM: import base64 base64_images = [] + # base64 pattern: must be divisible by 4, only valid chars, and proper padding base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") for image in images: if isinstance(image, str): - if base64_pattern.match(image): + # If it looks like base64, just pass it through + if base64_pattern.match(image) and len(image) % 4 == 0: base64_images.append(image) else: with open(image, "rb") as image_file: @@ -688,8 +753,8 @@ class LLM: else: print_red("Invalid image type") - message["images"] = base64_images - return message + message["images"] = base64_images + return message if __name__ == "__main__":