diff --git a/_llm/llm.py b/_llm/llm.py
index 6f78725..6b8e6d8 100644
--- a/_llm/llm.py
+++ b/_llm/llm.py
@@ -49,8 +49,62 @@ except ImportError:
env_manager.set_env()
+
tokenizer = tiktoken.get_encoding("cl100k_base")
+_QUIRKS_FILE = os.path.join(os.path.dirname(__file__), "provider_quirks.json")
+
+# Maps 400-error substrings to the retry action to take.
+# "drop:" removes that kwarg from the request and retries.
+# thought_signature errors are fixed at the message-history level in chat.py,
+# not via a parameter drop, so they are not listed here.
+_ERROR_RETRY_RULES: list[tuple[str, str]] = [
+ ("Penalty is not enabled", "drop:presence_penalty"),
+]
+
+
+class _ProviderQuirks:
+ """Persists per-model param drops discovered at runtime so we skip them from the first call."""
+
+ def __init__(self):
+ self._data: dict[str, list[str]] = {} # model -> list of actions to apply
+ self._load()
+
+ def _load(self):
+ try:
+ with open(_QUIRKS_FILE) as f:
+ self._data = json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError):
+ self._data = {}
+
+ def _save(self):
+ try:
+ with open(_QUIRKS_FILE, "w") as f:
+ json.dump(self._data, f, indent=2)
+ except Exception:
+ pass
+
+ def get_actions(self, model: str) -> list[str]:
+ return list(self._data.get(model, []))
+
+ def record(self, model: str, action: str):
+ if action not in self._data.get(model, []):
+ self._data.setdefault(model, []).append(action)
+ self._save()
+
+
+_quirks = _ProviderQuirks()
+
+
+def _apply_quirk_action(action: str, kwargs: dict, silent: bool = False):
+ """Mutate kwargs in-place according to a quirk action string."""
+ if action.startswith("drop:"):
+ param = action[5:]
+ if param in kwargs:
+ kwargs.pop(param)
+ if not silent:
+ print_yellow(f"Skipping unsupported param '{param}' for this model.")
+
def _strip_think_tags(text: str) -> str:
"""
@@ -112,7 +166,7 @@ class LLM:
system_message: str = "You are an assistant.",
temperature: float = 0.01,
model: Optional[str] = None,
- max_length_answer: int = 8000,
+ max_length_answer: int = 3000,
messages: Optional[list[dict]] = None,
chat: bool = True,
tools: Optional[list] = None,
@@ -123,6 +177,8 @@ class LLM:
presence_penalty: float = 0.3,
top_p: float = 0.9,
extra_body: Optional[Dict[str, Any]] = None,
+ base_url: Optional[str] = None,
+ api_key: Optional[str] = None,
) -> None:
"""
Initialize the assistant wrapper.
@@ -142,6 +198,7 @@ class LLM:
presence_penalty: Default presence penalty forwarded with each request.
top_p: Default nucleus sampling value used for sampling (see vLLM sampling params).
extra_body: Additional sampler payload sent via `extra_body` (e.g. repetition penalties).
+ base_url: Override the host URL directly (takes precedence over on_vpn/env vars).
"""
self.model = self.get_model(model)
self.call_model = self.model
@@ -160,30 +217,41 @@ class LLM:
self.think = think
self.tools = tools or []
self.silent = silent
- headers = {
- "Authorization": f"Basic {self.get_credentials()}",
- }
self.on_vpn = on_vpn
- if self.on_vpn:
- self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/vllm/v1"
+ # Auth is always via Bearer token (LLM_BEARER), whether remote or local.
+ # Don't set Authorization in default_headers — the OpenAI client adds it
+ # automatically from api_key, and a second Authorization header in
+ # default_headers would cause a conflict that makes vllm-sr return 404.
+ self._api_key = api_key # user-supplied key; takes precedence over LLM_BEARER env var
+ headers = {}
+ if base_url:
+ self.host_url = base_url
+ elif self.on_vpn:
+ # On VPN: connect directly to the local vLLM server (no auth needed)
+ self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/v1"
else:
- self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") + "/vllm/v1"
+ # Remote: use the public HTTPS endpoint, auth is via Bearer token
+ self.host_url = os.getenv("LLM_API_URL").rstrip("/") + "/v1"
self._make_clients(headers, timeout)
def _make_clients(self, headers=None, timeout=300):
session_id = uuid.uuid4()
- headers = headers or {"Authorization": f"Basic {self.get_credentials()}", "X-Session-ID": str(session_id)}
+ if headers is None:
+ headers = {"X-Session-ID": str(session_id)}
+ # Use caller-supplied key if present, otherwise fall back to env.
+ # This supports per-request user keys for external providers (Berget, OpenAI).
+ api_key = self._api_key or os.getenv("LLM_BEARER", "NONE")
# Sync client
self.client: OpenAI = OpenAI(
base_url=self.host_url,
- api_key="NONE",
+ api_key=api_key,
default_headers=headers,
timeout=timeout,
)
- # Async client - fix the double /v1 issue
+ # Async client
self.async_client = AsyncOpenAI(
base_url=self.host_url,
- api_key="NONE",
+ api_key=api_key,
default_headers=headers,
timeout=timeout,
)
@@ -201,7 +269,7 @@ class LLM:
sensible defaults while still accepting fully qualified model names.
"""
default_model = os.getenv(
- "LLM_MODEL_VLLM", "qwen3-14b"
+ "LLM_MODEL_VLLM", "smart"
)
if model_alias in {None, "", "vllm"}:
return default_model
@@ -220,6 +288,68 @@ class LLM:
num_tokens += len(tokens)
return int(num_tokens)
+ def _count_messages_tokens(self, messages: List[Dict[str, Any]]) -> int:
+ num_tokens = 0
+ for msg in messages:
+ for k, v in msg.items():
+ if k == "content":
+ if not isinstance(v, str):
+ v = str(v)
+ num_tokens += len(tokenizer.encode(v))
+ return num_tokens
+
+ def _tokenize_message(self, model: str, text: str) -> int:
+ """Ask the vLLM /tokenize endpoint for the exact token count of a string."""
+ import urllib.request
+ # Derive base URL by stripping the trailing /v1
+ base = self.host_url.rstrip("/")
+ if base.endswith("/v1"):
+ base = base[:-3]
+ url = f"{base}/tokenize"
+ payload = json.dumps({"model": model, "prompt": text}).encode()
+ req = urllib.request.Request(url, data=payload, headers={"Content-Type": "application/json"})
+ try:
+ with urllib.request.urlopen(req, timeout=10) as resp:
+ data = json.loads(resp.read())
+ return len(data.get("tokens", []))
+ except Exception:
+ # Fall back to tiktoken estimate if the endpoint is unavailable
+ return len(tokenizer.encode(text))
+
+ def _trim_messages_to_fit_exact(self, messages: List[Dict[str, Any]], model: str, max_tokens: int) -> List[Dict[str, Any]]:
+ """Use the vLLM /tokenize endpoint to drop oldest non-system messages until they fit."""
+ context_window = MODEL_CONTEXT_WINDOWS.get(model, DEFAULT_CONTEXT_WINDOW)
+ # Reserve ~5 tokens per message for chat-template overhead (<|im_start|>role\n...<|im_end|>\n)
+ # that the /tokenize endpoint doesn't count when given raw content strings.
+ template_overhead = 5 * len(messages)
+ budget = context_window - max_tokens - template_overhead
+ if budget <= 0:
+ return messages
+
+ def _total_tokens(msgs):
+ total = 0
+ for m in msgs:
+ content = m.get("content", "") or ""
+ if not isinstance(content, str):
+ content = str(content)
+ total += self._tokenize_message(model, content)
+ return total
+
+ while True:
+ current = _total_tokens(messages)
+ if current <= budget:
+ break
+ drop_idx = next(
+ (i for i, m in enumerate(messages) if m.get("role") != "system"),
+ None,
+ )
+ if drop_idx is None:
+ break
+ if not self.silent:
+ print_yellow(f"Context too long ({current} > {budget} tokens), dropping message at index {drop_idx}.")
+ messages = messages[:drop_idx] + messages[drop_idx + 1:]
+ return messages
+
def _prepare_messages_and_model(
self,
query,
@@ -258,7 +388,14 @@ class LLM:
def _build_headers(self, model):
"""Build HTTP headers for API requests, including auth and optional backend hints."""
- headers = {"Authorization": f"Basic {self.get_credentials()}"}
+ if self.on_vpn:
+ # On VPN the local server doesn't require auth — use Basic as a soft hint only
+ headers = {"Authorization": f"Basic {self.get_credentials()}"}
+ else:
+ # Remote public endpoint requires Bearer token auth.
+ # The OpenAI SDK sets 'Authorization: Bearer ' automatically,
+ # so we don't put Authorization here to avoid conflicting headers.
+ headers = {}
if model == self.get_model("embeddings"):
headers["X-Model-Type"] = "embeddings"
return headers
@@ -451,15 +588,64 @@ class LLM:
)
return chat
+ def _api_call_with_backoff(self, kwargs: dict) -> "ChatCompletion":
+ """Call self.client.chat.completions.create with exponential backoff for transient errors."""
+ _retryable = {429, 502, 503, 529}
+ _max_retries = 6
+ _base_delay = 2.0
+ for _attempt in range(_max_retries + 1):
+ try:
+ return self.client.chat.completions.create(**kwargs)
+ except Exception as _e:
+ _status = getattr(_e, "status_code", None)
+ if _status not in _retryable or _attempt >= _max_retries:
+ raise
+ _headers = getattr(getattr(_e, "response", None), "headers", {}) or {}
+ _ra = _headers.get("retry-after") or _headers.get("Retry-After")
+ _delay = float(_ra) if _ra else _base_delay * (2 ** _attempt)
+ if not self.silent:
+ print_yellow(
+ f"Rate-limited (HTTP {_status}). Retrying in {_delay:.1f}s "
+ f"(attempt {_attempt + 1}/{_max_retries})…"
+ )
+ time.sleep(_delay)
+ raise RuntimeError("Exhausted retries due to rate limiting.")
+
def _call_remote_api(
- self, model, tools, stream, options, format, headers, think=False
+ self, model, tools, stream, options, format, headers, think=None
) -> ChatCompletion:
"""Call the remote vLLM-backed API using the OpenAI-compatible client."""
self.call_model = model
self.messages = self._sanitize_messages(self.messages)
+
+ # Resolve thinking flag: per-call overrides instance default.
+ # When thinking is off, inject chat_template_kwargs so vLLM skips CoT tokens entirely.
+ # chat_template_kwargs is vLLM-specific; skip it when a user-supplied key is present
+ # (which indicates an external provider like Berget or OpenAI).
+ effective_think = think if think is not None else self.think
+ if not effective_think and not self._api_key:
+ body = dict(options.get("extra_body") or {})
+ ctk = dict(body.get("chat_template_kwargs") or {})
+ ctk["enable_thinking"] = False
+ body["chat_template_kwargs"] = ctk
+ options = {**options, "extra_body": body}
+
+ # Strip vLLM-specific extra_body keys (repetition_penalty, chat_template_kwargs)
+ # when using an external provider — they are not part of the OpenAI spec and will
+ # cause 4xx errors on providers that reject unknown fields.
+ if self._api_key:
+ body = dict(options.get("extra_body") or {})
+ body.pop("repetition_penalty", None)
+ body.pop("chat_template_kwargs", None)
+ options = {**options, "extra_body": body or None}
+
+ # For remote calls the OpenAI SDK handles auth by sending
+ # 'Authorization: Bearer '. On VPN no token is needed.
+ # self._api_key (user-supplied) takes precedence over the LLM_BEARER env var.
+ api_key = self._api_key or (os.getenv("LLM_BEARER", "NONE") if not self.on_vpn else "NONE")
self.client = OpenAI(
base_url=f"{self.host_url}",
- api_key="ollama",
+ api_key=api_key,
default_headers=headers,
timeout=300,
)
@@ -479,17 +665,40 @@ class LLM:
m["role"] = "user"
m["content"] = f"Tool output:\n{m.get('content','')}"
- # print('FORMAT', format)
- response: ParsedResponse = self.client.responses.parse(
- input=messages,
+ # vLLM does not support the /responses endpoint, so we use
+ # chat.completions.create with response_format (JSON schema) instead.
+ json_schema_response = self.client.chat.completions.create(
model=model,
- top_p=options["top_p"],
+ messages=messages,
temperature=options["temperature"],
+ top_p=options["top_p"],
+ response_format={
+ "type": "json_schema",
+ "json_schema": {
+ "name": format.__name__,
+ "schema": format.model_json_schema(),
+ },
+ },
extra_body=options["extra_body"],
- stream=stream,
- text_format=format,
+ max_tokens=options["max_tokens"],
+ )
+ content_text = json_schema_response.choices[0].message.content
+ parsed_instance = format.model_validate_json(content_text)
+ message = ChatCompletionMessage.model_construct(
+ role="assistant",
+ content=parsed_instance,
+ )
+ setattr(message, "content_text", content_text)
+ setattr(message, "parsed", parsed_instance)
+ setattr(message, "parsed_dict", parsed_instance.model_dump())
+ choice = Choice.model_construct(index=0, finish_reason="stop", message=message)
+ response: ChatCompletion = ChatCompletion.model_construct(
+ id=json_schema_response.id,
+ choices=[choice],
+ created=json_schema_response.created,
+ model=json_schema_response.model,
+ object="chat.completion",
)
- response: ChatCompletion = self._parsed_content_normalizer(response, model_cls=format)
# Call the OpenAI API
else:
@@ -499,32 +708,63 @@ class LLM:
message["content"] = (
f"Tool output:\n{message.get('content','')}" # TODO Works? Other tools?
)
+ # OpenAI o-series and newer models use max_completion_tokens instead of max_tokens.
+ _token_kwarg = (
+ {"max_completion_tokens": options["max_tokens"]}
+ if "openai.com" in self.host_url
+ else {"max_tokens": options["max_tokens"]}
+ )
+ # frequency_penalty removed: repetition_penalty in extra_body already handles
+ # this. Having both causes compounding effects that push the model into token loops.
+ kwargs = {
+ "model": model,
+ "messages": self.messages,
+ "stream": stream,
+ "temperature": options["temperature"],
+ "presence_penalty": options["presence_penalty"],
+ "top_p": options["top_p"],
+ "tools": tools,
+ "extra_body": options["extra_body"],
+ **_token_kwarg,
+ }
+ # Apply any previously-discovered quirks for this model upfront.
+ for _action in _quirks.get_actions(model):
+ _apply_quirk_action(_action, kwargs, self.silent)
try:
- response: ChatCompletion = self.client.chat.completions.create(
- model=model,
- messages=self.messages,
- # frequency_penalty removed: repetition_penalty in extra_body
- # already handles this. Having both causes compounding effects
- # that push the model into token loops.
- stream=stream,
- temperature=options["temperature"],
- presence_penalty=options["presence_penalty"],
- top_p=options["top_p"],
- tools=tools,
- extra_body=options["extra_body"],
- max_tokens=options["max_tokens"],
- )
+ response: ChatCompletion = self._api_call_with_backoff(kwargs)
except Exception as e:
- import traceback
- traceback.print_exc()
- print_red(f"Error calling remote API: {e}")
- print('---MESSAGES---')
- print_rainbow(self.messages, single_line=True)
- print()
- print('TOOLS')
- print_rainbow(tools, single_line=True)
- # Re-raise the exception to inform the caller of the API failure
- raise
+ err_str = str(e)
+ if (
+ getattr(e, "status_code", None) == 400
+ and "maximum context length" in err_str
+ ):
+ # Use /tokenize for exact counts and drop messages until they fit
+ if not self.silent:
+ print_yellow("Server rejected: context too long. Trimming with exact token counts and retrying.")
+ self.messages = self._trim_messages_to_fit_exact(
+ self.messages, model, options.get("max_tokens", self.max_length_answer)
+ )
+ kwargs["messages"] = self.messages
+ response = self._api_call_with_backoff(kwargs)
+ elif (
+ getattr(e, "status_code", None) == 400
+ and "Penalty is not enabled" in err_str
+ ):
+ if not self.silent:
+ print_yellow("Provider rejected penalty params. Retrying without presence_penalty.")
+ _apply_quirk_action("drop:presence_penalty", kwargs, self.silent)
+ _quirks.record(model, "drop:presence_penalty")
+ response = self._api_call_with_backoff(kwargs)
+ else:
+ import traceback
+ traceback.print_exc()
+ print_red(f"Error calling remote API: {e}")
+ print('---MESSAGES---')
+ print_rainbow(self.messages, single_line=True)
+ print()
+ print('TOOLS')
+ print_rainbow(tools, single_line=True)
+ raise
# Try to extract backend information if available
try:
@@ -538,10 +778,20 @@ class LLM:
return response
async def _call_remote_api_async(
- self, model, tools, stream, options, format, headers, think=False
+ self, model, tools, stream, options, format, headers, think=None
):
"""Call the remote API asynchronously using OpenAI async client."""
self.messages = self._sanitize_messages(self.messages)
+
+ # Resolve thinking flag: per-call overrides instance default.
+ effective_think = think if think is not None else self.think
+ if not effective_think:
+ body = dict(options.get("extra_body") or {})
+ ctk = dict(body.get("chat_template_kwargs") or {})
+ ctk["enable_thinking"] = False
+ body["chat_template_kwargs"] = ctk
+ options = {**options, "extra_body": body}
+
# Update the async client with the latest headers
self.async_client = AsyncOpenAI(
base_url=self.host_url, # Remove the extra /v1
@@ -594,19 +844,59 @@ class LLM:
if m.get("role") not in {"user", "assistant"}:
m["role"] = "user"
m["content"] = f"Tool output:\n{m.get('content','')}"
- response = await self.async_client.responses.parse(
- input=messages,
+ # vLLM does not support the /responses endpoint, so we use
+ # chat.completions.create with response_format (JSON schema) instead.
+ json_schema_response = await self.async_client.chat.completions.create(
model=model,
- top_p=options["top_p"],
+ messages=messages,
temperature=options["temperature"],
- extra_body=options["extra_body"],
- stream=stream,
- text_format=format,
+ top_p=options["top_p"],
+ response_format={
+ "type": "json_schema",
+ "json_schema": {
+ "name": format.__name__,
+ "schema": format.model_json_schema(),
+ },
+ },
+ extra_body=options.get("extra_body"),
+ max_tokens=options.get("max_tokens"),
+ )
+ content_text = json_schema_response.choices[0].message.content
+ parsed_instance = format.model_validate_json(content_text)
+ message = ChatCompletionMessage.model_construct(
+ role="assistant",
+ content=parsed_instance,
+ )
+ setattr(message, "content_text", content_text)
+ setattr(message, "parsed", parsed_instance)
+ setattr(message, "parsed_dict", parsed_instance.model_dump())
+ choice = Choice.model_construct(index=0, finish_reason="stop", message=message)
+ return ChatCompletion.model_construct(
+ id=json_schema_response.id,
+ choices=[choice],
+ created=json_schema_response.created,
+ model=json_schema_response.model,
+ object="chat.completion",
)
- return self._parsed_content_normalizer(response, model_cls=format)
if tools:
kwargs["tools"] = tools
- response = await self.async_client.chat.completions.create(**kwargs)
+ # Apply any previously-discovered quirks for this model upfront.
+ for _action in _quirks.get_actions(model):
+ _apply_quirk_action(_action, kwargs, self.silent)
+ try:
+ response = await self.async_client.chat.completions.create(**kwargs)
+ except Exception as e:
+ if (
+ getattr(e, "status_code", None) == 400
+ and "Penalty is not enabled" in str(e)
+ ):
+ if not self.silent:
+ print_yellow("Provider rejected penalty params. Retrying without presence_penalty.")
+ _apply_quirk_action("drop:presence_penalty", kwargs, self.silent)
+ _quirks.record(model, "drop:presence_penalty")
+ response = await self.async_client.chat.completions.create(**kwargs)
+ else:
+ raise
return response
def generate(
@@ -626,6 +916,7 @@ class LLM:
top_p: Optional[float] = None,
extra_body: Optional[Dict[str, Any]] = None,
max_tokens: Optional[int] = None,
+ auto_execute_tools: bool = True,
) -> ChatCompletionMessage:
"""
Generate a response through the remote API.
@@ -668,7 +959,7 @@ class LLM:
choice = response.choices[0]
message: ChatCompletionMessage = choice.message
- if hasattr(message, "tool_calls") and message.tool_calls:
+ if auto_execute_tools and hasattr(message, "tool_calls") and message.tool_calls:
# Hantera flera verktygsanrop sequensielt
for tool_call in message.tool_calls:
try:
@@ -738,7 +1029,7 @@ class LLM:
}
)
# fallback: older SDKs / shapes:
- if hasattr(message, "function_call") and message.function_call:
+ if auto_execute_tools and hasattr(message, "function_call") and message.function_call:
fc = message.function_call
func_name = getattr(fc, "name", None) or (
fc.get("name") if isinstance(fc, dict) else None
@@ -775,12 +1066,15 @@ class LLM:
if hasattr(message, "content_text"):
result: str = message.content_text
- # Qwen3 and some other models include ... blocks directly
# in the content field instead of (or in addition to) using reasoning_content.
# Strip those blocks so they never leak to the user.
if isinstance(result, str):
result = _strip_think_tags(result)
- message.content = result
+ # Only overwrite message.content if it is still a plain string.
+ # When format= is used, message.content already holds the parsed
+ # Pydantic instance and must not be replaced with its JSON text form.
+ if not isinstance(message.content, BaseModel):
+ message.content = result
# Spara i meddelandehistorik (utan verktygsanrop för ren historik)
self.messages.append({"role": "assistant", "content": result})
@@ -1184,13 +1478,32 @@ if __name__ == "__main__":
final_answer: float
explanation: str
- llm = LLM()
+ # on_vpn=False → uses LLM_API_URL (https://llm.edfast.se) with Bearer token auth
+ llm = LLM(on_vpn=False)
+ llm.host_url = "http://192.168.1.12:8897/v1" # Override for local testing
print('LLM URL:', llm.host_url)
- # base_url must be just the /v1 root — the OpenAI SDK appends the
- # correct path itself (/v1/responses or /v1/chat/completions).
- # Setting it to the full endpoint path causes the SDK to produce
- # invalid URLs like /v1/chat/completions/responses.
- llm.host_url = "http://192.168.1.12:8000/v1"
+ response = llm.generate(
+ query="Tell me about sweden?",
+ model="big-smart",
+ think=True,
+ stream=True,
+ )
+ for chunk_type, chunk_content in response:
+ if chunk_type == "content":
+ print_blue(f"Content chunk: {chunk_content}")
+ elif chunk_type == "thinking":
+ print_yellow(f"Thinking chunk: {chunk_content}")
+ elif chunk_type == "thinking_complete":
+ print_green(f"Complete thinking: {chunk_content}")
+ response = llm.generate(
+ query="What is the capital of Norway?",
+ model="smart",
+ think=False
+ )
+ print_blue(response)
+ exit()
+
+
response = llm.generate(
query="""Create a simple math problem solution in JSON format with this structure:
{
diff --git a/_llm/provider_quirks.json b/_llm/provider_quirks.json
new file mode 100644
index 0000000..8853cf1
--- /dev/null
+++ b/_llm/provider_quirks.json
@@ -0,0 +1,8 @@
+{
+ "models/gemini-flash-latest": [
+ "drop:presence_penalty"
+ ],
+ "gemini-3-flash-preview": [
+ "drop:presence_penalty"
+ ]
+}
\ No newline at end of file