Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ See [keep a changelog] for information about writing changes to this log.

## [Unreleased]

* —
* Updated trim message litellm guardrail to support tool call and minor bug fixes.

## [0.4.0-rc.0] - 2026-05-06

Expand Down
229 changes: 180 additions & 49 deletions applications/litellm/templates/message-trimming-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,11 @@ data:
message_overflow.py: |-
from typing import Literal, Optional, Union
from litellm.utils import trim_messages, get_max_tokens, token_counter
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_completion_messages,
)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching.caching import DualCache
import logging
import yaml
from pathlib import Path

logger = logging.getLogger(__name__)

Expand All @@ -41,6 +37,27 @@ data:
self.safety_buffer = default_config.get("safety_buffer", 500)
self.debug = default_config.get("debug", False)

# Context-window resolution: per-model map wins, then litellm's
# built-in get_max_tokens, then this global default.
self.default_max_context_tokens = default_config.get(
"default_max_context_tokens", 8192
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we document inline the reason for this magic number?

)
self.max_context_tokens_by_model = default_config.get(
"max_context_tokens_by_model", {}
) or {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The or {} seems redundant?


# Trailing role:tool handling. Default is to PRESERVE trailing tool
# messages — the agent-loop shape `User -> Asst{tool_calls} -> Tool`
# is the normal way to ask the model to reason from tool results.
# Only enable for upstream chat templates that explicitly reject
# tool-terminal conversations (e.g. strict HF Mistral v0.3 template).
self.pop_trailing_tool_messages = bool(
default_config.get("pop_trailing_tool_messages", False)
)
self.pop_trailing_tool_messages_by_model = default_config.get(
"pop_trailing_tool_messages_by_model", {}
) or {}

def _load_config(self):
with open("config.yaml", "r") as file:
config = yaml.safe_load(file)
Expand All @@ -64,6 +81,37 @@ data:
if self.debug:
print(f"[GUARDRAIL] {message}")

def _resolve_max_context_tokens(self, model: Optional[str]) -> int:
"""Resolve the model's context-window size.

Order: per-model override map -> litellm.get_max_tokens -> global default.
litellm's `model_prices_and_context_window.json` doesn't cover every
proxied model name (vLLM, Bedrock variants, custom deployments...),
so falling through to a single hardcoded number is a footgun on a
fleet with mixed 8k/32k/128k models.
"""
if model and model in self.max_context_tokens_by_model:
return int(self.max_context_tokens_by_model[model])
try:
resolved = get_max_tokens(model)
if resolved:
return int(resolved)
except Exception as e:
logger.warning(
"get_max_tokens(%r) failed (%s); falling back to "
"default_max_context_tokens=%d",
model,
e,
self.default_max_context_tokens,
)
return int(self.default_max_context_tokens)

def _resolve_pop_trailing_tools(self, model: Optional[str]) -> bool:
"""Per-model override > global default. See `pop_trailing_tool_messages`."""
if model and model in self.pop_trailing_tool_messages_by_model:
return bool(self.pop_trailing_tool_messages_by_model[model])
return self.pop_trailing_tool_messages

def _calculate_safe_completion_tokens(
self,
max_context_tokens: int,
Expand Down Expand Up @@ -120,6 +168,110 @@ data:
)
data["max_tokens"] = safe_completion_tokens

def _repair_tool_call_pairings(self, messages: list) -> list:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My gut tells me that we could make this at two levels flatter by going negative on the if conditions and using continue a bit more. Makes for a bit easier read.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe make use of the "new" match case instead of the if-else statements.

"""Strip orphan tool messages and orphan tool_calls from assistant messages.

After `trim_messages` truncates history, tool-call/tool-response pairings
can end up broken. Mistral/vLLM rejects such conversations. We keep only
tool_calls that have a matching later `role: tool` response, and keep only
`role: tool` messages whose `tool_call_id` was advertised by a surviving
assistant `tool_calls` entry.
"""
satisfied_ids = {
m.get("tool_call_id")
for m in messages
if m.get("role") == "tool" and m.get("tool_call_id")
}

result = []
advertised_ids = set()
for msg in messages:
role = msg.get("role")
if role == "assistant":
tcs = msg.get("tool_calls") or []
if tcs:
kept_tcs = [tc for tc in tcs if tc.get("id") in satisfied_ids]
content = msg.get("content")
if not kept_tcs and not (content or "").strip():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

content_empty = not (content or "").strip()
if not kept_tcs and content_empty:

self._log_debug(
"Dropping assistant message with no content and all tool_calls orphaned"
)
continue
new_msg = dict(msg)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg.copy() better reflects intent.

if kept_tcs:
new_msg["tool_calls"] = kept_tcs
advertised_ids.update(tc["id"] for tc in kept_tcs)
else:
new_msg.pop("tool_calls", None)
self._log_debug(
"Stripped all orphan tool_calls from assistant message"
)
result.append(new_msg)
else:
result.append(msg)
elif role == "tool":
if msg.get("tool_call_id") in advertised_ids:
result.append(msg)
else:
self._log_debug(
f"Dropping orphan tool message (tool_call_id={msg.get('tool_call_id')})"
)
else:
result.append(msg)
return result

def _ensure_last_is_user(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something, but are we even using this anywhere?

self, messages: list, pop_trailing_tools: bool = False
) -> list:
"""Guarantee the conversation does not end on an assistant message.

Always appends a user continuation message if the terminus is an
assistant message (some chat templates reject assistant-terminal
conversations under `add_generation_prompt=True`).

When `pop_trailing_tools=True`, also pops any trailing `role: tool`
messages first — needed for strict templates (e.g. HF Mistral v0.3)
that reject tool-role messages outright. Note: popping can break
tool-call/tool-response pairings; callers should re-run
`_repair_tool_call_pairings` after to clean up.
"""
if not messages:
return messages
result = list(messages)
if pop_trailing_tools:
while result and result[-1].get("role") == "tool":
self._log_debug("Popping trailing tool message from terminus")
result.pop()
if result and result[-1].get("role") == "assistant":
self._log_debug("Appending user continue message after assistant terminus")
result.append({"role": "user", "content": "Please continue"})
return result

def _sanitize_messages(self, messages: list, model: Optional[str] = None) -> list:
"""Repair tool-call pairings and fix the terminus. Safe to call on
every request.

Order matters: the second repair has to run *before* we decide
whether to append a user-continue, because popping a trailing tool
may expose an empty assistant whose only `tool_calls` were just
orphaned — repair will drop that assistant entirely, and we don't
want to append `"Please continue"` on top of a now-defunct terminus.
"""
if not messages:
return messages
pop = self._resolve_pop_trailing_tools(model)
result = self._repair_tool_call_pairings(messages)
if pop:
while result and result[-1].get("role") == "tool":
self._log_debug("Popping trailing tool message from terminus")
result.pop()
# Pop may have orphaned tool_calls on the now-terminal assistant.
result = self._repair_tool_call_pairings(result)
if result and result[-1].get("role") == "assistant":
self._log_debug("Appending user continue message after assistant terminus")
result.append({"role": "user", "content": "Please continue"})
return result

async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
Expand All @@ -140,11 +292,8 @@ data:
if "messages" in data and data["messages"]:
model = data.get("model")

# Get model's context window size
try:
max_context_tokens = get_max_tokens(model)
except:
max_context_tokens = 8192 # Default fallback
# Get model's context window size (per-model map -> litellm -> default).
max_context_tokens = self._resolve_max_context_tokens(model)

self._log_debug(f"Model: {model}")
self._log_debug(f"Max context tokens: {max_context_tokens}")
Expand Down Expand Up @@ -179,7 +328,7 @@ data:

self._log_debug(f"Safe completion tokens: {safe_completion_tokens}")
self._log_debug(
f"Calculation: min({requested_completion}, max(512, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.90))"
f"Calculation: min({requested_completion}, max(256, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.75))"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not understand this change? Why?

)

# Update completion tokens in the request
Expand All @@ -204,55 +353,37 @@ data:
self._log_debug(
f"Input tokens ({current_tokens}) exceed limit ({max_input_tokens}), trimming messages..."
)
# Trim messages to fit
data["messages"] = trim_messages(
data["messages"],
model=model,
max_tokens=int(
max_input_tokens * 0.90
), # Trim to 90% of already conservative max
max_tokens=int(max_input_tokens * 0.90),
trim_ratio=self.trim_ratio,
)

# Ensure "ensure_alternating_roles" is fixed for message after trim
data["messages"] = get_completion_messages(
messages=data["messages"],
assistant_continue_message={"role": "assistant", "content": ""},
user_continue_message={
"role": "user",
"content": "Please continue",
},
ensure_alternating_roles=True,
)
# Recount after trimming
try:
new_token_count = token_counter(
model=model, messages=data["messages"]
)
self._log_debug(f"After trimming, input tokens: {new_token_count}")

# RECALCULATE safe completion tokens based on actual trimmed input
safe_completion_tokens = self._calculate_safe_completion_tokens(
max_context_tokens, new_token_count, requested_completion
)
self._log_debug(
f"Recalculated safe completion tokens after trim: {safe_completion_tokens}"
)

# Update the data with the recalculated values
self._update_completion_tokens(
data, safe_completion_tokens, has_max_tokens, has_max_completion
)

# Update for final logging
current_tokens = new_token_count
except Exception as e:
self._log_debug(f"Failed to recount tokens after trim: {e}")
else:
self._log_debug(
f"No trimming needed, but current={current_tokens}, max={max_input_tokens}"
)

# Always sanitize: repair tool-call pairings and (optionally, per
# `pop_trailing_tool_messages` config) coerce the terminus into a
# user message so strict chat templates accept the request.
data["messages"] = self._sanitize_messages(data["messages"], model=model)

# Recount once after sanitize — messages may have grown (user continue
# appended) or shrunk (orphan tool / empty assistant dropped).
try:
final_tokens = token_counter(model=model, messages=data["messages"])
self._log_debug(f"After sanitize, input tokens: {final_tokens}")
safe_completion_tokens = self._calculate_safe_completion_tokens(
max_context_tokens, final_tokens, requested_completion
)
self._update_completion_tokens(
data, safe_completion_tokens, has_max_tokens, has_max_completion
)
current_tokens = final_tokens
except Exception as e:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._log_debug(f"Failed to recount tokens after sanitize: {e}")

self._log_debug(
f"Expected total: input ~{int(current_tokens)} + completion {safe_completion_tokens} + buffer {self.safety_buffer} = ~{int(current_tokens) + safe_completion_tokens + self.safety_buffer}/{max_context_tokens}"
)
Expand Down
Loading