Enhance LLM class: add support for embeddings model, update model options, and improve VPN handling

legacy
lasseedfast 5 months ago
parent 12b7c5ba4d
commit 4567ed2752
  1. 95
      _llm/llm.py

@ -13,7 +13,10 @@ from ollama import (
) )
import backoff import backoff
import env_manager import env_manager
from colorprinter.print_color import * try:
from colorprinter.colorprinter.print_color import *
except ImportError:
from colorprinter.print_color import *
env_manager.set_env() env_manager.set_env()
@ -51,9 +54,9 @@ class LLM:
system_message: str = "You are an assistant.", system_message: str = "You are an assistant.",
temperature: float = 0.01, temperature: float = 0.01,
model: Optional[ model: Optional[
Literal["small", "standard", "vision", "reasoning", "tools"] Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"]
] = "standard", ] = "standard",
max_length_answer: int = 4096, max_length_answer: int = 8000,
messages: list[dict] = None, messages: list[dict] = None,
chat: bool = True, chat: bool = True,
chosen_backend: str = None, chosen_backend: str = None,
@ -61,6 +64,7 @@ class LLM:
think: bool = False, think: bool = False,
timeout: int = 240, timeout: int = 240,
local_available: bool = False, local_available: bool = False,
on_vpn: bool = False,
) -> None: ) -> None:
""" """
Initialize the assistant with the given parameters. Initialize the assistant with the given parameters.
@ -68,12 +72,13 @@ class LLM:
Args: Args:
system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.". system_message (str): The initial system message for the assistant. Defaults to "You are an assistant.".
temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01. temperature (float): The temperature setting for the model, affecting randomness. Defaults to 0.01.
model (Optional[Literal["small", "standard", "vision", "reasoning"]]): The model type to use. Defaults to "standard". model (Optional[Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"]]): The model type to use. Defaults to "standard".
max_length_answer (int): The maximum length of the generated answer. Defaults to 4096. max_length_answer (int): The maximum length of the generated answer. Defaults to 4096.
messages (list[dict], optional): A list of initial messages. Defaults to None. messages (list[dict], optional): A list of initial messages. Defaults to None.
chat (bool): Whether the assistant is in chat mode. Defaults to True. chat (bool): Whether the assistant is in chat mode. Defaults to True.
chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen. chosen_backend (str, optional): The backend server to use. If not provided, the least connected server is chosen.
think (bool): Whether to use thinking mode for reasoning models. Defaults to False. think (bool): Whether to use thinking mode for reasoning models. Defaults to False.
on_vpn (bool): Whether the connection is over VPN and a local path to server can be used. Defaults to False.
Returns: Returns:
None None
@ -91,7 +96,6 @@ class LLM:
self.think = think self.think = think
self.tools = tools or [] self.tools = tools or []
self.local_available = local_available self.local_available = local_available
self.chosen_backend = chosen_backend self.chosen_backend = chosen_backend
headers = { headers = {
@ -101,7 +105,12 @@ class LLM:
if self.chosen_backend: if self.chosen_backend:
headers["X-Chosen-Backend"] = self.chosen_backend headers["X-Chosen-Backend"] = self.chosen_backend
self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") # If connected over VPN
self.on_vpn = True
if on_vpn:
self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}"
else:
self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
self.client: Client = Client(host=self.host_url, headers=headers, timeout=timeout) self.client: Client = Client(host=self.host_url, headers=headers, timeout=timeout)
self.async_client: AsyncClient = AsyncClient() self.async_client: AsyncClient = AsyncClient()
@ -118,9 +127,9 @@ class LLM:
"standard_64k": "LLM_MODEL_LARGE", "standard_64k": "LLM_MODEL_LARGE",
"reasoning": "LLM_MODEL_REASONING", "reasoning": "LLM_MODEL_REASONING",
"tools": "LLM_MODEL_TOOLS", "tools": "LLM_MODEL_TOOLS",
"embeddings": "LLM_MODEL_EMBEDDINGS",
} }
model = os.getenv(models.get(model_alias, "LLM_MODEL")) model = os.getenv(models.get(model_alias, "LLM_MODEL"))
print_purple(f"Using model: {model}")
return model return model
def count_tokens(self): def count_tokens(self):
@ -135,9 +144,16 @@ class LLM:
return int(num_tokens) return int(num_tokens)
def _prepare_messages_and_model( def _prepare_messages_and_model(
self, query, user_input, context, messages, images, model self, query, user_input, context, messages, images, model, tools=None
): ):
"""Prepare messages and select the appropriate model, handling images if present.""" """Prepare messages and select the appropriate model, handling images if present."""
if model == "embeddings":
self.messages = [{"role": "user", "content": query}]
model = self.get_model("embeddings")
print_red(f"Using embeddings model: {model}")
return model
if messages: if messages:
messages = [ messages = [
{"role": i["role"], "content": re.sub(r"\s*\n\s*", "\n", i["content"])} {"role": i["role"], "content": re.sub(r"\s*\n\s*", "\n", i["content"])}
@ -154,6 +170,7 @@ class LLM:
if images: if images:
message = self.prepare_images(images, message) message = self.prepare_images(images, message)
model = self.get_model("vision") model = self.get_model("vision")
print_blue(f"Using vision model: {model}")
else: else:
if model in [ if model in [
"small", "small",
@ -182,6 +199,8 @@ class LLM:
headers["X-Model-Type"] = "small" headers["X-Model-Type"] = "small"
if model == self.get_model("tools"): if model == self.get_model("tools"):
headers["X-Model-Type"] = "tools" headers["X-Model-Type"] = "tools"
if model == self.get_model("embeddings"):
headers["X-Model-Type"] = "embeddings"
# No longer need to modify message content for thinking - handled by native API # No longer need to modify message content for thinking - handled by native API
return headers return headers
@ -209,7 +228,26 @@ class LLM:
"""Call the remote Ollama API synchronously.""" """Call the remote Ollama API synchronously."""
self.call_model = model self.call_model = model
self.client: Client = Client(host=self.host_url, headers=headers, timeout=300) self.client: Client = Client(host=self.host_url, headers=headers, timeout=300)
print_yellow(f"🤖 Generating using {model} (remote)...") if self.on_vpn:
print_yellow(f"🤖 Generating using {model} (remote, on VPN)...")
else:
print_yellow(f"🤖 Generating using {model} (remote)...")
# If this is an embeddings model, call the embed endpoint instead of chat.
if model == self.get_model("embeddings"):
# Find the last user message content to embed
input_text = ""
for m in reversed(self.messages):
if m.get("role") == "user" and m.get("content"):
input_text = m["content"]
break
if not input_text and self.messages:
input_text = self.messages[-1].get("content", "")
# Use the embed API (synchronous)
response = self.client.embed(model=model, input=input_text, keep_alive=3600 * 24 * 7)
return response
response = self.client.chat( response = self.client.chat(
model=model, model=model,
messages=self.messages, messages=self.messages,
@ -237,6 +275,20 @@ class LLM:
): ):
"""Call the remote Ollama API asynchronously.""" """Call the remote Ollama API asynchronously."""
print_yellow(f"🤖 Generating using {model} (remote, async)...") print_yellow(f"🤖 Generating using {model} (remote, async)...")
# If embedding model, use async embed endpoint
if model == self.get_model("embeddings"):
input_text = ""
for m in reversed(self.messages):
if m.get("role") == "user" and m.get("content"):
input_text = m["content"]
break
if not input_text and self.messages:
input_text = self.messages[-1].get("content", "")
response = await self.async_client.embed(model=model, input=input_text)
return response
response = await self.async_client.chat( response = await self.async_client.chat(
model=model, model=model,
messages=self.messages, messages=self.messages,
@ -332,6 +384,7 @@ class LLM:
# f"Retrying due to error: {details['exception']}" # f"Retrying due to error: {details['exception']}"
# ) # )
# ) # )
async def _call_local_ollama_async(self, model, stream, temperature, think=False): async def _call_local_ollama_async(self, model, stream, temperature, think=False):
"""Call the local Ollama instance asynchronously (using a thread pool).""" """Call the local Ollama instance asynchronously (using a thread pool)."""
import ollama import ollama
@ -424,7 +477,7 @@ class LLM:
tools: list = None, tools: list = None,
images: list = None, images: list = None,
model: Optional[ model: Optional[
Literal["small", "standard", "vision", "reasoning", "tools"] Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"]
] = None, ] = None,
temperature: float = None, temperature: float = None,
messages: list[dict] = None, messages: list[dict] = None,
@ -443,7 +496,7 @@ class LLM:
stream (bool, optional): Whether to stream the response. Defaults to False. stream (bool, optional): Whether to stream the response. Defaults to False.
tools (list, optional): List of tools to make available for the model. tools (list, optional): List of tools to make available for the model.
images (list, optional): List of images to include in the request. images (list, optional): List of images to include in the request.
model (Literal["small", "standard", "vision", "reasoning", "tools"], optional): model (Literal["small", "standard", "vision", "reasoning", "tools", "embeddings"], optional):
The model type to use. Defaults to "standard". The model type to use. Defaults to "standard".
temperature (float, optional): Temperature parameter for generation randomness. temperature (float, optional): Temperature parameter for generation randomness.
Uses instance default if not provided. Uses instance default if not provided.
@ -472,6 +525,7 @@ class LLM:
model = self._prepare_messages_and_model( model = self._prepare_messages_and_model(
query, user_input, context, messages, images, model query, user_input, context, messages, images, model
) )
print_red(model)
temperature = temperature if temperature else self.options["temperature"] temperature = temperature if temperature else self.options["temperature"]
if think is None: if think is None:
think = self.think think = self.think
@ -482,6 +536,11 @@ class LLM:
response = self._call_remote_api( response = self._call_remote_api(
model, tools, stream, options, format, headers, think=think model, tools, stream, options, format, headers, think=think
) )
# If using embeddings model, the response is an embed result (not a ChatResponse).
if model == self.get_model("embeddings"):
return response
if stream: if stream:
return self.read_stream(response) return self.read_stream(response)
else: else:
@ -554,12 +613,16 @@ class LLM:
# First try with remote API # First try with remote API
if not force_local: if not force_local:
try: try:
headers = self._build_headers(model, tools, think) headers = self._build_headers(model)
options = self._get_options(temperature) options = self._get_options(temperature)
response = await self._call_remote_api_async( response = await self._call_remote_api_async(
model, tools, stream, options, format, headers, think=think model, tools, stream, options, format, headers, think=think
) )
# If using embeddings model, return the embed response directly
if model == self.get_model("embeddings"):
return response
if stream: if stream:
return self.read_stream(response) return self.read_stream(response)
else: else:
@ -672,11 +735,13 @@ class LLM:
import base64 import base64
base64_images = [] base64_images = []
# base64 pattern: must be divisible by 4, only valid chars, and proper padding
base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$") base64_pattern = re.compile(r"^[A-Za-z0-9+/]+={0,2}$")
for image in images: for image in images:
if isinstance(image, str): if isinstance(image, str):
if base64_pattern.match(image): # If it looks like base64, just pass it through
if base64_pattern.match(image) and len(image) % 4 == 0:
base64_images.append(image) base64_images.append(image)
else: else:
with open(image, "rb") as image_file: with open(image, "rb") as image_file:
@ -688,8 +753,8 @@ class LLM:
else: else:
print_red("Invalid image type") print_red("Invalid image type")
message["images"] = base64_images message["images"] = base64_images
return message return message
if __name__ == "__main__": if __name__ == "__main__":

Loading…
Cancel
Save