-
Notifications
You must be signed in to change notification settings - Fork 2
Updated message trim to handle message better and support tool calling. #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,15 +8,11 @@ data: | |
| message_overflow.py: |- | ||
| from typing import Literal, Optional, Union | ||
| from litellm.utils import trim_messages, get_max_tokens, token_counter | ||
| from litellm.litellm_core_utils.prompt_templates.common_utils import ( | ||
| get_completion_messages, | ||
| ) | ||
| from litellm.integrations.custom_guardrail import CustomGuardrail | ||
| from litellm.proxy._types import UserAPIKeyAuth | ||
| from litellm.caching.caching import DualCache | ||
| import logging | ||
| import yaml | ||
| from pathlib import Path | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -41,6 +37,27 @@ data: | |
| self.safety_buffer = default_config.get("safety_buffer", 500) | ||
| self.debug = default_config.get("debug", False) | ||
|
|
||
| # Context-window resolution: per-model map wins, then litellm's | ||
| # built-in get_max_tokens, then this global default. | ||
| self.default_max_context_tokens = default_config.get( | ||
| "default_max_context_tokens", 8192 | ||
| ) | ||
| self.max_context_tokens_by_model = default_config.get( | ||
| "max_context_tokens_by_model", {} | ||
| ) or {} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| # Trailing role:tool handling. Default is to PRESERVE trailing tool | ||
| # messages — the agent-loop shape `User -> Asst{tool_calls} -> Tool` | ||
| # is the normal way to ask the model to reason from tool results. | ||
| # Only enable for upstream chat templates that explicitly reject | ||
| # tool-terminal conversations (e.g. strict HF Mistral v0.3 template). | ||
| self.pop_trailing_tool_messages = bool( | ||
| default_config.get("pop_trailing_tool_messages", False) | ||
| ) | ||
| self.pop_trailing_tool_messages_by_model = default_config.get( | ||
| "pop_trailing_tool_messages_by_model", {} | ||
| ) or {} | ||
|
|
||
| def _load_config(self): | ||
| with open("config.yaml", "r") as file: | ||
| config = yaml.safe_load(file) | ||
|
|
@@ -64,6 +81,37 @@ data: | |
| if self.debug: | ||
| print(f"[GUARDRAIL] {message}") | ||
|
|
||
| def _resolve_max_context_tokens(self, model: Optional[str]) -> int: | ||
| """Resolve the model's context-window size. | ||
|
|
||
| Order: per-model override map -> litellm.get_max_tokens -> global default. | ||
| litellm's `model_prices_and_context_window.json` doesn't cover every | ||
| proxied model name (vLLM, Bedrock variants, custom deployments...), | ||
| so falling through to a single hardcoded number is a footgun on a | ||
| fleet with mixed 8k/32k/128k models. | ||
| """ | ||
| if model and model in self.max_context_tokens_by_model: | ||
| return int(self.max_context_tokens_by_model[model]) | ||
| try: | ||
| resolved = get_max_tokens(model) | ||
| if resolved: | ||
| return int(resolved) | ||
| except Exception as e: | ||
| logger.warning( | ||
| "get_max_tokens(%r) failed (%s); falling back to " | ||
| "default_max_context_tokens=%d", | ||
| model, | ||
| e, | ||
| self.default_max_context_tokens, | ||
| ) | ||
| return int(self.default_max_context_tokens) | ||
|
|
||
| def _resolve_pop_trailing_tools(self, model: Optional[str]) -> bool: | ||
| """Per-model override > global default. See `pop_trailing_tool_messages`.""" | ||
| if model and model in self.pop_trailing_tool_messages_by_model: | ||
| return bool(self.pop_trailing_tool_messages_by_model[model]) | ||
| return self.pop_trailing_tool_messages | ||
|
|
||
| def _calculate_safe_completion_tokens( | ||
| self, | ||
| max_context_tokens: int, | ||
|
|
@@ -120,6 +168,110 @@ data: | |
| ) | ||
| data["max_tokens"] = safe_completion_tokens | ||
|
|
||
| def _repair_tool_call_pairings(self, messages: list) -> list: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My gut tells me that we could make this at two levels flatter by going negative on the if conditions and using continue a bit more. Makes for a bit easier read. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, maybe make use of the "new" match case instead of the if-else statements. |
||
| """Strip orphan tool messages and orphan tool_calls from assistant messages. | ||
|
|
||
| After `trim_messages` truncates history, tool-call/tool-response pairings | ||
| can end up broken. Mistral/vLLM rejects such conversations. We keep only | ||
| tool_calls that have a matching later `role: tool` response, and keep only | ||
| `role: tool` messages whose `tool_call_id` was advertised by a surviving | ||
| assistant `tool_calls` entry. | ||
| """ | ||
| satisfied_ids = { | ||
| m.get("tool_call_id") | ||
| for m in messages | ||
| if m.get("role") == "tool" and m.get("tool_call_id") | ||
| } | ||
|
|
||
| result = [] | ||
| advertised_ids = set() | ||
| for msg in messages: | ||
| role = msg.get("role") | ||
| if role == "assistant": | ||
| tcs = msg.get("tool_calls") or [] | ||
| if tcs: | ||
| kept_tcs = [tc for tc in tcs if tc.get("id") in satisfied_ids] | ||
| content = msg.get("content") | ||
| if not kept_tcs and not (content or "").strip(): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| self._log_debug( | ||
| "Dropping assistant message with no content and all tool_calls orphaned" | ||
| ) | ||
| continue | ||
| new_msg = dict(msg) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if kept_tcs: | ||
| new_msg["tool_calls"] = kept_tcs | ||
| advertised_ids.update(tc["id"] for tc in kept_tcs) | ||
| else: | ||
| new_msg.pop("tool_calls", None) | ||
| self._log_debug( | ||
| "Stripped all orphan tool_calls from assistant message" | ||
| ) | ||
| result.append(new_msg) | ||
| else: | ||
| result.append(msg) | ||
| elif role == "tool": | ||
| if msg.get("tool_call_id") in advertised_ids: | ||
| result.append(msg) | ||
| else: | ||
| self._log_debug( | ||
| f"Dropping orphan tool message (tool_call_id={msg.get('tool_call_id')})" | ||
| ) | ||
| else: | ||
| result.append(msg) | ||
| return result | ||
|
|
||
| def _ensure_last_is_user( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I might be missing something, but are we even using this anywhere? |
||
| self, messages: list, pop_trailing_tools: bool = False | ||
| ) -> list: | ||
| """Guarantee the conversation does not end on an assistant message. | ||
|
|
||
| Always appends a user continuation message if the terminus is an | ||
| assistant message (some chat templates reject assistant-terminal | ||
| conversations under `add_generation_prompt=True`). | ||
|
|
||
| When `pop_trailing_tools=True`, also pops any trailing `role: tool` | ||
| messages first — needed for strict templates (e.g. HF Mistral v0.3) | ||
| that reject tool-role messages outright. Note: popping can break | ||
| tool-call/tool-response pairings; callers should re-run | ||
| `_repair_tool_call_pairings` after to clean up. | ||
| """ | ||
| if not messages: | ||
| return messages | ||
| result = list(messages) | ||
| if pop_trailing_tools: | ||
| while result and result[-1].get("role") == "tool": | ||
| self._log_debug("Popping trailing tool message from terminus") | ||
| result.pop() | ||
| if result and result[-1].get("role") == "assistant": | ||
| self._log_debug("Appending user continue message after assistant terminus") | ||
| result.append({"role": "user", "content": "Please continue"}) | ||
| return result | ||
|
|
||
| def _sanitize_messages(self, messages: list, model: Optional[str] = None) -> list: | ||
| """Repair tool-call pairings and fix the terminus. Safe to call on | ||
| every request. | ||
|
|
||
| Order matters: the second repair has to run *before* we decide | ||
| whether to append a user-continue, because popping a trailing tool | ||
| may expose an empty assistant whose only `tool_calls` were just | ||
| orphaned — repair will drop that assistant entirely, and we don't | ||
| want to append `"Please continue"` on top of a now-defunct terminus. | ||
| """ | ||
| if not messages: | ||
| return messages | ||
| pop = self._resolve_pop_trailing_tools(model) | ||
| result = self._repair_tool_call_pairings(messages) | ||
| if pop: | ||
| while result and result[-1].get("role") == "tool": | ||
| self._log_debug("Popping trailing tool message from terminus") | ||
| result.pop() | ||
| # Pop may have orphaned tool_calls on the now-terminal assistant. | ||
| result = self._repair_tool_call_pairings(result) | ||
| if result and result[-1].get("role") == "assistant": | ||
| self._log_debug("Appending user continue message after assistant terminus") | ||
| result.append({"role": "user", "content": "Please continue"}) | ||
| return result | ||
|
|
||
| async def async_pre_call_hook( | ||
| self, | ||
| user_api_key_dict: UserAPIKeyAuth, | ||
|
|
@@ -140,11 +292,8 @@ data: | |
| if "messages" in data and data["messages"]: | ||
| model = data.get("model") | ||
|
|
||
| # Get model's context window size | ||
| try: | ||
| max_context_tokens = get_max_tokens(model) | ||
| except: | ||
| max_context_tokens = 8192 # Default fallback | ||
| # Get model's context window size (per-model map -> litellm -> default). | ||
| max_context_tokens = self._resolve_max_context_tokens(model) | ||
|
|
||
| self._log_debug(f"Model: {model}") | ||
| self._log_debug(f"Max context tokens: {max_context_tokens}") | ||
|
|
@@ -179,7 +328,7 @@ data: | |
|
|
||
| self._log_debug(f"Safe completion tokens: {safe_completion_tokens}") | ||
| self._log_debug( | ||
| f"Calculation: min({requested_completion}, max(512, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.90))" | ||
| f"Calculation: min({requested_completion}, max(256, ({max_context_tokens} - {int(current_tokens)} - {self.safety_buffer}) * 0.75))" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not understand this change? Why? |
||
| ) | ||
|
|
||
| # Update completion tokens in the request | ||
|
|
@@ -204,55 +353,37 @@ data: | |
| self._log_debug( | ||
| f"Input tokens ({current_tokens}) exceed limit ({max_input_tokens}), trimming messages..." | ||
| ) | ||
| # Trim messages to fit | ||
| data["messages"] = trim_messages( | ||
| data["messages"], | ||
| model=model, | ||
| max_tokens=int( | ||
| max_input_tokens * 0.90 | ||
| ), # Trim to 90% of already conservative max | ||
| max_tokens=int(max_input_tokens * 0.90), | ||
| trim_ratio=self.trim_ratio, | ||
| ) | ||
|
|
||
| # Ensure "ensure_alternating_roles" is fixed for message after trim | ||
| data["messages"] = get_completion_messages( | ||
| messages=data["messages"], | ||
| assistant_continue_message={"role": "assistant", "content": ""}, | ||
| user_continue_message={ | ||
| "role": "user", | ||
| "content": "Please continue", | ||
| }, | ||
| ensure_alternating_roles=True, | ||
| ) | ||
| # Recount after trimming | ||
| try: | ||
| new_token_count = token_counter( | ||
| model=model, messages=data["messages"] | ||
| ) | ||
| self._log_debug(f"After trimming, input tokens: {new_token_count}") | ||
|
|
||
| # RECALCULATE safe completion tokens based on actual trimmed input | ||
| safe_completion_tokens = self._calculate_safe_completion_tokens( | ||
| max_context_tokens, new_token_count, requested_completion | ||
| ) | ||
| self._log_debug( | ||
| f"Recalculated safe completion tokens after trim: {safe_completion_tokens}" | ||
| ) | ||
|
|
||
| # Update the data with the recalculated values | ||
| self._update_completion_tokens( | ||
| data, safe_completion_tokens, has_max_tokens, has_max_completion | ||
| ) | ||
|
|
||
| # Update for final logging | ||
| current_tokens = new_token_count | ||
| except Exception as e: | ||
| self._log_debug(f"Failed to recount tokens after trim: {e}") | ||
| else: | ||
| self._log_debug( | ||
| f"No trimming needed, but current={current_tokens}, max={max_input_tokens}" | ||
| ) | ||
|
|
||
| # Always sanitize: repair tool-call pairings and (optionally, per | ||
| # `pop_trailing_tool_messages` config) coerce the terminus into a | ||
| # user message so strict chat templates accept the request. | ||
| data["messages"] = self._sanitize_messages(data["messages"], model=model) | ||
|
|
||
| # Recount once after sanitize — messages may have grown (user continue | ||
| # appended) or shrunk (orphan tool / empty assistant dropped). | ||
| try: | ||
| final_tokens = token_counter(model=model, messages=data["messages"]) | ||
| self._log_debug(f"After sanitize, input tokens: {final_tokens}") | ||
| safe_completion_tokens = self._calculate_safe_completion_tokens( | ||
| max_context_tokens, final_tokens, requested_completion | ||
| ) | ||
| self._update_completion_tokens( | ||
| data, safe_completion_tokens, has_max_tokens, has_max_completion | ||
| ) | ||
| current_tokens = final_tokens | ||
| except Exception as e: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can get more specific here it seems: https://github.com/BerriAI/litellm/blob/144279eb57edb6cc0a97ad47c9da33b910f70dfa/litellm/litellm_core_utils/token_counter.py#L360 |
||
| self._log_debug(f"Failed to recount tokens after sanitize: {e}") | ||
|
|
||
| self._log_debug( | ||
| f"Expected total: input ~{int(current_tokens)} + completion {safe_completion_tokens} + buffer {self.safety_buffer} = ~{int(current_tokens) + safe_completion_tokens + self.safety_buffer}/{max_context_tokens}" | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we document inline the reason for this magic number?