Implement provider quirks for handling model-specific parameter drops

main
Lasse Server 2 weeks ago
parent 1950c02bd8
commit 8eec27e9e3
  1. 419
      _llm/llm.py
  2. 8
      _llm/provider_quirks.json

@ -49,8 +49,62 @@ except ImportError:
env_manager.set_env()
tokenizer = tiktoken.get_encoding("cl100k_base")
_QUIRKS_FILE = os.path.join(os.path.dirname(__file__), "provider_quirks.json")
# Maps 400-error substrings to the retry action to take.
# "drop:<param>" removes that kwarg from the request and retries.
# thought_signature errors are fixed at the message-history level in chat.py,
# not via a parameter drop, so they are not listed here.
_ERROR_RETRY_RULES: list[tuple[str, str]] = [
("Penalty is not enabled", "drop:presence_penalty"),
]
class _ProviderQuirks:
"""Persists per-model param drops discovered at runtime so we skip them from the first call."""
def __init__(self):
self._data: dict[str, list[str]] = {} # model -> list of actions to apply
self._load()
def _load(self):
try:
with open(_QUIRKS_FILE) as f:
self._data = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
self._data = {}
def _save(self):
try:
with open(_QUIRKS_FILE, "w") as f:
json.dump(self._data, f, indent=2)
except Exception:
pass
def get_actions(self, model: str) -> list[str]:
return list(self._data.get(model, []))
def record(self, model: str, action: str):
if action not in self._data.get(model, []):
self._data.setdefault(model, []).append(action)
self._save()
_quirks = _ProviderQuirks()
def _apply_quirk_action(action: str, kwargs: dict, silent: bool = False):
"""Mutate kwargs in-place according to a quirk action string."""
if action.startswith("drop:"):
param = action[5:]
if param in kwargs:
kwargs.pop(param)
if not silent:
print_yellow(f"Skipping unsupported param '{param}' for this model.")
def _strip_think_tags(text: str) -> str:
"""
@ -112,7 +166,7 @@ class LLM:
system_message: str = "You are an assistant.",
temperature: float = 0.01,
model: Optional[str] = None,
max_length_answer: int = 8000,
max_length_answer: int = 3000,
messages: Optional[list[dict]] = None,
chat: bool = True,
tools: Optional[list] = None,
@ -123,6 +177,8 @@ class LLM:
presence_penalty: float = 0.3,
top_p: float = 0.9,
extra_body: Optional[Dict[str, Any]] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
) -> None:
"""
Initialize the assistant wrapper.
@ -142,6 +198,7 @@ class LLM:
presence_penalty: Default presence penalty forwarded with each request.
top_p: Default nucleus sampling value used for sampling (see vLLM sampling params).
extra_body: Additional sampler payload sent via `extra_body` (e.g. repetition penalties).
base_url: Override the host URL directly (takes precedence over on_vpn/env vars).
"""
self.model = self.get_model(model)
self.call_model = self.model
@ -160,30 +217,41 @@ class LLM:
self.think = think
self.tools = tools or []
self.silent = silent
headers = {
"Authorization": f"Basic {self.get_credentials()}",
}
self.on_vpn = on_vpn
if self.on_vpn:
self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/vllm/v1"
# Auth is always via Bearer token (LLM_BEARER), whether remote or local.
# Don't set Authorization in default_headers — the OpenAI client adds it
# automatically from api_key, and a second Authorization header in
# default_headers would cause a conflict that makes vllm-sr return 404.
self._api_key = api_key # user-supplied key; takes precedence over LLM_BEARER env var
headers = {}
if base_url:
self.host_url = base_url
elif self.on_vpn:
# On VPN: connect directly to the local vLLM server (no auth needed)
self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/v1"
else:
self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") + "/vllm/v1"
# Remote: use the public HTTPS endpoint, auth is via Bearer token
self.host_url = os.getenv("LLM_API_URL").rstrip("/") + "/v1"
self._make_clients(headers, timeout)
def _make_clients(self, headers=None, timeout=300):
session_id = uuid.uuid4()
headers = headers or {"Authorization": f"Basic {self.get_credentials()}", "X-Session-ID": str(session_id)}
if headers is None:
headers = {"X-Session-ID": str(session_id)}
# Use caller-supplied key if present, otherwise fall back to env.
# This supports per-request user keys for external providers (Berget, OpenAI).
api_key = self._api_key or os.getenv("LLM_BEARER", "NONE")
# Sync client
self.client: OpenAI = OpenAI(
base_url=self.host_url,
api_key="NONE",
api_key=api_key,
default_headers=headers,
timeout=timeout,
)
# Async client - fix the double /v1 issue
# Async client
self.async_client = AsyncOpenAI(
base_url=self.host_url,
api_key="NONE",
api_key=api_key,
default_headers=headers,
timeout=timeout,
)
@ -201,7 +269,7 @@ class LLM:
sensible defaults while still accepting fully qualified model names.
"""
default_model = os.getenv(
"LLM_MODEL_VLLM", "qwen3-14b"
"LLM_MODEL_VLLM", "smart"
)
if model_alias in {None, "", "vllm"}:
return default_model
@ -220,6 +288,68 @@ class LLM:
num_tokens += len(tokens)
return int(num_tokens)
def _count_messages_tokens(self, messages: List[Dict[str, Any]]) -> int:
num_tokens = 0
for msg in messages:
for k, v in msg.items():
if k == "content":
if not isinstance(v, str):
v = str(v)
num_tokens += len(tokenizer.encode(v))
return num_tokens
def _tokenize_message(self, model: str, text: str) -> int:
"""Ask the vLLM /tokenize endpoint for the exact token count of a string."""
import urllib.request
# Derive base URL by stripping the trailing /v1
base = self.host_url.rstrip("/")
if base.endswith("/v1"):
base = base[:-3]
url = f"{base}/tokenize"
payload = json.dumps({"model": model, "prompt": text}).encode()
req = urllib.request.Request(url, data=payload, headers={"Content-Type": "application/json"})
try:
with urllib.request.urlopen(req, timeout=10) as resp:
data = json.loads(resp.read())
return len(data.get("tokens", []))
except Exception:
# Fall back to tiktoken estimate if the endpoint is unavailable
return len(tokenizer.encode(text))
def _trim_messages_to_fit_exact(self, messages: List[Dict[str, Any]], model: str, max_tokens: int) -> List[Dict[str, Any]]:
"""Use the vLLM /tokenize endpoint to drop oldest non-system messages until they fit."""
context_window = MODEL_CONTEXT_WINDOWS.get(model, DEFAULT_CONTEXT_WINDOW)
# Reserve ~5 tokens per message for chat-template overhead (<|im_start|>role\n...<|im_end|>\n)
# that the /tokenize endpoint doesn't count when given raw content strings.
template_overhead = 5 * len(messages)
budget = context_window - max_tokens - template_overhead
if budget <= 0:
return messages
def _total_tokens(msgs):
total = 0
for m in msgs:
content = m.get("content", "") or ""
if not isinstance(content, str):
content = str(content)
total += self._tokenize_message(model, content)
return total
while True:
current = _total_tokens(messages)
if current <= budget:
break
drop_idx = next(
(i for i, m in enumerate(messages) if m.get("role") != "system"),
None,
)
if drop_idx is None:
break
if not self.silent:
print_yellow(f"Context too long ({current} > {budget} tokens), dropping message at index {drop_idx}.")
messages = messages[:drop_idx] + messages[drop_idx + 1:]
return messages
def _prepare_messages_and_model(
self,
query,
@ -258,7 +388,14 @@ class LLM:
def _build_headers(self, model):
"""Build HTTP headers for API requests, including auth and optional backend hints."""
if self.on_vpn:
# On VPN the local server doesn't require auth — use Basic as a soft hint only
headers = {"Authorization": f"Basic {self.get_credentials()}"}
else:
# Remote public endpoint requires Bearer token auth.
# The OpenAI SDK sets 'Authorization: Bearer <api_key>' automatically,
# so we don't put Authorization here to avoid conflicting headers.
headers = {}
if model == self.get_model("embeddings"):
headers["X-Model-Type"] = "embeddings"
return headers
@ -451,15 +588,64 @@ class LLM:
)
return chat
def _api_call_with_backoff(self, kwargs: dict) -> "ChatCompletion":
"""Call self.client.chat.completions.create with exponential backoff for transient errors."""
_retryable = {429, 502, 503, 529}
_max_retries = 6
_base_delay = 2.0
for _attempt in range(_max_retries + 1):
try:
return self.client.chat.completions.create(**kwargs)
except Exception as _e:
_status = getattr(_e, "status_code", None)
if _status not in _retryable or _attempt >= _max_retries:
raise
_headers = getattr(getattr(_e, "response", None), "headers", {}) or {}
_ra = _headers.get("retry-after") or _headers.get("Retry-After")
_delay = float(_ra) if _ra else _base_delay * (2 ** _attempt)
if not self.silent:
print_yellow(
f"Rate-limited (HTTP {_status}). Retrying in {_delay:.1f}s "
f"(attempt {_attempt + 1}/{_max_retries})…"
)
time.sleep(_delay)
raise RuntimeError("Exhausted retries due to rate limiting.")
def _call_remote_api(
self, model, tools, stream, options, format, headers, think=False
self, model, tools, stream, options, format, headers, think=None
) -> ChatCompletion:
"""Call the remote vLLM-backed API using the OpenAI-compatible client."""
self.call_model = model
self.messages = self._sanitize_messages(self.messages)
# Resolve thinking flag: per-call overrides instance default.
# When thinking is off, inject chat_template_kwargs so vLLM skips CoT tokens entirely.
# chat_template_kwargs is vLLM-specific; skip it when a user-supplied key is present
# (which indicates an external provider like Berget or OpenAI).
effective_think = think if think is not None else self.think
if not effective_think and not self._api_key:
body = dict(options.get("extra_body") or {})
ctk = dict(body.get("chat_template_kwargs") or {})
ctk["enable_thinking"] = False
body["chat_template_kwargs"] = ctk
options = {**options, "extra_body": body}
# Strip vLLM-specific extra_body keys (repetition_penalty, chat_template_kwargs)
# when using an external provider — they are not part of the OpenAI spec and will
# cause 4xx errors on providers that reject unknown fields.
if self._api_key:
body = dict(options.get("extra_body") or {})
body.pop("repetition_penalty", None)
body.pop("chat_template_kwargs", None)
options = {**options, "extra_body": body or None}
# For remote calls the OpenAI SDK handles auth by sending
# 'Authorization: Bearer <api_key>'. On VPN no token is needed.
# self._api_key (user-supplied) takes precedence over the LLM_BEARER env var.
api_key = self._api_key or (os.getenv("LLM_BEARER", "NONE") if not self.on_vpn else "NONE")
self.client = OpenAI(
base_url=f"{self.host_url}",
api_key="ollama",
api_key=api_key,
default_headers=headers,
timeout=300,
)
@ -479,17 +665,40 @@ class LLM:
m["role"] = "user"
m["content"] = f"Tool output:\n{m.get('content','')}"
# print('FORMAT', format)
response: ParsedResponse = self.client.responses.parse(
input=messages,
# vLLM does not support the /responses endpoint, so we use
# chat.completions.create with response_format (JSON schema) instead.
json_schema_response = self.client.chat.completions.create(
model=model,
top_p=options["top_p"],
messages=messages,
temperature=options["temperature"],
top_p=options["top_p"],
response_format={
"type": "json_schema",
"json_schema": {
"name": format.__name__,
"schema": format.model_json_schema(),
},
},
extra_body=options["extra_body"],
stream=stream,
text_format=format,
max_tokens=options["max_tokens"],
)
content_text = json_schema_response.choices[0].message.content
parsed_instance = format.model_validate_json(content_text)
message = ChatCompletionMessage.model_construct(
role="assistant",
content=parsed_instance,
)
setattr(message, "content_text", content_text)
setattr(message, "parsed", parsed_instance)
setattr(message, "parsed_dict", parsed_instance.model_dump())
choice = Choice.model_construct(index=0, finish_reason="stop", message=message)
response: ChatCompletion = ChatCompletion.model_construct(
id=json_schema_response.id,
choices=[choice],
created=json_schema_response.created,
model=json_schema_response.model,
object="chat.completion",
)
response: ChatCompletion = self._parsed_content_normalizer(response, model_cls=format)
# Call the OpenAI API
else:
@ -499,22 +708,54 @@ class LLM:
message["content"] = (
f"Tool output:\n{message.get('content','')}" # TODO Works? Other tools?
)
try:
response: ChatCompletion = self.client.chat.completions.create(
model=model,
messages=self.messages,
# frequency_penalty removed: repetition_penalty in extra_body
# already handles this. Having both causes compounding effects
# that push the model into token loops.
stream=stream,
temperature=options["temperature"],
presence_penalty=options["presence_penalty"],
top_p=options["top_p"],
tools=tools,
extra_body=options["extra_body"],
max_tokens=options["max_tokens"],
# OpenAI o-series and newer models use max_completion_tokens instead of max_tokens.
_token_kwarg = (
{"max_completion_tokens": options["max_tokens"]}
if "openai.com" in self.host_url
else {"max_tokens": options["max_tokens"]}
)
# frequency_penalty removed: repetition_penalty in extra_body already handles
# this. Having both causes compounding effects that push the model into token loops.
kwargs = {
"model": model,
"messages": self.messages,
"stream": stream,
"temperature": options["temperature"],
"presence_penalty": options["presence_penalty"],
"top_p": options["top_p"],
"tools": tools,
"extra_body": options["extra_body"],
**_token_kwarg,
}
# Apply any previously-discovered quirks for this model upfront.
for _action in _quirks.get_actions(model):
_apply_quirk_action(_action, kwargs, self.silent)
try:
response: ChatCompletion = self._api_call_with_backoff(kwargs)
except Exception as e:
err_str = str(e)
if (
getattr(e, "status_code", None) == 400
and "maximum context length" in err_str
):
# Use /tokenize for exact counts and drop messages until they fit
if not self.silent:
print_yellow("Server rejected: context too long. Trimming with exact token counts and retrying.")
self.messages = self._trim_messages_to_fit_exact(
self.messages, model, options.get("max_tokens", self.max_length_answer)
)
kwargs["messages"] = self.messages
response = self._api_call_with_backoff(kwargs)
elif (
getattr(e, "status_code", None) == 400
and "Penalty is not enabled" in err_str
):
if not self.silent:
print_yellow("Provider rejected penalty params. Retrying without presence_penalty.")
_apply_quirk_action("drop:presence_penalty", kwargs, self.silent)
_quirks.record(model, "drop:presence_penalty")
response = self._api_call_with_backoff(kwargs)
else:
import traceback
traceback.print_exc()
print_red(f"Error calling remote API: {e}")
@ -523,7 +764,6 @@ class LLM:
print()
print('TOOLS')
print_rainbow(tools, single_line=True)
# Re-raise the exception to inform the caller of the API failure
raise
# Try to extract backend information if available
@ -538,10 +778,20 @@ class LLM:
return response
async def _call_remote_api_async(
self, model, tools, stream, options, format, headers, think=False
self, model, tools, stream, options, format, headers, think=None
):
"""Call the remote API asynchronously using OpenAI async client."""
self.messages = self._sanitize_messages(self.messages)
# Resolve thinking flag: per-call overrides instance default.
effective_think = think if think is not None else self.think
if not effective_think:
body = dict(options.get("extra_body") or {})
ctk = dict(body.get("chat_template_kwargs") or {})
ctk["enable_thinking"] = False
body["chat_template_kwargs"] = ctk
options = {**options, "extra_body": body}
# Update the async client with the latest headers
self.async_client = AsyncOpenAI(
base_url=self.host_url, # Remove the extra /v1
@ -594,19 +844,59 @@ class LLM:
if m.get("role") not in {"user", "assistant"}:
m["role"] = "user"
m["content"] = f"Tool output:\n{m.get('content','')}"
response = await self.async_client.responses.parse(
input=messages,
# vLLM does not support the /responses endpoint, so we use
# chat.completions.create with response_format (JSON schema) instead.
json_schema_response = await self.async_client.chat.completions.create(
model=model,
top_p=options["top_p"],
messages=messages,
temperature=options["temperature"],
extra_body=options["extra_body"],
stream=stream,
text_format=format,
top_p=options["top_p"],
response_format={
"type": "json_schema",
"json_schema": {
"name": format.__name__,
"schema": format.model_json_schema(),
},
},
extra_body=options.get("extra_body"),
max_tokens=options.get("max_tokens"),
)
content_text = json_schema_response.choices[0].message.content
parsed_instance = format.model_validate_json(content_text)
message = ChatCompletionMessage.model_construct(
role="assistant",
content=parsed_instance,
)
setattr(message, "content_text", content_text)
setattr(message, "parsed", parsed_instance)
setattr(message, "parsed_dict", parsed_instance.model_dump())
choice = Choice.model_construct(index=0, finish_reason="stop", message=message)
return ChatCompletion.model_construct(
id=json_schema_response.id,
choices=[choice],
created=json_schema_response.created,
model=json_schema_response.model,
object="chat.completion",
)
return self._parsed_content_normalizer(response, model_cls=format)
if tools:
kwargs["tools"] = tools
# Apply any previously-discovered quirks for this model upfront.
for _action in _quirks.get_actions(model):
_apply_quirk_action(_action, kwargs, self.silent)
try:
response = await self.async_client.chat.completions.create(**kwargs)
except Exception as e:
if (
getattr(e, "status_code", None) == 400
and "Penalty is not enabled" in str(e)
):
if not self.silent:
print_yellow("Provider rejected penalty params. Retrying without presence_penalty.")
_apply_quirk_action("drop:presence_penalty", kwargs, self.silent)
_quirks.record(model, "drop:presence_penalty")
response = await self.async_client.chat.completions.create(**kwargs)
else:
raise
return response
def generate(
@ -626,6 +916,7 @@ class LLM:
top_p: Optional[float] = None,
extra_body: Optional[Dict[str, Any]] = None,
max_tokens: Optional[int] = None,
auto_execute_tools: bool = True,
) -> ChatCompletionMessage:
"""
Generate a response through the remote API.
@ -668,7 +959,7 @@ class LLM:
choice = response.choices[0]
message: ChatCompletionMessage = choice.message
if hasattr(message, "tool_calls") and message.tool_calls:
if auto_execute_tools and hasattr(message, "tool_calls") and message.tool_calls:
# Hantera flera verktygsanrop sequensielt
for tool_call in message.tool_calls:
try:
@ -738,7 +1029,7 @@ class LLM:
}
)
# fallback: older SDKs / shapes:
if hasattr(message, "function_call") and message.function_call:
if auto_execute_tools and hasattr(message, "function_call") and message.function_call:
fc = message.function_call
func_name = getattr(fc, "name", None) or (
fc.get("name") if isinstance(fc, dict) else None
@ -775,11 +1066,14 @@ class LLM:
if hasattr(message, "content_text"):
result: str = message.content_text
# Qwen3 and some other models include <think>...</think> blocks directly
# in the content field instead of (or in addition to) using reasoning_content.
# Strip those blocks so they never leak to the user.
if isinstance(result, str):
result = _strip_think_tags(result)
# Only overwrite message.content if it is still a plain string.
# When format= is used, message.content already holds the parsed
# Pydantic instance and must not be replaced with its JSON text form.
if not isinstance(message.content, BaseModel):
message.content = result
# Spara i meddelandehistorik (utan verktygsanrop för ren historik)
@ -1184,13 +1478,32 @@ if __name__ == "__main__":
final_answer: float
explanation: str
llm = LLM()
# on_vpn=False → uses LLM_API_URL (https://llm.edfast.se) with Bearer token auth
llm = LLM(on_vpn=False)
llm.host_url = "http://192.168.1.12:8897/v1" # Override for local testing
print('LLM URL:', llm.host_url)
# base_url must be just the /v1 root — the OpenAI SDK appends the
# correct path itself (/v1/responses or /v1/chat/completions).
# Setting it to the full endpoint path causes the SDK to produce
# invalid URLs like /v1/chat/completions/responses.
llm.host_url = "http://192.168.1.12:8000/v1"
response = llm.generate(
query="Tell me about sweden?",
model="big-smart",
think=True,
stream=True,
)
for chunk_type, chunk_content in response:
if chunk_type == "content":
print_blue(f"Content chunk: {chunk_content}")
elif chunk_type == "thinking":
print_yellow(f"Thinking chunk: {chunk_content}")
elif chunk_type == "thinking_complete":
print_green(f"Complete thinking: {chunk_content}")
response = llm.generate(
query="What is the capital of Norway?",
model="smart",
think=False
)
print_blue(response)
exit()
response = llm.generate(
query="""Create a simple math problem solution in JSON format with this structure:
{

@ -0,0 +1,8 @@
{
"models/gemini-flash-latest": [
"drop:presence_penalty"
],
"gemini-3-flash-preview": [
"drop:presence_penalty"
]
}
Loading…
Cancel
Save