diff --git a/tests/test_client_multimodal_types.py b/tests/test_client_multimodal_types.py index d51c38262..22a6e8c6f 100644 --- a/tests/test_client_multimodal_types.py +++ b/tests/test_client_multimodal_types.py @@ -181,6 +181,33 @@ async def test_anthropic_from_native_response_extracts_usage(): assert response.usage.reasoning_tokens == 0 +@pytest.mark.asyncio +async def test_anthropic_from_native_response_extracts_cache_usage(): + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + client = AnthropicMessagesClient(object()) + native_response = SimpleNamespace( + id="msg_cache", + model="claude-haiku-4-5", + stop_reason="end_turn", + content=[SimpleNamespace(type="text", text="Hello!")], + usage=SimpleNamespace( + input_tokens=42, + output_tokens=17, + cache_creation_input_tokens=8, + cache_read_input_tokens=100, + ), + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 50 + assert response.usage.completion_tokens == 17 + assert response.usage.cached_input_tokens == 100 + assert response.usage.total_tokens == 67 + + @pytest.mark.asyncio async def test_anthropic_from_native_response_always_parses_reasoning(): pytest.importorskip("anthropic") diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py new file mode 100644 index 000000000..7f9560aa9 --- /dev/null +++ b/tests/test_prompt_cache_utils.py @@ -0,0 +1,26 @@ +from verifiers.types import ClientConfig +from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_kwargs + + +def test_anthropic_cache_control_hint_is_default_only(): + extra_kwargs = apply_prompt_cache_to_kwargs( + config=ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com/v1", + ), + sampling_args={"max_tokens": 16}, + extra_kwargs={}, + ) + + assert extra_kwargs == {"cache_control": {"type": "ephemeral"}} + + extra_kwargs = apply_prompt_cache_to_kwargs( + config=ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com/v1", + ), + sampling_args={"cache_control": {"type": "custom"}}, + extra_kwargs={}, + ) + + assert extra_kwargs == {} diff --git a/verifiers/clients/anthropic_messages_client.py b/verifiers/clients/anthropic_messages_client.py index 31611d81d..71258822b 100644 --- a/verifiers/clients/anthropic_messages_client.py +++ b/verifiers/clients/anthropic_messages_client.py @@ -468,6 +468,18 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens + cached_input_tokens = getattr(response.usage, "cache_read_input_tokens", None) + cache_creation_input_tokens = getattr( + response.usage, "cache_creation_input_tokens", None + ) + if isinstance(cache_creation_input_tokens, int) and not isinstance( + cache_creation_input_tokens, bool + ): + input_tokens += cache_creation_input_tokens + if not isinstance(cached_input_tokens, int) or isinstance( + cached_input_tokens, bool + ): + cached_input_tokens = None return Response( id=response.id, @@ -478,6 +490,7 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: completion_tokens=output_tokens, reasoning_tokens=0, total_tokens=input_tokens + output_tokens, + cached_input_tokens=cached_input_tokens, ), message=ResponseMessage( content=content, diff --git a/verifiers/clients/client.py b/verifiers/clients/client.py index 7991670e8..b98806faf 100644 --- a/verifiers/clients/client.py +++ b/verifiers/clients/client.py @@ -19,6 +19,7 @@ SamplingArgs, Tool, ) +from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_kwargs if TYPE_CHECKING: pass @@ -126,6 +127,11 @@ async def get_response( native_prompt, extra_kwargs = await self.to_native_prompt(prompt) native_tools = await self.to_native_tools(tools) + extra_kwargs = apply_prompt_cache_to_kwargs( + config=self._config, + sampling_args=sampling_args, + extra_kwargs=extra_kwargs, + ) native_response = await self.get_native_response( native_prompt, model, diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index d932d32d3..a8915285d 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -423,13 +423,29 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: completion_tokens, int ): return None + prompt_details = get_usage_field(usage, "prompt_tokens_details") + if prompt_details is None: + prompt_details = get_usage_field(usage, "input_tokens_details") + cached_tokens = None + if prompt_details is not None: + reported_cached_tokens = get_usage_field( + prompt_details, "cached_tokens" + ) + if isinstance(reported_cached_tokens, int) and not isinstance( + reported_cached_tokens, bool + ): + cached_tokens = reported_cached_tokens + prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens + elif cached_tokens is not None: + total_tokens = max(0, total_tokens - cached_tokens) return Usage( prompt_tokens=prompt_tokens, reasoning_tokens=0, completion_tokens=completion_tokens, total_tokens=total_tokens, + cached_input_tokens=cached_tokens, ) def parse_is_truncated(response: OpenAIChatResponse) -> bool: diff --git a/verifiers/clients/openai_responses_client.py b/verifiers/clients/openai_responses_client.py index eb599d5fc..8555a09dd 100644 --- a/verifiers/clients/openai_responses_client.py +++ b/verifiers/clients/openai_responses_client.py @@ -385,8 +385,21 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: completion_tokens, int ): return None + input_details = get_usage_field(usage, "input_tokens_details") + if input_details is None: + input_details = get_usage_field(usage, "prompt_tokens_details") + cached_tokens = None + if input_details is not None: + reported_cached_tokens = get_usage_field(input_details, "cached_tokens") + if isinstance(reported_cached_tokens, int) and not isinstance( + reported_cached_tokens, bool + ): + cached_tokens = reported_cached_tokens + prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens + elif cached_tokens is not None: + total_tokens = max(0, total_tokens - cached_tokens) if not isinstance(reasoning_tokens, int): reasoning_tokens = 0 return Usage( @@ -394,6 +407,7 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: reasoning_tokens=reasoning_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cached_input_tokens=cached_tokens, ) def parse_is_truncated(response: OpenAIResponsesNativeResponse) -> bool: diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index ed379f086..81a02bdb9 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -480,10 +480,15 @@ def get_state_usage(self, state: State) -> TokenUsage | None: usage = state.get("usage") if isinstance(usage, Mapping): try: - return { + out: TokenUsage = { "input_tokens": float(usage.get("input_tokens", 0.0)), "output_tokens": float(usage.get("output_tokens", 0.0)), } + for key in ("cached_input_tokens",): + value = usage.get(key) + if value is not None: + out[key] = float(value) + return out except (TypeError, ValueError): return None return None diff --git a/verifiers/types.py b/verifiers/types.py index 0e4f63c19..91ec44ccb 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -181,6 +181,7 @@ class Usage(CustomBaseModel): reasoning_tokens: int completion_tokens: int total_tokens: int + cached_input_tokens: int | None = None class RoutedExpertsPayload(TypedDict): @@ -249,6 +250,7 @@ class TrajectoryStepTokens(TypedDict): class TokenUsage(TypedDict): input_tokens: float output_tokens: float + cached_input_tokens: NotRequired[float] final_input_tokens: NotRequired[float] final_output_tokens: NotRequired[float] diff --git a/verifiers/utils/eval_display.py b/verifiers/utils/eval_display.py index 0f2c34c18..2d74ad326 100644 --- a/verifiers/utils/eval_display.py +++ b/verifiers/utils/eval_display.py @@ -410,6 +410,9 @@ def _make_tokens_row( "input": format_numeric(usage.get("input_tokens", 0.0)), "output": format_numeric(usage.get("output_tokens", 0.0)), } + cached = usage.get("cached_input_tokens") + if cached is not None: + kv["cached_input"] = format_numeric(cached) inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") if inp is not None: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 5e7c8651d..7795327bf 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -817,6 +817,8 @@ def print_usage(results: GenerateOutputs): usage_count = 0 input_total = 0.0 output_total = 0.0 + cached_input_total = 0.0 + cached_input_count = 0 final_input_total = 0.0 final_output_total = 0.0 context_count = 0 @@ -827,6 +829,10 @@ def print_usage(results: GenerateOutputs): usage_count += 1 input_total += float(token_usage.get("input_tokens", 0.0)) output_total += float(token_usage.get("output_tokens", 0.0)) + cached = token_usage.get("cached_input_tokens") + if cached is not None: + cached_input_total += float(cached) + cached_input_count += 1 inp = token_usage.get("final_input_tokens") out = token_usage.get("final_output_tokens") if inp is not None and out is not None: @@ -840,6 +846,8 @@ def print_usage(results: GenerateOutputs): input_tokens=input_total / usage_count, output_tokens=output_total / usage_count, ) + if cached_input_count > 0: + usage["cached_input_tokens"] = cached_input_total / cached_input_count if context_count > 0: usage["final_input_tokens"] = final_input_total / context_count usage["final_output_tokens"] = final_output_total / context_count @@ -851,6 +859,9 @@ def print_usage(results: GenerateOutputs): print("Usage:") print(f"input_tokens (avg): {float(usage.get('input_tokens', 0.0)):.3f}") + cached = usage.get("cached_input_tokens") + if cached is not None: + print(f"cached_input_tokens (avg): {float(cached):.3f}") print(f"output_tokens (avg): {float(usage.get('output_tokens', 0.0)):.3f}") inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") diff --git a/verifiers/utils/interception_utils.py b/verifiers/utils/interception_utils.py index 47897f069..991c999d4 100644 --- a/verifiers/utils/interception_utils.py +++ b/verifiers/utils/interception_utils.py @@ -868,6 +868,8 @@ def serialize_anthropic_message_response(response: Response) -> dict[str, Any]: "input_tokens": response.usage.prompt_tokens, "output_tokens": response.usage.completion_tokens, } + if response.usage.cached_input_tokens is not None: + usage["cache_read_input_tokens"] = response.usage.cached_input_tokens return { "id": response.id, "type": "message", diff --git a/verifiers/utils/metric_utils.py b/verifiers/utils/metric_utils.py index 6c8b543f6..45ffff29e 100644 --- a/verifiers/utils/metric_utils.py +++ b/verifiers/utils/metric_utils.py @@ -92,6 +92,12 @@ class OutputTokensMetric(TokenUsageKeyMetric): _key = "output_tokens" +class CachedInputTokensMetric(TokenUsageKeyMetric): + """Mean cached_input_tokens per output.""" + + _key = "cached_input_tokens" + + class FinalInputTokensMetric(TokenUsageKeyMetric): """Mean final_input_tokens (non-completion context tokens) per output.""" diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py new file mode 100644 index 000000000..822216a88 --- /dev/null +++ b/verifiers/utils/prompt_cache_utils.py @@ -0,0 +1,51 @@ +from collections.abc import Mapping +from typing import Any +from urllib.parse import urlsplit + +from verifiers.types import ClientConfig + +ANTHROPIC_ORIGINS = frozenset({"https://api.anthropic.com"}) + + +def endpoint_origin(api_base_url: str) -> str | None: + parsed = urlsplit(api_base_url) + if not parsed.scheme or not parsed.hostname: + return None + scheme = parsed.scheme.lower() + host = parsed.hostname.lower() + port = parsed.port + netloc = host + if ":" in host: + netloc = f"[{host}]" + if port is not None and not ( + (scheme == "https" and port == 443) or (scheme == "http" and port == 80) + ): + netloc = f"{netloc}:{port}" + return f"{scheme}://{netloc}" + + +def uses_official_anthropic_messages(config: ClientConfig | None) -> bool: + return ( + config is not None + and config.client_type == "anthropic_messages" + and endpoint_origin(config.api_base_url) in ANTHROPIC_ORIGINS + ) + + +def _cache_control_payload() -> dict[str, str]: + return {"type": "ephemeral"} + + +def apply_prompt_cache_to_kwargs( + *, + config: ClientConfig | None, + sampling_args: Mapping[str, Any], + extra_kwargs: Mapping[str, Any], +) -> dict[str, Any]: + updated_extra_kwargs = dict(extra_kwargs) + if ( + uses_official_anthropic_messages(config) + and "cache_control" not in sampling_args + ): + updated_extra_kwargs.setdefault("cache_control", _cache_control_payload()) + return updated_extra_kwargs diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 0e34690c1..a70e86611 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -29,6 +29,7 @@ serialize_messages_for_output, ) from verifiers.utils.metric_utils import ( + CachedInputTokensMetric, EnvMetrics, ErrorRateMetric, FinalInputTokensMetric, @@ -128,6 +129,13 @@ def _token_usage_from_mapping(value: object, context: str) -> TokenUsage | None: for key in ("final_input_tokens", "final_output_tokens"): if key in mapping_value and mapping_value[key] is not None: usage[key] = _token_count(mapping_value[key], f"{context}.{key}") + if ( + "cached_input_tokens" in mapping_value + and mapping_value["cached_input_tokens"] is not None + ): + usage["cached_input_tokens"] = _token_count( + mapping_value["cached_input_tokens"], f"{context}.cached_input_tokens" + ) return usage @@ -136,6 +144,7 @@ def _token_usage_from_trajectory(trajectory: object) -> TokenUsage | None: return None input_tokens = 0 output_tokens = 0 + cached_input_tokens = 0 usage_seen = False for index, step in enumerate(trajectory): if not isinstance(step, Mapping): @@ -150,12 +159,16 @@ def _token_usage_from_trajectory(trajectory: object) -> TokenUsage | None: step_input_tokens, step_output_tokens = response_usage_tokens(response) input_tokens += step_input_tokens output_tokens += step_output_tokens + cached_input_tokens += response.usage.cached_input_tokens or 0 if not usage_seen: return None - return TokenUsage( + usage = TokenUsage( input_tokens=float(input_tokens), output_tokens=float(output_tokens), ) + if cached_input_tokens > 0: + usage["cached_input_tokens"] = float(cached_input_tokens) + return usage def _extract_state_token_usage(state: State) -> TokenUsage | None: @@ -230,6 +243,8 @@ def state_to_output( "input_tokens": usage.get("input_tokens", 0.0), "output_tokens": usage.get("output_tokens", 0.0), } + if usage.get("cached_input_tokens") is not None: + token_usage["cached_input_tokens"] = usage["cached_input_tokens"] # Add context token metrics from trajectory trajectory = state.get("trajectory", []) if isinstance(trajectory, list): @@ -567,6 +582,7 @@ def __init__( self.env_metrics = EnvMetrics() self.input_tokens = InputTokensMetric() self.output_tokens = OutputTokensMetric() + self.cached_input_tokens = CachedInputTokensMetric() self.final_input_tokens = FinalInputTokensMetric() self.final_output_tokens = FinalOutputTokensMetric() self.pass_at_k = PassAtKMetric(rollouts_per_example, threshold=pass_threshold) @@ -617,6 +633,7 @@ def add_outputs(self, new_outputs: list[RolloutOutput]) -> None: self.env_metrics.add_outputs(new_outputs) self.input_tokens.add_outputs(new_outputs) self.output_tokens.add_outputs(new_outputs) + self.cached_input_tokens.add_outputs(new_outputs) self.final_input_tokens.add_outputs(new_outputs) self.final_output_tokens.add_outputs(new_outputs) self.pass_at_k.add_outputs(new_outputs) @@ -641,6 +658,8 @@ def build_metadata(self) -> GenerateMetadata: input_tokens=self.input_tokens.compute(), output_tokens=self.output_tokens.compute(), ) + if self.cached_input_tokens.count > 0: + usage["cached_input_tokens"] = self.cached_input_tokens.compute() if self.final_input_tokens.count > 0: usage["final_input_tokens"] = self.final_input_tokens.compute() usage["final_output_tokens"] = self.final_output_tokens.compute() diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index c8e3921b0..d6ffe5f3d 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -39,30 +39,45 @@ def increment( input_tokens: int | float = 0, output_tokens: int | float = 0, *, + cached_input_tokens: int | float | None = None, mark_seen: bool = True, ) -> None: input_delta = float(input_tokens or 0.0) output_delta = float(output_tokens or 0.0) - if input_delta < 0 or output_delta < 0: + cached_input_delta = float(cached_input_tokens or 0.0) + if input_delta < 0 or output_delta < 0 or cached_input_delta < 0: raise ValueError("Token usage increments must be non-negative.") if mark_seen: self._usage_seen = True self._usage_totals["input_tokens"] += input_delta self._usage_totals["output_tokens"] += output_delta + if cached_input_tokens is not None: + self._usage_totals["cached_input_tokens"] = ( + self._usage_totals.get("cached_input_tokens", 0.0) + cached_input_delta + ) def increment_from_response(self, response: Response) -> None: if response.usage is None: return input_tokens, output_tokens = response_usage_tokens(response) - self.increment(input_tokens, output_tokens, mark_seen=True) + self.increment( + input_tokens, + output_tokens, + cached_input_tokens=response.usage.cached_input_tokens, + mark_seen=True, + ) def snapshot(self) -> TokenUsage | None: if not self._usage_seen: return None - return { + usage: TokenUsage = { "input_tokens": self._usage_totals["input_tokens"], "output_tokens": self._usage_totals["output_tokens"], } + cached_input_tokens = self._usage_totals.get("cached_input_tokens", 0.0) + if cached_input_tokens > 0: + usage["cached_input_tokens"] = cached_input_tokens + return usage def compute_context_token_metrics( @@ -95,7 +110,11 @@ def compute_context_token_metrics( if not isinstance(response, Response) or response.usage is None: continue prompt_tokens, completion_tokens = response_usage_tokens(response) - last_step_total = prompt_tokens + completion_tokens + last_step_total = ( + prompt_tokens + + (response.usage.cached_input_tokens or 0) + + completion_tokens + ) found = True break