diff --git a/_llm/llm.py b/_llm/llm.py index 6f78725..6b8e6d8 100644 --- a/_llm/llm.py +++ b/_llm/llm.py @@ -49,8 +49,62 @@ except ImportError: env_manager.set_env() + tokenizer = tiktoken.get_encoding("cl100k_base") +_QUIRKS_FILE = os.path.join(os.path.dirname(__file__), "provider_quirks.json") + +# Maps 400-error substrings to the retry action to take. +# "drop:" removes that kwarg from the request and retries. +# thought_signature errors are fixed at the message-history level in chat.py, +# not via a parameter drop, so they are not listed here. +_ERROR_RETRY_RULES: list[tuple[str, str]] = [ + ("Penalty is not enabled", "drop:presence_penalty"), +] + + +class _ProviderQuirks: + """Persists per-model param drops discovered at runtime so we skip them from the first call.""" + + def __init__(self): + self._data: dict[str, list[str]] = {} # model -> list of actions to apply + self._load() + + def _load(self): + try: + with open(_QUIRKS_FILE) as f: + self._data = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + self._data = {} + + def _save(self): + try: + with open(_QUIRKS_FILE, "w") as f: + json.dump(self._data, f, indent=2) + except Exception: + pass + + def get_actions(self, model: str) -> list[str]: + return list(self._data.get(model, [])) + + def record(self, model: str, action: str): + if action not in self._data.get(model, []): + self._data.setdefault(model, []).append(action) + self._save() + + +_quirks = _ProviderQuirks() + + +def _apply_quirk_action(action: str, kwargs: dict, silent: bool = False): + """Mutate kwargs in-place according to a quirk action string.""" + if action.startswith("drop:"): + param = action[5:] + if param in kwargs: + kwargs.pop(param) + if not silent: + print_yellow(f"Skipping unsupported param '{param}' for this model.") + def _strip_think_tags(text: str) -> str: """ @@ -112,7 +166,7 @@ class LLM: system_message: str = "You are an assistant.", temperature: float = 0.01, model: Optional[str] = None, - max_length_answer: int = 8000, + max_length_answer: int = 3000, messages: Optional[list[dict]] = None, chat: bool = True, tools: Optional[list] = None, @@ -123,6 +177,8 @@ class LLM: presence_penalty: float = 0.3, top_p: float = 0.9, extra_body: Optional[Dict[str, Any]] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, ) -> None: """ Initialize the assistant wrapper. @@ -142,6 +198,7 @@ class LLM: presence_penalty: Default presence penalty forwarded with each request. top_p: Default nucleus sampling value used for sampling (see vLLM sampling params). extra_body: Additional sampler payload sent via `extra_body` (e.g. repetition penalties). + base_url: Override the host URL directly (takes precedence over on_vpn/env vars). """ self.model = self.get_model(model) self.call_model = self.model @@ -160,30 +217,41 @@ class LLM: self.think = think self.tools = tools or [] self.silent = silent - headers = { - "Authorization": f"Basic {self.get_credentials()}", - } self.on_vpn = on_vpn - if self.on_vpn: - self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/vllm/v1" + # Auth is always via Bearer token (LLM_BEARER), whether remote or local. + # Don't set Authorization in default_headers — the OpenAI client adds it + # automatically from api_key, and a second Authorization header in + # default_headers would cause a conflict that makes vllm-sr return 404. + self._api_key = api_key # user-supplied key; takes precedence over LLM_BEARER env var + headers = {} + if base_url: + self.host_url = base_url + elif self.on_vpn: + # On VPN: connect directly to the local vLLM server (no auth needed) + self.host_url = f"{os.getenv('LLM_URL')}:{os.getenv('LLM_PORT')}/v1" else: - self.host_url = os.getenv("LLM_API_URL").rstrip("/api/chat/") + "/vllm/v1" + # Remote: use the public HTTPS endpoint, auth is via Bearer token + self.host_url = os.getenv("LLM_API_URL").rstrip("/") + "/v1" self._make_clients(headers, timeout) def _make_clients(self, headers=None, timeout=300): session_id = uuid.uuid4() - headers = headers or {"Authorization": f"Basic {self.get_credentials()}", "X-Session-ID": str(session_id)} + if headers is None: + headers = {"X-Session-ID": str(session_id)} + # Use caller-supplied key if present, otherwise fall back to env. + # This supports per-request user keys for external providers (Berget, OpenAI). + api_key = self._api_key or os.getenv("LLM_BEARER", "NONE") # Sync client self.client: OpenAI = OpenAI( base_url=self.host_url, - api_key="NONE", + api_key=api_key, default_headers=headers, timeout=timeout, ) - # Async client - fix the double /v1 issue + # Async client self.async_client = AsyncOpenAI( base_url=self.host_url, - api_key="NONE", + api_key=api_key, default_headers=headers, timeout=timeout, ) @@ -201,7 +269,7 @@ class LLM: sensible defaults while still accepting fully qualified model names. """ default_model = os.getenv( - "LLM_MODEL_VLLM", "qwen3-14b" + "LLM_MODEL_VLLM", "smart" ) if model_alias in {None, "", "vllm"}: return default_model @@ -220,6 +288,68 @@ class LLM: num_tokens += len(tokens) return int(num_tokens) + def _count_messages_tokens(self, messages: List[Dict[str, Any]]) -> int: + num_tokens = 0 + for msg in messages: + for k, v in msg.items(): + if k == "content": + if not isinstance(v, str): + v = str(v) + num_tokens += len(tokenizer.encode(v)) + return num_tokens + + def _tokenize_message(self, model: str, text: str) -> int: + """Ask the vLLM /tokenize endpoint for the exact token count of a string.""" + import urllib.request + # Derive base URL by stripping the trailing /v1 + base = self.host_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + url = f"{base}/tokenize" + payload = json.dumps({"model": model, "prompt": text}).encode() + req = urllib.request.Request(url, data=payload, headers={"Content-Type": "application/json"}) + try: + with urllib.request.urlopen(req, timeout=10) as resp: + data = json.loads(resp.read()) + return len(data.get("tokens", [])) + except Exception: + # Fall back to tiktoken estimate if the endpoint is unavailable + return len(tokenizer.encode(text)) + + def _trim_messages_to_fit_exact(self, messages: List[Dict[str, Any]], model: str, max_tokens: int) -> List[Dict[str, Any]]: + """Use the vLLM /tokenize endpoint to drop oldest non-system messages until they fit.""" + context_window = MODEL_CONTEXT_WINDOWS.get(model, DEFAULT_CONTEXT_WINDOW) + # Reserve ~5 tokens per message for chat-template overhead (<|im_start|>role\n...<|im_end|>\n) + # that the /tokenize endpoint doesn't count when given raw content strings. + template_overhead = 5 * len(messages) + budget = context_window - max_tokens - template_overhead + if budget <= 0: + return messages + + def _total_tokens(msgs): + total = 0 + for m in msgs: + content = m.get("content", "") or "" + if not isinstance(content, str): + content = str(content) + total += self._tokenize_message(model, content) + return total + + while True: + current = _total_tokens(messages) + if current <= budget: + break + drop_idx = next( + (i for i, m in enumerate(messages) if m.get("role") != "system"), + None, + ) + if drop_idx is None: + break + if not self.silent: + print_yellow(f"Context too long ({current} > {budget} tokens), dropping message at index {drop_idx}.") + messages = messages[:drop_idx] + messages[drop_idx + 1:] + return messages + def _prepare_messages_and_model( self, query, @@ -258,7 +388,14 @@ class LLM: def _build_headers(self, model): """Build HTTP headers for API requests, including auth and optional backend hints.""" - headers = {"Authorization": f"Basic {self.get_credentials()}"} + if self.on_vpn: + # On VPN the local server doesn't require auth — use Basic as a soft hint only + headers = {"Authorization": f"Basic {self.get_credentials()}"} + else: + # Remote public endpoint requires Bearer token auth. + # The OpenAI SDK sets 'Authorization: Bearer ' automatically, + # so we don't put Authorization here to avoid conflicting headers. + headers = {} if model == self.get_model("embeddings"): headers["X-Model-Type"] = "embeddings" return headers @@ -451,15 +588,64 @@ class LLM: ) return chat + def _api_call_with_backoff(self, kwargs: dict) -> "ChatCompletion": + """Call self.client.chat.completions.create with exponential backoff for transient errors.""" + _retryable = {429, 502, 503, 529} + _max_retries = 6 + _base_delay = 2.0 + for _attempt in range(_max_retries + 1): + try: + return self.client.chat.completions.create(**kwargs) + except Exception as _e: + _status = getattr(_e, "status_code", None) + if _status not in _retryable or _attempt >= _max_retries: + raise + _headers = getattr(getattr(_e, "response", None), "headers", {}) or {} + _ra = _headers.get("retry-after") or _headers.get("Retry-After") + _delay = float(_ra) if _ra else _base_delay * (2 ** _attempt) + if not self.silent: + print_yellow( + f"Rate-limited (HTTP {_status}). Retrying in {_delay:.1f}s " + f"(attempt {_attempt + 1}/{_max_retries})…" + ) + time.sleep(_delay) + raise RuntimeError("Exhausted retries due to rate limiting.") + def _call_remote_api( - self, model, tools, stream, options, format, headers, think=False + self, model, tools, stream, options, format, headers, think=None ) -> ChatCompletion: """Call the remote vLLM-backed API using the OpenAI-compatible client.""" self.call_model = model self.messages = self._sanitize_messages(self.messages) + + # Resolve thinking flag: per-call overrides instance default. + # When thinking is off, inject chat_template_kwargs so vLLM skips CoT tokens entirely. + # chat_template_kwargs is vLLM-specific; skip it when a user-supplied key is present + # (which indicates an external provider like Berget or OpenAI). + effective_think = think if think is not None else self.think + if not effective_think and not self._api_key: + body = dict(options.get("extra_body") or {}) + ctk = dict(body.get("chat_template_kwargs") or {}) + ctk["enable_thinking"] = False + body["chat_template_kwargs"] = ctk + options = {**options, "extra_body": body} + + # Strip vLLM-specific extra_body keys (repetition_penalty, chat_template_kwargs) + # when using an external provider — they are not part of the OpenAI spec and will + # cause 4xx errors on providers that reject unknown fields. + if self._api_key: + body = dict(options.get("extra_body") or {}) + body.pop("repetition_penalty", None) + body.pop("chat_template_kwargs", None) + options = {**options, "extra_body": body or None} + + # For remote calls the OpenAI SDK handles auth by sending + # 'Authorization: Bearer '. On VPN no token is needed. + # self._api_key (user-supplied) takes precedence over the LLM_BEARER env var. + api_key = self._api_key or (os.getenv("LLM_BEARER", "NONE") if not self.on_vpn else "NONE") self.client = OpenAI( base_url=f"{self.host_url}", - api_key="ollama", + api_key=api_key, default_headers=headers, timeout=300, ) @@ -479,17 +665,40 @@ class LLM: m["role"] = "user" m["content"] = f"Tool output:\n{m.get('content','')}" - # print('FORMAT', format) - response: ParsedResponse = self.client.responses.parse( - input=messages, + # vLLM does not support the /responses endpoint, so we use + # chat.completions.create with response_format (JSON schema) instead. + json_schema_response = self.client.chat.completions.create( model=model, - top_p=options["top_p"], + messages=messages, temperature=options["temperature"], + top_p=options["top_p"], + response_format={ + "type": "json_schema", + "json_schema": { + "name": format.__name__, + "schema": format.model_json_schema(), + }, + }, extra_body=options["extra_body"], - stream=stream, - text_format=format, + max_tokens=options["max_tokens"], + ) + content_text = json_schema_response.choices[0].message.content + parsed_instance = format.model_validate_json(content_text) + message = ChatCompletionMessage.model_construct( + role="assistant", + content=parsed_instance, + ) + setattr(message, "content_text", content_text) + setattr(message, "parsed", parsed_instance) + setattr(message, "parsed_dict", parsed_instance.model_dump()) + choice = Choice.model_construct(index=0, finish_reason="stop", message=message) + response: ChatCompletion = ChatCompletion.model_construct( + id=json_schema_response.id, + choices=[choice], + created=json_schema_response.created, + model=json_schema_response.model, + object="chat.completion", ) - response: ChatCompletion = self._parsed_content_normalizer(response, model_cls=format) # Call the OpenAI API else: @@ -499,32 +708,63 @@ class LLM: message["content"] = ( f"Tool output:\n{message.get('content','')}" # TODO Works? Other tools? ) + # OpenAI o-series and newer models use max_completion_tokens instead of max_tokens. + _token_kwarg = ( + {"max_completion_tokens": options["max_tokens"]} + if "openai.com" in self.host_url + else {"max_tokens": options["max_tokens"]} + ) + # frequency_penalty removed: repetition_penalty in extra_body already handles + # this. Having both causes compounding effects that push the model into token loops. + kwargs = { + "model": model, + "messages": self.messages, + "stream": stream, + "temperature": options["temperature"], + "presence_penalty": options["presence_penalty"], + "top_p": options["top_p"], + "tools": tools, + "extra_body": options["extra_body"], + **_token_kwarg, + } + # Apply any previously-discovered quirks for this model upfront. + for _action in _quirks.get_actions(model): + _apply_quirk_action(_action, kwargs, self.silent) try: - response: ChatCompletion = self.client.chat.completions.create( - model=model, - messages=self.messages, - # frequency_penalty removed: repetition_penalty in extra_body - # already handles this. Having both causes compounding effects - # that push the model into token loops. - stream=stream, - temperature=options["temperature"], - presence_penalty=options["presence_penalty"], - top_p=options["top_p"], - tools=tools, - extra_body=options["extra_body"], - max_tokens=options["max_tokens"], - ) + response: ChatCompletion = self._api_call_with_backoff(kwargs) except Exception as e: - import traceback - traceback.print_exc() - print_red(f"Error calling remote API: {e}") - print('---MESSAGES---') - print_rainbow(self.messages, single_line=True) - print() - print('TOOLS') - print_rainbow(tools, single_line=True) - # Re-raise the exception to inform the caller of the API failure - raise + err_str = str(e) + if ( + getattr(e, "status_code", None) == 400 + and "maximum context length" in err_str + ): + # Use /tokenize for exact counts and drop messages until they fit + if not self.silent: + print_yellow("Server rejected: context too long. Trimming with exact token counts and retrying.") + self.messages = self._trim_messages_to_fit_exact( + self.messages, model, options.get("max_tokens", self.max_length_answer) + ) + kwargs["messages"] = self.messages + response = self._api_call_with_backoff(kwargs) + elif ( + getattr(e, "status_code", None) == 400 + and "Penalty is not enabled" in err_str + ): + if not self.silent: + print_yellow("Provider rejected penalty params. Retrying without presence_penalty.") + _apply_quirk_action("drop:presence_penalty", kwargs, self.silent) + _quirks.record(model, "drop:presence_penalty") + response = self._api_call_with_backoff(kwargs) + else: + import traceback + traceback.print_exc() + print_red(f"Error calling remote API: {e}") + print('---MESSAGES---') + print_rainbow(self.messages, single_line=True) + print() + print('TOOLS') + print_rainbow(tools, single_line=True) + raise # Try to extract backend information if available try: @@ -538,10 +778,20 @@ class LLM: return response async def _call_remote_api_async( - self, model, tools, stream, options, format, headers, think=False + self, model, tools, stream, options, format, headers, think=None ): """Call the remote API asynchronously using OpenAI async client.""" self.messages = self._sanitize_messages(self.messages) + + # Resolve thinking flag: per-call overrides instance default. + effective_think = think if think is not None else self.think + if not effective_think: + body = dict(options.get("extra_body") or {}) + ctk = dict(body.get("chat_template_kwargs") or {}) + ctk["enable_thinking"] = False + body["chat_template_kwargs"] = ctk + options = {**options, "extra_body": body} + # Update the async client with the latest headers self.async_client = AsyncOpenAI( base_url=self.host_url, # Remove the extra /v1 @@ -594,19 +844,59 @@ class LLM: if m.get("role") not in {"user", "assistant"}: m["role"] = "user" m["content"] = f"Tool output:\n{m.get('content','')}" - response = await self.async_client.responses.parse( - input=messages, + # vLLM does not support the /responses endpoint, so we use + # chat.completions.create with response_format (JSON schema) instead. + json_schema_response = await self.async_client.chat.completions.create( model=model, - top_p=options["top_p"], + messages=messages, temperature=options["temperature"], - extra_body=options["extra_body"], - stream=stream, - text_format=format, + top_p=options["top_p"], + response_format={ + "type": "json_schema", + "json_schema": { + "name": format.__name__, + "schema": format.model_json_schema(), + }, + }, + extra_body=options.get("extra_body"), + max_tokens=options.get("max_tokens"), + ) + content_text = json_schema_response.choices[0].message.content + parsed_instance = format.model_validate_json(content_text) + message = ChatCompletionMessage.model_construct( + role="assistant", + content=parsed_instance, + ) + setattr(message, "content_text", content_text) + setattr(message, "parsed", parsed_instance) + setattr(message, "parsed_dict", parsed_instance.model_dump()) + choice = Choice.model_construct(index=0, finish_reason="stop", message=message) + return ChatCompletion.model_construct( + id=json_schema_response.id, + choices=[choice], + created=json_schema_response.created, + model=json_schema_response.model, + object="chat.completion", ) - return self._parsed_content_normalizer(response, model_cls=format) if tools: kwargs["tools"] = tools - response = await self.async_client.chat.completions.create(**kwargs) + # Apply any previously-discovered quirks for this model upfront. + for _action in _quirks.get_actions(model): + _apply_quirk_action(_action, kwargs, self.silent) + try: + response = await self.async_client.chat.completions.create(**kwargs) + except Exception as e: + if ( + getattr(e, "status_code", None) == 400 + and "Penalty is not enabled" in str(e) + ): + if not self.silent: + print_yellow("Provider rejected penalty params. Retrying without presence_penalty.") + _apply_quirk_action("drop:presence_penalty", kwargs, self.silent) + _quirks.record(model, "drop:presence_penalty") + response = await self.async_client.chat.completions.create(**kwargs) + else: + raise return response def generate( @@ -626,6 +916,7 @@ class LLM: top_p: Optional[float] = None, extra_body: Optional[Dict[str, Any]] = None, max_tokens: Optional[int] = None, + auto_execute_tools: bool = True, ) -> ChatCompletionMessage: """ Generate a response through the remote API. @@ -668,7 +959,7 @@ class LLM: choice = response.choices[0] message: ChatCompletionMessage = choice.message - if hasattr(message, "tool_calls") and message.tool_calls: + if auto_execute_tools and hasattr(message, "tool_calls") and message.tool_calls: # Hantera flera verktygsanrop sequensielt for tool_call in message.tool_calls: try: @@ -738,7 +1029,7 @@ class LLM: } ) # fallback: older SDKs / shapes: - if hasattr(message, "function_call") and message.function_call: + if auto_execute_tools and hasattr(message, "function_call") and message.function_call: fc = message.function_call func_name = getattr(fc, "name", None) or ( fc.get("name") if isinstance(fc, dict) else None @@ -775,12 +1066,15 @@ class LLM: if hasattr(message, "content_text"): result: str = message.content_text - # Qwen3 and some other models include ... blocks directly # in the content field instead of (or in addition to) using reasoning_content. # Strip those blocks so they never leak to the user. if isinstance(result, str): result = _strip_think_tags(result) - message.content = result + # Only overwrite message.content if it is still a plain string. + # When format= is used, message.content already holds the parsed + # Pydantic instance and must not be replaced with its JSON text form. + if not isinstance(message.content, BaseModel): + message.content = result # Spara i meddelandehistorik (utan verktygsanrop för ren historik) self.messages.append({"role": "assistant", "content": result}) @@ -1184,13 +1478,32 @@ if __name__ == "__main__": final_answer: float explanation: str - llm = LLM() + # on_vpn=False → uses LLM_API_URL (https://llm.edfast.se) with Bearer token auth + llm = LLM(on_vpn=False) + llm.host_url = "http://192.168.1.12:8897/v1" # Override for local testing print('LLM URL:', llm.host_url) - # base_url must be just the /v1 root — the OpenAI SDK appends the - # correct path itself (/v1/responses or /v1/chat/completions). - # Setting it to the full endpoint path causes the SDK to produce - # invalid URLs like /v1/chat/completions/responses. - llm.host_url = "http://192.168.1.12:8000/v1" + response = llm.generate( + query="Tell me about sweden?", + model="big-smart", + think=True, + stream=True, + ) + for chunk_type, chunk_content in response: + if chunk_type == "content": + print_blue(f"Content chunk: {chunk_content}") + elif chunk_type == "thinking": + print_yellow(f"Thinking chunk: {chunk_content}") + elif chunk_type == "thinking_complete": + print_green(f"Complete thinking: {chunk_content}") + response = llm.generate( + query="What is the capital of Norway?", + model="smart", + think=False + ) + print_blue(response) + exit() + + response = llm.generate( query="""Create a simple math problem solution in JSON format with this structure: { diff --git a/_llm/provider_quirks.json b/_llm/provider_quirks.json new file mode 100644 index 0000000..8853cf1 --- /dev/null +++ b/_llm/provider_quirks.json @@ -0,0 +1,8 @@ +{ + "models/gemini-flash-latest": [ + "drop:presence_penalty" + ], + "gemini-3-flash-preview": [ + "drop:presence_penalty" + ] +} \ No newline at end of file