diff --git a/CHANGELOG.md b/CHANGELOG.md index af97ef3..b1c90ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,7 @@ See [keep a changelog] for information about writing changes to this log. ## [Unreleased] -* — +* Updated trim message litellm guardrail to support tool call and minor bug fixes. ## [0.4.0-rc.0] - 2026-05-06 diff --git a/applications/litellm/templates/message-trimming-config.yaml b/applications/litellm/templates/message-trimming-config.yaml index f1ef517..6e63220 100644 --- a/applications/litellm/templates/message-trimming-config.yaml +++ b/applications/litellm/templates/message-trimming-config.yaml @@ -8,15 +8,11 @@ data: message_overflow.py: |- from typing import Literal, Optional, Union from litellm.utils import trim_messages, get_max_tokens, token_counter - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - get_completion_messages, - ) from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.caching.caching import DualCache import logging import yaml - from pathlib import Path logger = logging.getLogger(__name__) @@ -41,6 +37,27 @@ data: self.safety_buffer = default_config.get("safety_buffer", 500) self.debug = default_config.get("debug", False) + # Context-window resolution: per-model map wins, then litellm's + # built-in get_max_tokens, then this global default. + self.default_max_context_tokens = default_config.get( + "default_max_context_tokens", 8192 + ) + self.max_context_tokens_by_model = default_config.get( + "max_context_tokens_by_model", {} + ) or {} + + # Trailing role:tool handling. Default is to PRESERVE trailing tool + # messages — the agent-loop shape `User -> Asst{tool_calls} -> Tool` + # is the normal way to ask the model to reason from tool results. + # Only enable for upstream chat templates that explicitly reject + # tool-terminal conversations (e.g. strict HF Mistral v0.3 template). + self.pop_trailing_tool_messages = bool( + default_config.get("pop_trailing_tool_messages", False) + ) + self.pop_trailing_tool_messages_by_model = default_config.get( + "pop_trailing_tool_messages_by_model", {} + ) or {} + def _load_config(self): with open("config.yaml", "r") as file: config = yaml.safe_load(file) @@ -64,6 +81,37 @@ data: if self.debug: print(f"[GUARDRAIL] {message}") + def _resolve_max_context_tokens(self, model: Optional[str]) -> int: + """Resolve the model's context-window size. + + Order: per-model override map -> litellm.get_max_tokens -> global default. + litellm's `model_prices_and_context_window.json` doesn't cover every + proxied model name (vLLM, Bedrock variants, custom deployments...), + so falling through to a single hardcoded number is a footgun on a + fleet with mixed 8k/32k/128k models. + """ + if model and model in self.max_context_tokens_by_model: + return int(self.max_context_tokens_by_model[model]) + try: + resolved = get_max_tokens(model) + if resolved: + return int(resolved) + except Exception as e: + logger.warning( + "get_max_tokens(%r) failed (%s); falling back to " + "default_max_context_tokens=%d", + model, + e, + self.default_max_context_tokens, + ) + return int(self.default_max_context_tokens) + + def _resolve_pop_trailing_tools(self, model: Optional[str]) -> bool: + """Per-model override > global default. See `pop_trailing_tool_messages`.""" + if model and model in self.pop_trailing_tool_messages_by_model: + return bool(self.pop_trailing_tool_messages_by_model[model]) + return self.pop_trailing_tool_messages + def _calculate_safe_completion_tokens( self, max_context_tokens: int, @@ -120,6 +168,110 @@ data: ) data["max_tokens"] = safe_completion_tokens + def _repair_tool_call_pairings(self, messages: list) -> list: + """Strip orphan tool messages and orphan tool_calls from assistant messages. + + After `trim_messages` truncates history, tool-call/tool-response pairings + can end up broken. Mistral/vLLM rejects such conversations. We keep only + tool_calls that have a matching later `role: tool` response, and keep only + `role: tool` messages whose `tool_call_id` was advertised by a surviving + assistant `tool_calls` entry. + """ + satisfied_ids = { + m.get("tool_call_id") + for m in messages + if m.get("role") == "tool" and m.get("tool_call_id") + } + + result = [] + advertised_ids = set() + for msg in messages: + role = msg.get("role") + if role == "assistant": + tcs = msg.get("tool_calls") or [] + if tcs: + kept_tcs = [tc for tc in tcs if tc.get("id") in satisfied_ids] + content = msg.get("content") + if not kept_tcs and not (content or "").strip(): + self._log_debug( + "Dropping assistant message with no content and all tool_calls orphaned" + ) + continue + new_msg = dict(msg) + if kept_tcs: + new_msg["tool_calls"] = kept_tcs + advertised_ids.update(tc["id"] for tc in kept_tcs) + else: + new_msg.pop("tool_calls", None) + self._log_debug( + "Stripped all orphan tool_calls from assistant message" + ) + result.append(new_msg) + else: + result.append(msg) + elif role == "tool": + if msg.get("tool_call_id") in advertised_ids: + result.append(msg) + else: + self._log_debug( + f"Dropping orphan tool message (tool_call_id={msg.get('tool_call_id')})" + ) + else: + result.append(msg) + return result + + def _ensure_last_is_user( + self, messages: list, pop_trailing_tools: bool = False + ) -> list: + """Guarantee the conversation does not end on an assistant message. + + Always appends a user continuation message if the terminus is an + assistant message (some chat templates reject assistant-terminal + conversations under `add_generation_prompt=True`). + + When `pop_trailing_tools=True`, also pops any trailing `role: tool` + messages first — needed for strict templates (e.g. HF Mistral v0.3) + that reject tool-role messages outright. Note: popping can break + tool-call/tool-response pairings; callers should re-run + `_repair_tool_call_pairings` after to clean up. + """ + if not messages: + return messages + result = list(messages) + if pop_trailing_tools: + while result and result[-1].get("role") == "tool": + self._log_debug("Popping trailing tool message from terminus") + result.pop() + if result and result[-1].get("role") == "assistant": + self._log_debug("Appending user continue message after assistant terminus") + result.append({"role": "user", "content": "Please continue"}) + return result + + def _sanitize_messages(self, messages: list, model: Optional[str] = None) -> list: + """Repair tool-call pairings and fix the terminus. Safe to call on + every request. + + Order matters: the second repair has to run *before* we decide + whether to append a user-continue, because popping a trailing tool + may expose an empty assistant whose only `tool_calls` were just + orphaned — repair will drop that assistant entirely, and we don't + want to append `"Please continue"` on top of a now-defunct terminus. + """ + if not messages: + return messages + pop = self._resolve_pop_trailing_tools(model) + result = self._repair_tool_call_pairings(messages) + if pop: + while result and result[-1].get("role") == "tool": + self._log_debug("Popping trailing tool message from terminus") + result.pop() + # Pop may have orphaned tool_calls on the now-terminal assistant. + result = self._repair_tool_call_pairings(result) + if result and result[-1].get("role") == "assistant": + self._log_debug("Appending user continue message after assistant terminus") + result.append({"role": "user", "content": "Please continue"}) + return result + async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, @@ -140,11 +292,8 @@ data: if "messages" in data and data["messages"]: model = data.get("model") - # Get model's context window size - try: - max_context_tokens = get_max_tokens(model) - except: - max_context_tokens = 8192 # Default fallback + # Get model's context window size (per-model map -> litellm -> default). + max_context_tokens = self._resolve_max_context_tokens(model) self._log_debug(f"Model: {model}") self._log_debug(f"Max context tokens: {max_context_tokens}") @@ -179,7 +328,7 @@ data: self._log_debug(f"Safe completion tokens: {safe_completion_tokens}") self._log_debug( - f"Calculation: min({requested_completion}, max(512, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.90))" + f"Calculation: min({requested_completion}, max(256, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.75))" ) # Update completion tokens in the request @@ -204,55 +353,37 @@ data: self._log_debug( f"Input tokens ({current_tokens}) exceed limit ({max_input_tokens}), trimming messages..." ) - # Trim messages to fit data["messages"] = trim_messages( data["messages"], model=model, - max_tokens=int( - max_input_tokens * 0.90 - ), # Trim to 90% of already conservative max + max_tokens=int(max_input_tokens * 0.90), trim_ratio=self.trim_ratio, ) - - # Ensure "ensure_alternating_roles" is fixed for message after trim - data["messages"] = get_completion_messages( - messages=data["messages"], - assistant_continue_message={"role": "assistant", "content": ""}, - user_continue_message={ - "role": "user", - "content": "Please continue", - }, - ensure_alternating_roles=True, - ) - # Recount after trimming - try: - new_token_count = token_counter( - model=model, messages=data["messages"] - ) - self._log_debug(f"After trimming, input tokens: {new_token_count}") - - # RECALCULATE safe completion tokens based on actual trimmed input - safe_completion_tokens = self._calculate_safe_completion_tokens( - max_context_tokens, new_token_count, requested_completion - ) - self._log_debug( - f"Recalculated safe completion tokens after trim: {safe_completion_tokens}" - ) - - # Update the data with the recalculated values - self._update_completion_tokens( - data, safe_completion_tokens, has_max_tokens, has_max_completion - ) - - # Update for final logging - current_tokens = new_token_count - except Exception as e: - self._log_debug(f"Failed to recount tokens after trim: {e}") else: self._log_debug( f"No trimming needed, but current={current_tokens}, max={max_input_tokens}" ) + # Always sanitize: repair tool-call pairings and (optionally, per + # `pop_trailing_tool_messages` config) coerce the terminus into a + # user message so strict chat templates accept the request. + data["messages"] = self._sanitize_messages(data["messages"], model=model) + + # Recount once after sanitize — messages may have grown (user continue + # appended) or shrunk (orphan tool / empty assistant dropped). + try: + final_tokens = token_counter(model=model, messages=data["messages"]) + self._log_debug(f"After sanitize, input tokens: {final_tokens}") + safe_completion_tokens = self._calculate_safe_completion_tokens( + max_context_tokens, final_tokens, requested_completion + ) + self._update_completion_tokens( + data, safe_completion_tokens, has_max_tokens, has_max_completion + ) + current_tokens = final_tokens + except Exception as e: + self._log_debug(f"Failed to recount tokens after sanitize: {e}") + self._log_debug( f"Expected total: input ~{int(current_tokens)} + completion {safe_completion_tokens} + buffer {self.safety_buffer} = ~{int(current_tokens) + safe_completion_tokens + self.safety_buffer}/{max_context_tokens}" )