Enhance LLM class: add local availability and timeout parameters, update model handling, and refactor header building

legacy
lasseedfast 8 months ago
parent a5a7034bbc
commit 12b7c5ba4d
  1. 87
      _llm/llm.py

@ -59,6 +59,8 @@ class LLM:
chosen_backend: str = None, chosen_backend: str = None,
tools: list = None, tools: list = None,
think: bool = False, think: bool = False,
timeout: int = 240,
local_available: bool = False,
) -> None: ) -> None:
""" """
Initialize the assistant with the given parameters. Initialize the assistant with the given parameters.
@ -86,6 +88,9 @@ class LLM:
self.messages = messages or [{"role": "system", "content": self.system_message}] self.messages = messages or [{"role": "system", "content": self.system_message}]
self.max_length_answer = max_length_answer self.max_length_answer = max_length_answer
self.chat = chat self.chat = chat
self.think = think
self.tools = tools or []
self.local_available = local_available
self.chosen_backend = chosen_backend self.chosen_backend = chosen_backend
@ -97,7 +102,7 @@ class LLM:
headers["X-Chosen-Backend"] = self.chosen_backend headers["X-Chosen-Backend"] = self.chosen_backend
self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/")
self.client: Client = Client(host=self.host_url, headers=headers, timeout=120) self.client: Client = Client(host=self.host_url, headers=headers, timeout=timeout)
self.async_client: AsyncClient = AsyncClient() self.async_client: AsyncClient = AsyncClient()
def get_credentials(self): def get_credentials(self):
@ -150,12 +155,21 @@ class LLM:
message = self.prepare_images(images, message) message = self.prepare_images(images, message)
model = self.get_model("vision") model = self.get_model("vision")
else: else:
model = self.get_model(model) if model in [
"small",
"standard",
"standard_64k",
"reasoning",
"tools",
]:
model = self.get_model(model)
self.messages.append(message) self.messages.append(message)
return model return model
def _build_headers(self, model, tools, think): def _build_headers(self, model):
"""Build HTTP headers for API requests, including auth and backend/model info.""" """Build HTTP headers for API requests, including auth and backend/model info."""
headers = {"Authorization": f"Basic {self.get_credentials()}"} headers = {"Authorization": f"Basic {self.get_credentials()}"}
if self.chosen_backend and model not in [ if self.chosen_backend and model not in [
@ -179,16 +193,16 @@ class LLM:
) )
return options return options
@backoff.on_exception( # @backoff.on_exception(
backoff.expo, # backoff.expo,
(ResponseError, TimeoutError), # (ResponseError, TimeoutError),
max_tries=3, # max_tries=3,
factor=2, # factor=2,
base=10, # base=10,
on_backoff=lambda details: print_yellow( # on_backoff=lambda details: print_yellow(
f"Retrying due to error: {details['exception']}" # f"Retrying due to error: {details['exception']}"
) # )
) # )
def _call_remote_api( def _call_remote_api(
self, model, tools, stream, options, format, headers, think=False self, model, tools, stream, options, format, headers, think=False
): ):
@ -308,16 +322,16 @@ class LLM:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
return response_obj.message return response_obj.message
@backoff.on_exception( # @backoff.on_exception(
backoff.expo, # backoff.expo,
(ResponseError, TimeoutError), # (ResponseError, TimeoutError),
max_tries=3, # max_tries=3,
factor=2, # factor=2,
base=10, # base=10,
on_backoff=lambda details: print_yellow( # on_backoff=lambda details: print_yellow(
f"Retrying due to error: {details['exception']}" # f"Retrying due to error: {details['exception']}"
) # )
) # )
async def _call_local_ollama_async(self, model, stream, temperature, think=False): async def _call_local_ollama_async(self, model, stream, temperature, think=False):
"""Call the local Ollama instance asynchronously (using a thread pool).""" """Call the local Ollama instance asynchronously (using a thread pool)."""
import ollama import ollama
@ -411,11 +425,11 @@ class LLM:
images: list = None, images: list = None,
model: Optional[ model: Optional[
Literal["small", "standard", "vision", "reasoning", "tools"] Literal["small", "standard", "vision", "reasoning", "tools"]
] = "standard", ] = None,
temperature: float = None, temperature: float = None,
messages: list[dict] = None, messages: list[dict] = None,
format=None, format=None,
think=False, think=None,
force_local: bool = False, force_local: bool = False,
): ):
""" """
@ -435,9 +449,10 @@ class LLM:
Uses instance default if not provided. Uses instance default if not provided.
messages (list[dict], optional): Pre-formatted message history. messages (list[dict], optional): Pre-formatted message history.
format (optional): Response format specification. format (optional): Response format specification.
think (bool, optional): Whether to enable thinking mode. Defaults to False. think (bool, optional): Whether to enable thinking mode. Defaults to None.
force_local (bool, optional): Force use of local Ollama instead of remote API. force_local (bool, optional): Force use of local Ollama instead of remote API.
Defaults to False. Defaults to False.
local_available (bool, optional): Whether local Ollama is available.
Returns: Returns:
The generated response. Type varies based on stream parameter and success: The generated response. Type varies based on stream parameter and success:
@ -450,13 +465,19 @@ class LLM:
Prints stack trace for exceptions but doesn't propagate them, instead Prints stack trace for exceptions but doesn't propagate them, instead
returning error messages or attempting fallback to local processing. returning error messages or attempting fallback to local processing.
""" """
if model is None and self.model:
model = self.model
elif model is None:
model = "standard"
model = self._prepare_messages_and_model( model = self._prepare_messages_and_model(
query, user_input, context, messages, images, model query, user_input, context, messages, images, model
) )
temperature = temperature if temperature else self.options["temperature"] temperature = temperature if temperature else self.options["temperature"]
if think is None:
think = self.think
if not force_local: if not force_local:
try: try:
headers = self._build_headers(model, tools, think) headers = self._build_headers(model)
options = self._get_options(temperature) options = self._get_options(temperature)
response = self._call_remote_api( response = self._call_remote_api(
model, tools, stream, options, format, headers, think=think model, tools, stream, options, format, headers, think=think
@ -480,11 +501,13 @@ class LLM:
return "An error occurred." return "An error occurred."
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
try:
return self._call_local_ollama(model, stream, temperature, think=think) if self.local_available:
except Exception as e: try:
traceback.print_exc() return self._call_local_ollama(model, stream, temperature, think=think)
return "Both remote API and local Ollama failed. An error occurred." except Exception as e:
traceback.print_exc()
return "Both remote API and local Ollama failed. An error occurred."
async def async_generate( async def async_generate(
self, self,

Loading…
Cancel
Save