From fce4d3da93147c31339de909b76900be0a7d7efa Mon Sep 17 00:00:00 2001 From: William Brown Date: Sat, 16 May 2026 12:29:24 -0500 Subject: [PATCH 1/9] Add prompt cache handling and token accounting --- docs/evaluation.md | 2 + docs/reference.md | 6 +- tests/test_endpoint_registry.py | 32 ++ tests/test_eval_cli.py | 48 ++ tests/test_prompt_cache_utils.py | 498 ++++++++++++++++++ .../clients/anthropic_messages_client.py | 12 + verifiers/clients/client.py | 15 + .../clients/openai_chat_completions_client.py | 25 + .../clients/openai_completions_client.py | 8 + verifiers/clients/openai_responses_client.py | 8 + verifiers/envs/environment.py | 7 +- verifiers/scripts/eval.py | 18 + verifiers/scripts/tui.py | 24 + verifiers/types.py | 7 + verifiers/utils/eval_display.py | 6 + verifiers/utils/eval_utils.py | 31 ++ verifiers/utils/interception_utils.py | 13 +- verifiers/utils/metric_utils.py | 12 + verifiers/utils/prompt_cache_utils.py | 220 ++++++++ verifiers/utils/save_utils.py | 52 +- verifiers/utils/usage_utils.py | 144 ++++- verifiers/v1/env.py | 22 +- 22 files changed, 1167 insertions(+), 43 deletions(-) create mode 100644 tests/test_prompt_cache_utils.py create mode 100644 verifiers/utils/prompt_cache_utils.py diff --git a/docs/evaluation.md b/docs/evaluation.md index 6e09d35c4..c9c75ef1e 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -166,6 +166,8 @@ In `[[eval]]` TOML configs you can set extra headers as `headers = { ... }` and/ For per-request headers that need to vary per rollout (e.g. sticky DP-aware routing keyed off `example_id` or `trajectory_id`), use `headers_from_state = { "X-Name" = "state_key" }` and/or `header_from_state = ["X-Name: state_key", ...]` (same form as repeated `--header-from-state`). The value for each request is resolved at send time as `state[state_key]`. If unset, `X-Session-ID` defaults to `example_id`. +Prompt caching is automatic for supported official providers inferred from `url` and `api_client_type`: OpenAI (`https://api.openai.com`), Anthropic (`https://api.anthropic.com`), and OpenRouter (`https://openrouter.ai`). Unsupported providers run unchanged. Set `prompt_cache = false` on an endpoint row or `[[eval]]` only when you need to disable this behavior for a specific run. + To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. Then use the alias directly: diff --git a/docs/reference.md b/docs/reference.md index 0586d9af3..7f0c36438 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -243,14 +243,18 @@ Derivations: class TokenUsage(TypedDict, total=False): input_tokens: float output_tokens: float + cached_input_tokens: float + cache_write_input_tokens: float final_input_tokens: float final_output_tokens: float ``` | Field | Description | |-------|-------------| -| `input_tokens` | Sum of prompt tokens across all turns. Shared context is counted each time it appears in a prompt. | +| `input_tokens` | Sum of non-cache-hit prompt tokens across all turns. Shared uncached context is counted each time it appears in a prompt. | | `output_tokens` | Sum of completion tokens across all turns. | +| `cached_input_tokens` | Sum of prompt tokens served from provider prompt cache, when reported by the provider. | +| `cache_write_input_tokens` | Sum of prompt tokens written to provider prompt cache, when reported by the provider. | | `final_input_tokens` | Non-completion tokens in the final turn's context (system prompts, user messages, tool results, etc.). | | `final_output_tokens` | Completion tokens in the final turn's context. Equals `output_tokens` for single-turn rollouts. | diff --git a/tests/test_endpoint_registry.py b/tests/test_endpoint_registry.py index 6a97eed5e..2ee61b0f3 100644 --- a/tests/test_endpoint_registry.py +++ b/tests/test_endpoint_registry.py @@ -243,6 +243,38 @@ def test_load_endpoints_toml_accepts_extra_headers_alias(tmp_path: Path): assert endpoints["proxy"][0]["extra_headers"] == {"X-A": "a"} +def test_load_endpoints_toml_accepts_prompt_cache_opt_out(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "openai"\n' + 'model = "m"\n' + 'url = "https://api.openai.com/v1"\n' + 'key = "OPENAI_API_KEY"\n' + "prompt_cache = false\n", + encoding="utf-8", + ) + + endpoints = load_endpoints(str(registry_path)) + + assert endpoints["openai"][0]["prompt_cache"] is False + + +def test_load_endpoints_toml_rejects_non_bool_prompt_cache(tmp_path: Path): + registry_path = tmp_path / "endpoints.toml" + registry_path.write_text( + "[[endpoint]]\n" + 'endpoint_id = "openai"\n' + 'model = "m"\n' + 'url = "https://api.openai.com/v1"\n' + 'key = "OPENAI_API_KEY"\n' + 'prompt_cache = "yes"\n', + encoding="utf-8", + ) + + assert load_endpoints(str(registry_path)) == {} + + def test_load_endpoints_toml_rejects_headers_and_extra_headers_together( tmp_path: Path, ): diff --git a/tests/test_eval_cli.py b/tests/test_eval_cli.py index ad49a0b49..40675fd1e 100644 --- a/tests/test_eval_cli.py +++ b/tests/test_eval_cli.py @@ -319,6 +319,54 @@ def test_cli_registry_headers_merged_with_eval_toml(tmp_path, monkeypatch, run_c } +def test_cli_registry_prompt_cache_opt_out_flows_to_client_config( + monkeypatch, run_cli +): + captured = run_cli( + monkeypatch, + { + "model": "openai", + "api_base_url": None, + "api_key_var": None, + }, + endpoints={ + "openai": [ + { + "model": "gpt-5.4-mini", + "key": "OPENAI_API_KEY", + "url": "https://api.openai.com/v1", + "prompt_cache": False, + } + ] + }, + ) + + assert captured["configs"][0].client_config.prompt_cache is False + + +def test_cli_toml_prompt_cache_opt_out_overrides_registry(monkeypatch, run_cli): + captured = run_cli( + monkeypatch, + { + "model": "openai", + "api_base_url": None, + "api_key_var": None, + "prompt_cache": False, + }, + endpoints={ + "openai": [ + { + "model": "gpt-5.4-mini", + "key": "OPENAI_API_KEY", + "url": "https://api.openai.com/v1", + } + ] + }, + ) + + assert captured["configs"][0].client_config.prompt_cache is False + + def test_cli_multi_variant_preserves_per_row_registry_headers(monkeypatch, run_cli): captured = run_cli( monkeypatch, diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py new file mode 100644 index 000000000..07ee505c7 --- /dev/null +++ b/tests/test_prompt_cache_utils.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +import asyncio +from types import SimpleNamespace + +import pytest + +import verifiers.v1 as vf +from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient +from verifiers.clients.client import Client +from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient +from verifiers.types import ClientConfig, Response, ResponseMessage, Usage +from verifiers.utils.prompt_cache_utils import ( + EndpointIdentity, + apply_prompt_cache_to_request, + resolve_prompt_cache_policy, + should_prefire_prompt_cache_group, +) +from verifiers.utils.save_utils import state_to_output +from verifiers.utils.usage_utils import extract_usage_token_details + + +class RecordingClient(Client): + def __init__(self, config: ClientConfig): + super().__init__(config) + self.request = {} + + def setup_client(self, config): + return object() + + async def to_native_tool(self, tool): + return tool + + async def to_native_prompt(self, messages): + return messages, {} + + async def get_native_response(self, prompt, model, sampling_args, tools=None, **kwargs): + self.request = { + "prompt": prompt, + "model": model, + "sampling_args": sampling_args, + "tools": tools, + "kwargs": kwargs, + } + return object() + + async def raise_from_native_response(self, response): + _ = response + + async def from_native_response(self, response): + _ = response + return Response( + id="resp", + created=0, + model="model", + usage=None, + message=ResponseMessage( + content="ok", + finish_reason="stop", + is_truncated=False, + ), + ) + + async def close(self) -> None: + pass + + +class GroupOrderClient(Client): + def __init__(self, config: ClientConfig): + super().__init__(config) + self.first_active = False + self.started_during_first: list[int] = [] + + def setup_client(self, config): + return object() + + async def get_response(self, prompt, model, sampling_args, tools=None, **kwargs): + _ = prompt, model, sampling_args, tools + state = kwargs["state"] + rollout_index = int(state["rollout_index"]) + if rollout_index == 0: + self.first_active = True + await asyncio.sleep(0.01) + self.first_active = False + elif self.first_active: + self.started_during_first.append(rollout_index) + return Response( + id=f"resp-{rollout_index}", + created=0, + model="model", + usage=Usage( + prompt_tokens=1, + reasoning_tokens=0, + completion_tokens=1, + total_tokens=2, + ), + message=ResponseMessage( + content="ok", + finish_reason="stop", + is_truncated=False, + ), + ) + + async def to_native_tool(self, tool): + return tool + + async def to_native_prompt(self, messages): + return messages, {} + + async def get_native_response(self, prompt, model, sampling_args, tools=None, **kwargs): + raise AssertionError("get_response is implemented directly") + + async def raise_from_native_response(self, response): + _ = response + + async def from_native_response(self, response): + _ = response + + async def close(self) -> None: + pass + + +class IndexedTaskset(vf.Taskset): + async def init_group(self, task, num_rollouts): + tasks, states = await super().init_group(task, num_rollouts) + for index, state in enumerate(states): + state["rollout_index"] = index + return tasks, states + + +def test_endpoint_identity_normalizes_official_origins(): + identity = EndpointIdentity.from_url( + "https://api.openai.com/v1", "openai_chat_completions" + ) + + assert identity is not None + assert identity.origin == "https://api.openai.com" + assert identity.host == "api.openai.com" + assert identity.path == "/v1" + + +def test_prompt_cache_policy_is_inferred_from_url_and_type(): + assert ( + resolve_prompt_cache_policy( + ClientConfig( + client_type="openai_responses", + api_base_url="https://api.openai.com/v1", + ), + "gpt-5.4-mini", + ).mode + == "implicit" + ) + assert ( + resolve_prompt_cache_policy( + ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com", + ), + "claude-sonnet-4-5", + ).mode + == "anthropic_top_level" + ) + assert ( + resolve_prompt_cache_policy( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://openrouter.ai/api/v1", + ), + "anthropic/claude-sonnet-4.5", + ).mode + == "openrouter_anthropic_top_level" + ) + assert ( + resolve_prompt_cache_policy( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.example.com/v1", + ), + "model", + ).mode + == "disabled" + ) + + +def test_prompt_cache_false_disables_inferred_provider_policy(): + policy = resolve_prompt_cache_policy( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + prompt_cache=False, + ), + "gpt-5.4-mini", + ) + + assert policy.mode == "disabled" + assert not policy.prefire_groups + + +def test_anthropic_request_policy_adds_top_level_cache_control(): + native_prompt, native_tools, sampling_args, extra_kwargs = ( + apply_prompt_cache_to_request( + config=ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com", + ), + model="claude-sonnet-4-5", + native_prompt=[{"role": "user", "content": "question"}], + native_tools=None, + sampling_args={"max_tokens": 16}, + extra_kwargs={}, + ) + ) + + assert native_prompt == [{"role": "user", "content": "question"}] + assert native_tools is None + assert sampling_args == {"max_tokens": 16} + assert extra_kwargs["cache_control"] == {"type": "ephemeral"} + + +def test_openrouter_anthropic_policy_uses_extra_body_cache_control(): + native_prompt, native_tools, sampling_args, extra_kwargs = ( + apply_prompt_cache_to_request( + config=ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://openrouter.ai/api/v1", + ), + model="anthropic/claude-sonnet-4.5", + native_prompt=[{"role": "user", "content": "question"}], + native_tools=[], + sampling_args={"max_tokens": 16, "extra_body": {"foo": "bar"}}, + extra_kwargs={}, + ) + ) + + assert native_prompt == [{"role": "user", "content": "question"}] + assert native_tools == [] + assert extra_kwargs == {} + assert sampling_args["extra_body"] == { + "foo": "bar", + "cache_control": {"type": "ephemeral"}, + } + + +def test_openai_policy_does_not_mutate_request(): + native_prompt, native_tools, sampling_args, extra_kwargs = ( + apply_prompt_cache_to_request( + config=ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + ), + model="gpt-5.4-mini", + native_prompt=[{"role": "user", "content": "question"}], + native_tools=None, + sampling_args={"max_tokens": 16}, + extra_kwargs={}, + ) + ) + + assert native_prompt == [{"role": "user", "content": "question"}] + assert native_tools is None + assert sampling_args == {"max_tokens": 16} + assert extra_kwargs == {} + + +def test_group_prefire_is_tied_to_cache_policy(): + assert should_prefire_prompt_cache_group( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + ), + "gpt-5.4-mini", + 2, + ) + assert not should_prefire_prompt_cache_group( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + prompt_cache=False, + ), + "gpt-5.4-mini", + 2, + ) + assert not should_prefire_prompt_cache_group( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.example.com/v1", + ), + "model", + 2, + ) + assert not should_prefire_prompt_cache_group( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + ), + "gpt-5.4-mini", + 1, + ) + + +@pytest.mark.asyncio +async def test_client_request_hook_applies_prompt_cache_policy(): + client = RecordingClient( + ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com", + ) + ) + + await client.get_response( + prompt=[], + model="claude-sonnet-4-5", + sampling_args={"max_tokens": 16}, + ) + + assert client.request["kwargs"]["cache_control"] == {"type": "ephemeral"} + + +@pytest.mark.asyncio +async def test_v1_group_prefire_serializes_first_rollout_for_cached_provider(): + client = GroupOrderClient( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.openai.com/v1", + ) + ) + env = vf.Env( + taskset=IndexedTaskset(source=[{"question": "q"}]), + harness=vf.Harness(max_turns=1), + ) + + await env._run_group_states( + [ + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + ], + client, + "gpt-5.4-mini", + {}, + ) + + assert client.started_during_first == [] + + +@pytest.mark.asyncio +async def test_v1_group_prefire_is_skipped_for_generic_provider(): + client = GroupOrderClient( + ClientConfig( + client_type="openai_chat_completions", + api_base_url="https://api.example.com/v1", + ) + ) + env = vf.Env( + taskset=IndexedTaskset(source=[{"question": "q"}]), + harness=vf.Harness(max_turns=1), + ) + + await env._run_group_states( + [ + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, + ], + client, + "model", + {}, + ) + + assert client.started_during_first + + +@pytest.mark.asyncio +async def test_openai_usage_splits_cached_input_tokens(): + client = OpenAIChatCompletionsClient(object()) + message = SimpleNamespace( + content="ok", + tool_calls=None, + model_dump=lambda: {}, + ) + native_response = SimpleNamespace( + id="resp", + created=0, + model="gpt-5.4-mini", + usage=SimpleNamespace( + prompt_tokens=100, + completion_tokens=5, + total_tokens=105, + prompt_tokens_details=SimpleNamespace( + cached_tokens=80, + cache_write_tokens=10, + ), + ), + choices=[ + SimpleNamespace( + message=message, + finish_reason="stop", + ) + ], + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 20 + assert response.usage.cached_input_tokens == 80 + assert response.usage.cache_write_input_tokens == 10 + assert response.usage.total_tokens == 25 + + +@pytest.mark.asyncio +async def test_anthropic_usage_splits_cache_read_and_write_tokens(): + client = AnthropicMessagesClient(object()) + native_response = SimpleNamespace( + id="resp", + model="claude-sonnet-4-5", + stop_reason="end_turn", + content=[SimpleNamespace(type="text", text="ok")], + usage=SimpleNamespace( + input_tokens=5, + output_tokens=7, + cache_read_input_tokens=80, + cache_creation_input_tokens=10, + ), + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 15 + assert response.usage.cached_input_tokens == 80 + assert response.usage.cache_write_input_tokens == 10 + assert response.usage.total_tokens == 22 + + +def test_native_anthropic_usage_counts_cache_writes_as_uncached_input(): + response = SimpleNamespace( + usage=SimpleNamespace( + input_tokens=5, + output_tokens=7, + cache_read_input_tokens=80, + cache_creation_input_tokens=10, + ) + ) + + assert extract_usage_token_details(response) == { + "input_tokens": 15, + "output_tokens": 7, + "cached_input_tokens": 80, + "cache_write_input_tokens": 10, + } + + +def test_serialized_response_usage_counts_cache_details(): + response = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 7, + "prompt_tokens_details": { + "cached_tokens": 80, + "cache_write_tokens": 10, + }, + } + } + + assert extract_usage_token_details(response) == { + "input_tokens": 20, + "output_tokens": 7, + "cached_input_tokens": 80, + "cache_write_input_tokens": 10, + } + + +def test_state_output_fallback_reads_serialized_trajectory_usage(): + task = vf.Task( + { + "example_id": 0, + "prompt": [{"role": "user", "content": "q"}], + } + ).freeze() + state = vf.State.for_task(task) + state["trajectory"] = [ + { + "response": { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 7, + "prompt_tokens_details": {"cached_tokens": 80}, + } + } + } + ] + + output = state_to_output(state) + + assert output["token_usage"]["input_tokens"] == 20.0 + assert output["token_usage"]["cached_input_tokens"] == 80.0 + assert output["token_usage"]["output_tokens"] == 7.0 diff --git a/verifiers/clients/anthropic_messages_client.py b/verifiers/clients/anthropic_messages_client.py index 9e80b63b7..5e25e53a3 100644 --- a/verifiers/clients/anthropic_messages_client.py +++ b/verifiers/clients/anthropic_messages_client.py @@ -447,6 +447,16 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens + cached_input_tokens = getattr(response.usage, "cache_read_input_tokens", None) + cache_write_input_tokens = getattr( + response.usage, "cache_creation_input_tokens", None + ) + if isinstance(cache_write_input_tokens, int): + input_tokens += cache_write_input_tokens + else: + cache_write_input_tokens = None + if not isinstance(cached_input_tokens, int): + cached_input_tokens = None return Response( id=response.id, @@ -457,6 +467,8 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: completion_tokens=output_tokens, reasoning_tokens=0, total_tokens=input_tokens + output_tokens, + cached_input_tokens=cached_input_tokens, + cache_write_input_tokens=cache_write_input_tokens, ), message=ResponseMessage( content=content, diff --git a/verifiers/clients/client.py b/verifiers/clients/client.py index 7991670e8..5879b39e6 100644 --- a/verifiers/clients/client.py +++ b/verifiers/clients/client.py @@ -19,6 +19,7 @@ SamplingArgs, Tool, ) +from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_request if TYPE_CHECKING: pass @@ -50,6 +51,10 @@ def __init__(self, client_or_config: ClientT | ClientConfig) -> None: def client(self) -> ClientT: return self._client + @property + def config(self) -> ClientConfig | None: + return self._config + @abstractmethod def setup_client(self, config: ClientConfig) -> ClientT: ... @@ -126,6 +131,16 @@ async def get_response( native_prompt, extra_kwargs = await self.to_native_prompt(prompt) native_tools = await self.to_native_tools(tools) + native_prompt, native_tools, sampling_args, extra_kwargs = ( + apply_prompt_cache_to_request( + config=self._config, + model=model, + native_prompt=native_prompt, + native_tools=native_tools, + sampling_args=sampling_args, + extra_kwargs=extra_kwargs, + ) + ) native_response = await self.get_native_response( native_prompt, model, diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index c755d8dd4..ea08822a5 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -99,6 +99,24 @@ def get_usage_field(usage: Any, key: str) -> Any: return getattr(usage, key, None) +def get_usage_int_field(usage: Any, key: str) -> int | None: + value = get_usage_field(usage, key) + if isinstance(value, int) and not isinstance(value, bool): + return value + return None + + +def get_prompt_cache_token_fields(usage: Any) -> tuple[int | None, int | None]: + details = get_usage_field(usage, "prompt_tokens_details") + if details is None: + details = get_usage_field(usage, "input_tokens_details") + if details is None: + return None, None + cached_tokens = get_usage_int_field(details, "cached_tokens") + cache_write_tokens = get_usage_int_field(details, "cache_write_tokens") + return cached_tokens, cache_write_tokens + + def content_to_text(content: Any) -> str: """Get all text content from OAI message content.""" if isinstance(content, str): @@ -397,13 +415,20 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: completion_tokens, int ): return None + cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + if cached_tokens is not None: + prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens + elif cached_tokens is not None: + total_tokens = max(0, total_tokens - cached_tokens) return Usage( prompt_tokens=prompt_tokens, reasoning_tokens=0, completion_tokens=completion_tokens, total_tokens=total_tokens, + cached_input_tokens=cached_tokens, + cache_write_input_tokens=cache_write_tokens, ) def parse_is_truncated(response: OpenAIChatResponse) -> bool: diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index f7115322a..7e1aa3712 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -6,6 +6,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, + get_prompt_cache_token_fields, get_usage_field, handle_openai_overlong_prompt, ) @@ -125,13 +126,20 @@ def parse_usage(response: OpenAITextResponse) -> Usage | None: completion_tokens, int ): return None + cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + if cached_tokens is not None: + prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens + elif cached_tokens is not None: + total_tokens = max(0, total_tokens - cached_tokens) return Usage( prompt_tokens=prompt_tokens, reasoning_tokens=0, completion_tokens=completion_tokens, total_tokens=total_tokens, + cached_input_tokens=cached_tokens, + cache_write_input_tokens=cache_write_tokens, ) def parse_finish_reason(response: OpenAITextResponse) -> FinishReason: diff --git a/verifiers/clients/openai_responses_client.py b/verifiers/clients/openai_responses_client.py index b33f6b615..28672557d 100644 --- a/verifiers/clients/openai_responses_client.py +++ b/verifiers/clients/openai_responses_client.py @@ -10,6 +10,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, + get_prompt_cache_token_fields, get_usage_field, handle_openai_overlong_prompt, ) @@ -387,8 +388,13 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: completion_tokens, int ): return None + cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + if cached_tokens is not None: + prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens + elif cached_tokens is not None: + total_tokens = max(0, total_tokens - cached_tokens) if not isinstance(reasoning_tokens, int): reasoning_tokens = 0 return Usage( @@ -396,6 +402,8 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: reasoning_tokens=reasoning_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, + cached_input_tokens=cached_tokens, + cache_write_input_tokens=cache_write_tokens, ) def parse_is_truncated(response: OpenAIResponsesNativeResponse) -> bool: diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index 36d5c9743..e51ca4a17 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -481,10 +481,15 @@ def get_state_usage(self, state: State) -> TokenUsage | None: usage = state.get("usage") if isinstance(usage, Mapping): try: - return { + out: TokenUsage = { "input_tokens": float(usage.get("input_tokens", 0.0)), "output_tokens": float(usage.get("output_tokens", 0.0)), } + for key in ("cached_input_tokens", "cache_write_input_tokens"): + value = usage.get(key) + if value is not None: + out[key] = float(value) + return out except (TypeError, ValueError): return None return None diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index ceb9e71e4..4b70954fa 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -183,6 +183,15 @@ def build_extra_headers_from_state(raw: dict[str, Any]) -> dict[str, str]: return {**table, **from_list} +def build_prompt_cache_enabled(raw: dict[str, Any], default: bool = True) -> bool: + raw_prompt_cache = raw.get("prompt_cache") + if raw_prompt_cache is None: + return default + if not isinstance(raw_prompt_cache, bool): + raise ValueError("'prompt_cache' must be a boolean when provided.") + return raw_prompt_cache + + def get_env_eval_defaults(env_id: str) -> dict[str, Any]: """Get eval config defaults from the environment module's pyproject.toml. @@ -708,13 +717,18 @@ def build_eval_config(raw: dict) -> EvalConfig: } registry_headers_base: dict[str, str] = {} + registry_prompt_cache = True if endpoint_group is not None: registry_headers_base = dict(endpoint_group[0].get("extra_headers", {})) + registry_prompt_cache = bool(endpoint_group[0].get("prompt_cache", True)) merged_headers: dict[str, str] = { **registry_headers_base, **eval_headers_merged, } + prompt_cache_enabled = build_prompt_cache_enabled( + raw, default=registry_prompt_cache + ) primary_api_base_url = api_base_url if not isinstance(primary_api_base_url, str): @@ -739,6 +753,9 @@ def build_eval_config(raw: dict) -> EvalConfig: **dict(ep.get("extra_headers", {})), **eval_headers_merged, }, + prompt_cache=build_prompt_cache_enabled( + raw, default=bool(ep.get("prompt_cache", True)) + ), ) for ep in endpoint_group ] @@ -751,6 +768,7 @@ def build_eval_config(raw: dict) -> EvalConfig: endpoint_configs=endpoint_configs, extra_headers=merged_headers, extra_headers_from_state=eval_headers_from_state, + prompt_cache=prompt_cache_enabled, ) # Backward-compatible TOML field: resume_path diff --git a/verifiers/scripts/tui.py b/verifiers/scripts/tui.py index e0da9b706..fa53da9df 100755 --- a/verifiers/scripts/tui.py +++ b/verifiers/scripts/tui.py @@ -3257,8 +3257,21 @@ def _build_header_summary_text(self) -> Text: if isinstance(usage, dict): input_tok = usage.get("input_tokens") output_tok = usage.get("output_tokens") + cached_input_tok = usage.get("cached_input_tokens") + cache_write_input_tok = usage.get("cache_write_input_tokens") if input_tok is not None: usage_items.append(("Avg input tokens", format_numeric(input_tok))) + if cached_input_tok is not None: + usage_items.append( + ("Avg cached input tokens", format_numeric(cached_input_tok)) + ) + if cache_write_input_tok is not None: + usage_items.append( + ( + "Avg cache write input tokens", + format_numeric(cache_write_input_tok), + ) + ) if output_tok is not None: usage_items.append(("Avg output tokens", format_numeric(output_tok))) max_tokens = sampling_args.get("max_tokens") @@ -4728,10 +4741,21 @@ def _build_usage_text(self, record: Dict[str, Any]) -> Text: usage_lines = [] input_tok = token_usage.get("input_tokens") output_tok = token_usage.get("output_tokens") + cached_input_tok = token_usage.get("cached_input_tokens") + cache_write_input_tok = token_usage.get("cache_write_input_tokens") final_inp = token_usage.get("final_input_tokens") final_outp = token_usage.get("final_output_tokens") if input_tok is not None: usage_lines.append(f"input_tokens: {format_numeric(input_tok)}") + if cached_input_tok is not None: + usage_lines.append( + f"cached_input_tokens: {format_numeric(cached_input_tok)}" + ) + if cache_write_input_tok is not None: + usage_lines.append( + "cache_write_input_tokens: " + f"{format_numeric(cache_write_input_tok)}" + ) if output_tok is not None: usage_lines.append(f"output_tokens: {format_numeric(output_tok)}") if final_inp is not None: diff --git a/verifiers/types.py b/verifiers/types.py index f0dc4ac55..28e54a2eb 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -170,6 +170,8 @@ class Usage(CustomBaseModel): reasoning_tokens: int completion_tokens: int total_tokens: int + cached_input_tokens: int | None = None + cache_write_input_tokens: int | None = None class ResponseTokens(CustomBaseModel): @@ -221,6 +223,8 @@ class TrajectoryStepTokens(TypedDict): class TokenUsage(TypedDict): input_tokens: float output_tokens: float + cached_input_tokens: NotRequired[float] + cache_write_input_tokens: NotRequired[float] final_input_tokens: NotRequired[float] final_output_tokens: NotRequired[float] @@ -555,6 +559,7 @@ class RolloutScores(TypedDict): "model": str, "api_client_type": NotRequired[ClientType], "extra_headers": NotRequired[dict[str, str]], + "prompt_cache": NotRequired[bool], }, ) Endpoints = dict[str, list[Endpoint]] @@ -601,6 +606,7 @@ class ClientConfig(BaseModel): 'e.g. {"X-Session-ID": "example_id"} adds a X-Session-ID header ' "with the value of state['example_id'].", ) + prompt_cache: bool = True @field_validator("extra_headers", mode="before") @classmethod @@ -661,6 +667,7 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = Field(default_factory=dict) + prompt_cache: bool = True @field_validator("extra_headers", mode="before") @classmethod diff --git a/verifiers/utils/eval_display.py b/verifiers/utils/eval_display.py index 4cf469087..1ba4f8dd1 100644 --- a/verifiers/utils/eval_display.py +++ b/verifiers/utils/eval_display.py @@ -391,6 +391,12 @@ def _make_tokens_row(self, usage: TokenUsage) -> Table: "input": format_numeric(usage.get("input_tokens", 0.0)), "output": format_numeric(usage.get("output_tokens", 0.0)), } + cached = usage.get("cached_input_tokens") + cache_write = usage.get("cache_write_input_tokens") + if cached is not None: + kv["cached_input"] = format_numeric(cached) + if cache_write is not None: + kv["cache_write_input"] = format_numeric(cache_write) inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") if inp is not None: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index d6e9d854e..7a7bda443 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -126,6 +126,12 @@ def _coerce_endpoint(raw_endpoint: object, source: str) -> Endpoint: if coerced_headers: endpoint["extra_headers"] = coerced_headers + raw_prompt_cache = raw_endpoint_dict.get("prompt_cache") + if raw_prompt_cache is not None: + if not isinstance(raw_prompt_cache, bool): + raise ValueError(f"Field 'prompt_cache' must be a boolean in {source}") + endpoint["prompt_cache"] = raw_prompt_cache + return endpoint @@ -451,6 +457,7 @@ def load_toml_config( "headers", "header_from_state", "headers_from_state", + "prompt_cache", # sampling "sampling_args", "max_tokens", @@ -712,6 +719,10 @@ def print_usage(results: GenerateOutputs): usage_count = 0 input_total = 0.0 output_total = 0.0 + cached_input_total = 0.0 + cache_write_input_total = 0.0 + cached_input_count = 0 + cache_write_input_count = 0 final_input_total = 0.0 final_output_total = 0.0 context_count = 0 @@ -722,6 +733,14 @@ def print_usage(results: GenerateOutputs): usage_count += 1 input_total += float(token_usage.get("input_tokens", 0.0)) output_total += float(token_usage.get("output_tokens", 0.0)) + cached = token_usage.get("cached_input_tokens") + if cached is not None: + cached_input_total += float(cached) + cached_input_count += 1 + cache_write = token_usage.get("cache_write_input_tokens") + if cache_write is not None: + cache_write_input_total += float(cache_write) + cache_write_input_count += 1 inp = token_usage.get("final_input_tokens") out = token_usage.get("final_output_tokens") if inp is not None and out is not None: @@ -735,6 +754,12 @@ def print_usage(results: GenerateOutputs): input_tokens=input_total / usage_count, output_tokens=output_total / usage_count, ) + if cached_input_count > 0: + usage["cached_input_tokens"] = cached_input_total / cached_input_count + if cache_write_input_count > 0: + usage["cache_write_input_tokens"] = ( + cache_write_input_total / cache_write_input_count + ) if context_count > 0: usage["final_input_tokens"] = final_input_total / context_count usage["final_output_tokens"] = final_output_total / context_count @@ -746,6 +771,12 @@ def print_usage(results: GenerateOutputs): print("Usage:") print(f"input_tokens (avg): {float(usage.get('input_tokens', 0.0)):.3f}") + cached = usage.get("cached_input_tokens") + if cached is not None: + print(f"cached_input_tokens (avg): {float(cached):.3f}") + cache_write = usage.get("cache_write_input_tokens") + if cache_write is not None: + print(f"cache_write_input_tokens (avg): {float(cache_write):.3f}") print(f"output_tokens (avg): {float(usage.get('output_tokens', 0.0)):.3f}") inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") diff --git a/verifiers/utils/interception_utils.py b/verifiers/utils/interception_utils.py index fbb94a32f..53e592c35 100644 --- a/verifiers/utils/interception_utils.py +++ b/verifiers/utils/interception_utils.py @@ -856,10 +856,21 @@ def serialize_anthropic_message_response(response: Response) -> dict[str, Any]: content.append({"type": "text", "text": ""}) usage = {} if response.usage is not None: + input_tokens = response.usage.prompt_tokens + if response.usage.cache_write_input_tokens is not None: + input_tokens = max( + 0, input_tokens - response.usage.cache_write_input_tokens + ) usage = { - "input_tokens": response.usage.prompt_tokens, + "input_tokens": input_tokens, "output_tokens": response.usage.completion_tokens, } + if response.usage.cached_input_tokens is not None: + usage["cache_read_input_tokens"] = response.usage.cached_input_tokens + if response.usage.cache_write_input_tokens is not None: + usage["cache_creation_input_tokens"] = ( + response.usage.cache_write_input_tokens + ) return { "id": response.id, "type": "message", diff --git a/verifiers/utils/metric_utils.py b/verifiers/utils/metric_utils.py index 66f204028..a7f57afac 100644 --- a/verifiers/utils/metric_utils.py +++ b/verifiers/utils/metric_utils.py @@ -92,6 +92,18 @@ class OutputTokensMetric(_TokenUsageKeyMetric): _key = "output_tokens" +class CachedInputTokensMetric(_TokenUsageKeyMetric): + """Mean cached_input_tokens per output.""" + + _key = "cached_input_tokens" + + +class CacheWriteInputTokensMetric(_TokenUsageKeyMetric): + """Mean cache_write_input_tokens per output.""" + + _key = "cache_write_input_tokens" + + class FinalInputTokensMetric(_TokenUsageKeyMetric): """Mean final_input_tokens (non-completion context tokens) per output.""" diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py new file mode 100644 index 000000000..e61c3d3a3 --- /dev/null +++ b/verifiers/utils/prompt_cache_utils.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Literal +from urllib.parse import urlsplit + +from verifiers.types import ClientConfig, ClientType + +PromptCacheMode = Literal[ + "disabled", + "implicit", + "anthropic_top_level", + "openrouter_anthropic_top_level", +] + +OPENAI_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset( + { + "openai_chat_completions", + "openai_responses", + } +) +ANTHROPIC_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset( + {"anthropic_messages"} +) +OPENROUTER_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset( + { + "openai_chat_completions", + "openai_responses", + } +) + + +@dataclass(frozen=True) +class EndpointIdentity: + client_type: ClientType + origin: str + host: str + path: str + + @classmethod + def from_config(cls, config: ClientConfig) -> "EndpointIdentity | None": + return cls.from_url(config.api_base_url, config.client_type) + + @classmethod + def from_url( + cls, api_base_url: str, client_type: ClientType + ) -> "EndpointIdentity | None": + parsed = urlsplit(api_base_url) + if not parsed.scheme or not parsed.hostname: + return None + scheme = parsed.scheme.lower() + host = parsed.hostname.lower() + port = parsed.port + netloc = host + if port is not None and not ( + (scheme == "https" and port == 443) or (scheme == "http" and port == 80) + ): + netloc = f"{host}:{port}" + return cls( + client_type=client_type, + origin=f"{scheme}://{netloc}", + host=host, + path=parsed.path or "", + ) + + +@dataclass(frozen=True) +class PromptCachePolicy: + mode: PromptCacheMode = "disabled" + prefire_groups: bool = False + + @property + def enabled(self) -> bool: + return self.mode != "disabled" + + +class PromptCacheAdapter: + prefire_groups = True + + def policy_for( + self, identity: EndpointIdentity, model: str + ) -> PromptCachePolicy: + _ = identity, model + return PromptCachePolicy(mode="implicit", prefire_groups=self.prefire_groups) + + +class AnthropicPromptCacheAdapter(PromptCacheAdapter): + def policy_for( + self, identity: EndpointIdentity, model: str + ) -> PromptCachePolicy: + _ = identity, model + return PromptCachePolicy( + mode="anthropic_top_level", prefire_groups=self.prefire_groups + ) + + +class OpenRouterPromptCacheAdapter(PromptCacheAdapter): + anthropic_model_prefixes = ("anthropic/",) + + def policy_for( + self, identity: EndpointIdentity, model: str + ) -> PromptCachePolicy: + _ = identity + if model.startswith(self.anthropic_model_prefixes): + return PromptCachePolicy( + mode="openrouter_anthropic_top_level", + prefire_groups=self.prefire_groups, + ) + return PromptCachePolicy(mode="implicit", prefire_groups=self.prefire_groups) + + +@dataclass(frozen=True) +class ProviderSpec: + provider_id: str + origins: frozenset[str] + client_types: frozenset[ClientType] + prompt_cache: PromptCacheAdapter + + def recognizes(self, identity: EndpointIdentity) -> bool: + return ( + identity.origin in self.origins + and identity.client_type in self.client_types + ) + + +PROVIDER_SPECS: tuple[ProviderSpec, ...] = ( + ProviderSpec( + provider_id="openai", + origins=frozenset({"https://api.openai.com"}), + client_types=OPENAI_CACHE_CLIENT_TYPES, + prompt_cache=PromptCacheAdapter(), + ), + ProviderSpec( + provider_id="anthropic", + origins=frozenset({"https://api.anthropic.com"}), + client_types=ANTHROPIC_CACHE_CLIENT_TYPES, + prompt_cache=AnthropicPromptCacheAdapter(), + ), + ProviderSpec( + provider_id="openrouter", + origins=frozenset({"https://openrouter.ai"}), + client_types=OPENROUTER_CACHE_CLIENT_TYPES, + prompt_cache=OpenRouterPromptCacheAdapter(), + ), +) + +DISABLED_PROMPT_CACHE_POLICY = PromptCachePolicy() + + +def infer_provider_spec(config: ClientConfig) -> ProviderSpec | None: + identity = EndpointIdentity.from_config(config) + if identity is None: + return None + for spec in PROVIDER_SPECS: + if spec.recognizes(identity): + return spec + return None + + +def resolve_prompt_cache_policy( + config: ClientConfig | None, model: str +) -> PromptCachePolicy: + if config is None or not config.prompt_cache: + return DISABLED_PROMPT_CACHE_POLICY + identity = EndpointIdentity.from_config(config) + if identity is None: + return DISABLED_PROMPT_CACHE_POLICY + spec = infer_provider_spec(config) + if spec is None: + return DISABLED_PROMPT_CACHE_POLICY + return spec.prompt_cache.policy_for(identity, model) + + +def should_prefire_prompt_cache_group( + client_or_config: object, model: str, group_size: int +) -> bool: + if group_size <= 1: + return False + config = client_or_config if isinstance(client_or_config, ClientConfig) else None + if config is None: + config = getattr(client_or_config, "config", None) + if not isinstance(config, ClientConfig): + return False + return resolve_prompt_cache_policy(config, model).prefire_groups + + +def _cache_control_payload() -> dict[str, str]: + return {"type": "ephemeral"} + + +def apply_prompt_cache_to_request( + *, + config: ClientConfig | None, + model: str, + native_prompt: object, + native_tools: object, + sampling_args: Mapping[str, Any], + extra_kwargs: Mapping[str, Any], +) -> tuple[object, object, dict[str, Any], dict[str, Any]]: + policy = resolve_prompt_cache_policy(config, model) + updated_sampling_args = dict(sampling_args) + updated_extra_kwargs = dict(extra_kwargs) + updated_native_prompt = native_prompt + if policy.mode == "anthropic_top_level": + updated_extra_kwargs.setdefault("cache_control", _cache_control_payload()) + elif policy.mode == "openrouter_anthropic_top_level": + extra_body = updated_sampling_args.get("extra_body") + if isinstance(extra_body, Mapping): + extra_body = dict(extra_body) + else: + extra_body = {} + extra_body.setdefault("cache_control", _cache_control_payload()) + updated_sampling_args["extra_body"] = extra_body + return ( + updated_native_prompt, + native_tools, + updated_sampling_args, + updated_extra_kwargs, + ) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index d9aa889e7..14f96866d 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -27,6 +27,8 @@ serialize_messages_for_output, ) from verifiers.utils.metric_utils import ( + CacheWriteInputTokensMetric, + CachedInputTokensMetric, EnvMetrics, ErrorRateMetric, FinalInputTokensMetric, @@ -41,6 +43,8 @@ StateUsageTracker, ) from verifiers.utils.usage_utils import ( + cast_token_usage, + extract_usage_token_details, extract_usage_tokens as extract_usage_tokens_from_response, ) from verifiers.utils.version_utils import get_version_info @@ -110,10 +114,19 @@ def _coerce_token_usage(value: object) -> TokenUsage | None: output_tokens = float(0.0 if output_raw is None else output_raw) except (TypeError, ValueError): return None - return { + usage: dict[str, float] = { "input_tokens": input_tokens, "output_tokens": output_tokens, } + for key in ("cached_input_tokens", "cache_write_input_tokens"): + raw_value = mapping_value.get(key) + if raw_value is None: + continue + try: + usage[key] = float(raw_value) + except (TypeError, ValueError): + continue + return cast_token_usage(usage) def _extract_state_token_usage(state: State) -> TokenUsage | None: @@ -188,28 +201,25 @@ def state_to_output( if usage is None: # Legacy fallback for states that do not use state-level usage tracking. trajectory = state.get("trajectory", []) - input_tokens = 0 - output_tokens = 0 + usage_totals: dict[str, float] = { + "input_tokens": 0.0, + "output_tokens": 0.0, + } usage_seen = False for step in trajectory: response = step.get("response") if response is None: continue - if getattr(response, "usage", None) is not None: - usage_seen = True - step_input_tokens, step_output_tokens = extract_usage_tokens(response) - input_tokens += step_input_tokens - output_tokens += step_output_tokens + details = extract_usage_token_details(response) + if details is None: + continue + usage_seen = True + for key, value in details.items(): + usage_totals[key] = usage_totals.get(key, 0.0) + float(value) if usage_seen: - usage = { - "input_tokens": float(input_tokens), - "output_tokens": float(output_tokens), - } + usage = cast_token_usage(usage_totals) if usage is not None: - token_usage: dict[str, float] = { - "input_tokens": usage.get("input_tokens", 0.0), - "output_tokens": usage.get("output_tokens", 0.0), - } + token_usage: dict[str, float] = dict(usage) # Add context token metrics from trajectory trajectory = state.get("trajectory", []) if trajectory: @@ -328,6 +338,8 @@ def __init__( self.env_metrics = EnvMetrics() self.input_tokens = InputTokensMetric() self.output_tokens = OutputTokensMetric() + self.cached_input_tokens = CachedInputTokensMetric() + self.cache_write_input_tokens = CacheWriteInputTokensMetric() self.final_input_tokens = FinalInputTokensMetric() self.final_output_tokens = FinalOutputTokensMetric() self.pass_at_k = PassAtKMetric(rollouts_per_example, threshold=pass_threshold) @@ -378,6 +390,8 @@ def add_outputs(self, new_outputs: list[RolloutOutput]) -> None: self.env_metrics.add_outputs(new_outputs) self.input_tokens.add_outputs(new_outputs) self.output_tokens.add_outputs(new_outputs) + self.cached_input_tokens.add_outputs(new_outputs) + self.cache_write_input_tokens.add_outputs(new_outputs) self.final_input_tokens.add_outputs(new_outputs) self.final_output_tokens.add_outputs(new_outputs) self.pass_at_k.add_outputs(new_outputs) @@ -402,6 +416,12 @@ def build_metadata(self) -> GenerateMetadata: input_tokens=self.input_tokens.compute(), output_tokens=self.output_tokens.compute(), ) + if self.cached_input_tokens.count > 0: + usage["cached_input_tokens"] = self.cached_input_tokens.compute() + if self.cache_write_input_tokens.count > 0: + usage["cache_write_input_tokens"] = ( + self.cache_write_input_tokens.compute() + ) if self.final_input_tokens.count > 0: usage["final_input_tokens"] = self.final_input_tokens.compute() usage["final_output_tokens"] = self.final_output_tokens.compute() diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index ea3f57db6..fb6424f8e 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -2,6 +2,8 @@ import math from types import MappingProxyType +from typing import Any + from verifiers.types import TokenUsage @@ -11,6 +13,30 @@ def _get_usage_value(usage_obj: object, key: str) -> int | float: return getattr(usage_obj, key, 0) +def _get_optional_usage_value(usage_obj: object, key: str) -> object: + if isinstance(usage_obj, Mapping): + return usage_obj.get(key) + return getattr(usage_obj, key, None) + + +def _get_nested_usage_value(usage_obj: object, key: str) -> object: + value = _get_optional_usage_value(usage_obj, key) + if value is not None: + return value + details = _get_optional_usage_value(usage_obj, "prompt_tokens_details") + if isinstance(details, Mapping): + return details.get(key) + if details is not None: + return getattr(details, key, None) + return None + + +def _get_response_usage(response: object) -> object: + if isinstance(response, Mapping): + return response.get("usage") + return getattr(response, "usage", None) + + def _coerce_usage_int(value: object) -> int: """Best-effort usage coercion. Invalid values degrade to zero.""" if value is None: @@ -40,17 +66,59 @@ def _coerce_usage_int(value: object) -> int: return 0 -def extract_usage_tokens(response: object) -> tuple[int, int]: - usage = getattr(response, "usage", None) +def extract_usage_token_details(response: object) -> dict[str, int] | None: + usage = _get_response_usage(response) if usage is None: - return 0, 0 + return None prompt_tokens = _get_usage_value(usage, "prompt_tokens") completion_tokens = _get_usage_value(usage, "completion_tokens") if not prompt_tokens and not completion_tokens: prompt_tokens = _get_usage_value(usage, "input_tokens") completion_tokens = _get_usage_value(usage, "output_tokens") - return _coerce_usage_int(prompt_tokens), _coerce_usage_int(completion_tokens) + details = { + "input_tokens": _coerce_usage_int(prompt_tokens), + "output_tokens": _coerce_usage_int(completion_tokens), + } + + subtract_cached_from_input = False + cached_input_tokens = _get_optional_usage_value(usage, "cached_input_tokens") + if cached_input_tokens is None: + cached_input_tokens = _get_optional_usage_value(usage, "cache_read_input_tokens") + if cached_input_tokens is None: + cached_input_tokens = _get_nested_usage_value(usage, "cached_tokens") + subtract_cached_from_input = cached_input_tokens is not None + if cached_input_tokens is not None: + cached_int = _coerce_usage_int(cached_input_tokens) + details["cached_input_tokens"] = cached_int + if subtract_cached_from_input: + details["input_tokens"] = max(0, details["input_tokens"] - cached_int) + + cache_write_input_tokens = _get_optional_usage_value( + usage, "cache_write_input_tokens" + ) + add_cache_write_to_input = False + if cache_write_input_tokens is None: + cache_write_input_tokens = _get_optional_usage_value( + usage, "cache_creation_input_tokens" + ) + add_cache_write_to_input = cache_write_input_tokens is not None + if cache_write_input_tokens is None: + cache_write_input_tokens = _get_nested_usage_value(usage, "cache_write_tokens") + if cache_write_input_tokens is not None: + cache_write_int = _coerce_usage_int(cache_write_input_tokens) + details["cache_write_input_tokens"] = cache_write_int + if add_cache_write_to_input: + details["input_tokens"] += cache_write_int + + return details + + +def extract_usage_tokens(response: object) -> tuple[int, int]: + details = extract_usage_token_details(response) + if details is None: + return 0, 0 + return details["input_tokens"], details["output_tokens"] class StateUsageTracker: @@ -75,30 +143,57 @@ def increment( input_tokens: int | float = 0, output_tokens: int | float = 0, *, + cached_input_tokens: int | float | None = None, + cache_write_input_tokens: int | float | None = None, mark_seen: bool = True, ) -> None: - input_delta = float(input_tokens or 0.0) - output_delta = float(output_tokens or 0.0) - if input_delta < 0 or output_delta < 0: + deltas: dict[str, float] = { + "input_tokens": float(input_tokens or 0.0), + "output_tokens": float(output_tokens or 0.0), + } + if cached_input_tokens is not None: + deltas["cached_input_tokens"] = float(cached_input_tokens or 0.0) + if cache_write_input_tokens is not None: + deltas["cache_write_input_tokens"] = float( + cache_write_input_tokens or 0.0 + ) + if any(delta < 0 for delta in deltas.values()): raise ValueError("Token usage increments must be non-negative.") if mark_seen: self._usage_seen = True - self._usage_totals["input_tokens"] += input_delta - self._usage_totals["output_tokens"] += output_delta + for key, delta in deltas.items(): + self._usage_totals[key] = self._usage_totals.get(key, 0.0) + delta def increment_from_response(self, response: object) -> None: - if getattr(response, "usage", None) is None: + if _get_response_usage(response) is None: + return + details = extract_usage_token_details(response) + if details is None: return - input_tokens, output_tokens = extract_usage_tokens(response) - self.increment(input_tokens, output_tokens, mark_seen=True) + self.increment( + details["input_tokens"], + details["output_tokens"], + cached_input_tokens=details.get("cached_input_tokens"), + cache_write_input_tokens=details.get("cache_write_input_tokens"), + mark_seen=True, + ) def snapshot(self) -> TokenUsage | None: if not self._usage_seen: return None - return { - "input_tokens": self._usage_totals["input_tokens"], - "output_tokens": self._usage_totals["output_tokens"], - } + return cast_token_usage(self._usage_totals) + + +def cast_token_usage(usage: Mapping[str, Any]) -> TokenUsage: + out: TokenUsage = { + "input_tokens": float(usage.get("input_tokens", 0.0)), + "output_tokens": float(usage.get("output_tokens", 0.0)), + } + for key in ("cached_input_tokens", "cache_write_input_tokens"): + value = usage.get(key) + if value is not None: + out[key] = float(value) + return out def compute_context_token_metrics( @@ -128,9 +223,15 @@ def compute_context_token_metrics( found = False for step in reversed(trajectory): response = step.get("response") - if response is None or getattr(response, "usage", None) is None: + if response is None or _get_response_usage(response) is None: + continue + details = extract_usage_token_details(response) + if details is None: continue - prompt_tokens, completion_tokens = extract_usage_tokens(response) + prompt_tokens = details["input_tokens"] + details.get( + "cached_input_tokens", 0 + ) + completion_tokens = details["output_tokens"] last_step_total = prompt_tokens + completion_tokens found = True break @@ -142,10 +243,11 @@ def compute_context_token_metrics( total_completion = 0 for step in trajectory: response = step.get("response") - if response is None or getattr(response, "usage", None) is None: + if response is None or _get_response_usage(response) is None: continue - _, completion_tokens = extract_usage_tokens(response) - total_completion += completion_tokens + details = extract_usage_token_details(response) + if details is not None: + total_completion += details["output_tokens"] return { "final_output_tokens": total_completion, diff --git a/verifiers/v1/env.py b/verifiers/v1/env.py index f3753221d..2c40fce99 100644 --- a/verifiers/v1/env.py +++ b/verifiers/v1/env.py @@ -9,6 +9,7 @@ from verifiers.clients import Client from verifiers.types import ClientConfig from verifiers.types import RolloutInput, SamplingArgs +from verifiers.utils.prompt_cache_utils import should_prefire_prompt_cache_group from .harness import Harness from .state import State @@ -102,9 +103,24 @@ async def _run_group_states( "score_rollout": self.score_rollouts, }, ) - states = await asyncio.gather( - *[self.harness.run(task, state) for task, state in zip(tasks, states)] - ) + task_state_pairs = list(zip(tasks, states)) + if should_prefire_prompt_cache_group(client, model, len(task_state_pairs)): + first_task, first_state = task_state_pairs[0] + first_result = await self.harness.run(first_task, first_state) + remaining_results = await asyncio.gather( + *[ + self.harness.run(task, state) + for task, state in task_state_pairs[1:] + ] + ) + states = [first_result, *remaining_results] + else: + states = await asyncio.gather( + *[ + self.harness.run(task, state) + for task, state in task_state_pairs + ] + ) try: if self.score_rollouts: await self.harness.score_group(tasks, states) From 7350782069a97ebe9ebd0d13e392fa69f8a2ff57 Mon Sep 17 00:00:00 2001 From: William Brown Date: Sun, 17 May 2026 17:27:30 -0500 Subject: [PATCH 2/9] Drop cache write token exports --- docs/reference.md | 2 - tests/test_prompt_cache_utils.py | 115 +++--------------- .../clients/anthropic_messages_client.py | 9 +- .../clients/openai_chat_completions_client.py | 11 +- .../clients/openai_completions_client.py | 5 +- verifiers/clients/openai_responses_client.py | 5 +- verifiers/envs/environment.py | 2 +- verifiers/scripts/tui.py | 14 --- verifiers/types.py | 2 - verifiers/utils/eval_display.py | 3 - verifiers/utils/eval_utils.py | 13 -- verifiers/utils/interception_utils.py | 11 +- verifiers/utils/metric_utils.py | 6 - verifiers/utils/prompt_cache_utils.py | 29 +---- verifiers/utils/save_utils.py | 9 +- verifiers/utils/usage_utils.py | 27 +--- verifiers/v1/env.py | 25 +--- 17 files changed, 46 insertions(+), 242 deletions(-) diff --git a/docs/reference.md b/docs/reference.md index 7f0c36438..89a705273 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -244,7 +244,6 @@ class TokenUsage(TypedDict, total=False): input_tokens: float output_tokens: float cached_input_tokens: float - cache_write_input_tokens: float final_input_tokens: float final_output_tokens: float ``` @@ -254,7 +253,6 @@ class TokenUsage(TypedDict, total=False): | `input_tokens` | Sum of non-cache-hit prompt tokens across all turns. Shared uncached context is counted each time it appears in a prompt. | | `output_tokens` | Sum of completion tokens across all turns. | | `cached_input_tokens` | Sum of prompt tokens served from provider prompt cache, when reported by the provider. | -| `cache_write_input_tokens` | Sum of prompt tokens written to provider prompt cache, when reported by the provider. | | `final_input_tokens` | Non-completion tokens in the final turn's context (system prompts, user messages, tool results, etc.). | | `final_output_tokens` | Completion tokens in the final turn's context. Equals `output_tokens` for single-turn rollouts. | diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index 07ee505c7..d18ad9b94 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -9,12 +9,11 @@ from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient -from verifiers.types import ClientConfig, Response, ResponseMessage, Usage +from verifiers.types import ClientConfig, Response, ResponseMessage from verifiers.utils.prompt_cache_utils import ( EndpointIdentity, apply_prompt_cache_to_request, resolve_prompt_cache_policy, - should_prefire_prompt_cache_group, ) from verifiers.utils.save_utils import state_to_output from verifiers.utils.usage_utils import extract_usage_token_details @@ -65,35 +64,30 @@ async def close(self) -> None: pass -class GroupOrderClient(Client): +class ConcurrentStartClient(Client): def __init__(self, config: ClientConfig): super().__init__(config) self.first_active = False - self.started_during_first: list[int] = [] + self.started_any = False + self.overlapped_first = False def setup_client(self, config): return object() async def get_response(self, prompt, model, sampling_args, tools=None, **kwargs): - _ = prompt, model, sampling_args, tools - state = kwargs["state"] - rollout_index = int(state["rollout_index"]) - if rollout_index == 0: + _ = prompt, model, sampling_args, tools, kwargs + if self.first_active: + self.overlapped_first = True + elif not self.started_any: + self.started_any = True self.first_active = True await asyncio.sleep(0.01) self.first_active = False - elif self.first_active: - self.started_during_first.append(rollout_index) return Response( - id=f"resp-{rollout_index}", + id="resp", created=0, model="model", - usage=Usage( - prompt_tokens=1, - reasoning_tokens=0, - completion_tokens=1, - total_tokens=2, - ), + usage=None, message=ResponseMessage( content="ok", finish_reason="stop", @@ -120,14 +114,6 @@ async def close(self) -> None: pass -class IndexedTaskset(vf.Taskset): - async def init_group(self, task, num_rollouts): - tasks, states = await super().init_group(task, num_rollouts) - for index, state in enumerate(states): - state["rollout_index"] = index - return tasks, states - - def test_endpoint_identity_normalizes_official_origins(): identity = EndpointIdentity.from_url( "https://api.openai.com/v1", "openai_chat_completions" @@ -193,7 +179,7 @@ def test_prompt_cache_false_disables_inferred_provider_policy(): ) assert policy.mode == "disabled" - assert not policy.prefire_groups + assert not policy.enabled def test_anthropic_request_policy_adds_top_level_cache_control(): @@ -262,42 +248,6 @@ def test_openai_policy_does_not_mutate_request(): assert extra_kwargs == {} -def test_group_prefire_is_tied_to_cache_policy(): - assert should_prefire_prompt_cache_group( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - ), - "gpt-5.4-mini", - 2, - ) - assert not should_prefire_prompt_cache_group( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - prompt_cache=False, - ), - "gpt-5.4-mini", - 2, - ) - assert not should_prefire_prompt_cache_group( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.example.com/v1", - ), - "model", - 2, - ) - assert not should_prefire_prompt_cache_group( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - ), - "gpt-5.4-mini", - 1, - ) - - @pytest.mark.asyncio async def test_client_request_hook_applies_prompt_cache_policy(): client = RecordingClient( @@ -317,15 +267,15 @@ async def test_client_request_hook_applies_prompt_cache_policy(): @pytest.mark.asyncio -async def test_v1_group_prefire_serializes_first_rollout_for_cached_provider(): - client = GroupOrderClient( +async def test_v1_group_rollouts_start_concurrently_for_cached_provider(): + client = ConcurrentStartClient( ClientConfig( client_type="openai_chat_completions", api_base_url="https://api.openai.com/v1", ) ) env = vf.Env( - taskset=IndexedTaskset(source=[{"question": "q"}]), + taskset=vf.Taskset(source=[{"question": "q"}]), harness=vf.Harness(max_turns=1), ) @@ -340,34 +290,7 @@ async def test_v1_group_prefire_serializes_first_rollout_for_cached_provider(): {}, ) - assert client.started_during_first == [] - - -@pytest.mark.asyncio -async def test_v1_group_prefire_is_skipped_for_generic_provider(): - client = GroupOrderClient( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.example.com/v1", - ) - ) - env = vf.Env( - taskset=IndexedTaskset(source=[{"question": "q"}]), - harness=vf.Harness(max_turns=1), - ) - - await env._run_group_states( - [ - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - ], - client, - "model", - {}, - ) - - assert client.started_during_first + assert client.overlapped_first @pytest.mark.asyncio @@ -404,7 +327,6 @@ async def test_openai_usage_splits_cached_input_tokens(): assert response.usage is not None assert response.usage.prompt_tokens == 20 assert response.usage.cached_input_tokens == 80 - assert response.usage.cache_write_input_tokens == 10 assert response.usage.total_tokens == 25 @@ -429,11 +351,10 @@ async def test_anthropic_usage_splits_cache_read_and_write_tokens(): assert response.usage is not None assert response.usage.prompt_tokens == 15 assert response.usage.cached_input_tokens == 80 - assert response.usage.cache_write_input_tokens == 10 assert response.usage.total_tokens == 22 -def test_native_anthropic_usage_counts_cache_writes_as_uncached_input(): +def test_native_anthropic_cache_creation_counts_as_uncached_input(): response = SimpleNamespace( usage=SimpleNamespace( input_tokens=5, @@ -447,7 +368,6 @@ def test_native_anthropic_usage_counts_cache_writes_as_uncached_input(): "input_tokens": 15, "output_tokens": 7, "cached_input_tokens": 80, - "cache_write_input_tokens": 10, } @@ -467,7 +387,6 @@ def test_serialized_response_usage_counts_cache_details(): "input_tokens": 20, "output_tokens": 7, "cached_input_tokens": 80, - "cache_write_input_tokens": 10, } diff --git a/verifiers/clients/anthropic_messages_client.py b/verifiers/clients/anthropic_messages_client.py index 5e25e53a3..709b0d9b3 100644 --- a/verifiers/clients/anthropic_messages_client.py +++ b/verifiers/clients/anthropic_messages_client.py @@ -448,13 +448,11 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: input_tokens = response.usage.input_tokens output_tokens = response.usage.output_tokens cached_input_tokens = getattr(response.usage, "cache_read_input_tokens", None) - cache_write_input_tokens = getattr( + cache_creation_input_tokens = getattr( response.usage, "cache_creation_input_tokens", None ) - if isinstance(cache_write_input_tokens, int): - input_tokens += cache_write_input_tokens - else: - cache_write_input_tokens = None + if isinstance(cache_creation_input_tokens, int): + input_tokens += cache_creation_input_tokens if not isinstance(cached_input_tokens, int): cached_input_tokens = None @@ -468,7 +466,6 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: reasoning_tokens=0, total_tokens=input_tokens + output_tokens, cached_input_tokens=cached_input_tokens, - cache_write_input_tokens=cache_write_input_tokens, ), message=ResponseMessage( content=content, diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index ea08822a5..494613e07 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -106,15 +106,13 @@ def get_usage_int_field(usage: Any, key: str) -> int | None: return None -def get_prompt_cache_token_fields(usage: Any) -> tuple[int | None, int | None]: +def get_cached_prompt_tokens(usage: Any) -> int | None: details = get_usage_field(usage, "prompt_tokens_details") if details is None: details = get_usage_field(usage, "input_tokens_details") if details is None: - return None, None - cached_tokens = get_usage_int_field(details, "cached_tokens") - cache_write_tokens = get_usage_int_field(details, "cache_write_tokens") - return cached_tokens, cache_write_tokens + return None + return get_usage_int_field(details, "cached_tokens") def content_to_text(content: Any) -> str: @@ -415,7 +413,7 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: completion_tokens, int ): return None - cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): @@ -428,7 +426,6 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: completion_tokens=completion_tokens, total_tokens=total_tokens, cached_input_tokens=cached_tokens, - cache_write_input_tokens=cache_write_tokens, ) def parse_is_truncated(response: OpenAIChatResponse) -> bool: diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index 7e1aa3712..f32e1d7c3 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -6,7 +6,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, - get_prompt_cache_token_fields, + get_cached_prompt_tokens, get_usage_field, handle_openai_overlong_prompt, ) @@ -126,7 +126,7 @@ def parse_usage(response: OpenAITextResponse) -> Usage | None: completion_tokens, int ): return None - cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): @@ -139,7 +139,6 @@ def parse_usage(response: OpenAITextResponse) -> Usage | None: completion_tokens=completion_tokens, total_tokens=total_tokens, cached_input_tokens=cached_tokens, - cache_write_input_tokens=cache_write_tokens, ) def parse_finish_reason(response: OpenAITextResponse) -> FinishReason: diff --git a/verifiers/clients/openai_responses_client.py b/verifiers/clients/openai_responses_client.py index 28672557d..7c10927d6 100644 --- a/verifiers/clients/openai_responses_client.py +++ b/verifiers/clients/openai_responses_client.py @@ -10,7 +10,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, - get_prompt_cache_token_fields, + get_cached_prompt_tokens, get_usage_field, handle_openai_overlong_prompt, ) @@ -388,7 +388,7 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: completion_tokens, int ): return None - cached_tokens, cache_write_tokens = get_prompt_cache_token_fields(usage) + cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) if not isinstance(total_tokens, int): @@ -403,7 +403,6 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: completion_tokens=completion_tokens, total_tokens=total_tokens, cached_input_tokens=cached_tokens, - cache_write_input_tokens=cache_write_tokens, ) def parse_is_truncated(response: OpenAIResponsesNativeResponse) -> bool: diff --git a/verifiers/envs/environment.py b/verifiers/envs/environment.py index e51ca4a17..86263cb6e 100644 --- a/verifiers/envs/environment.py +++ b/verifiers/envs/environment.py @@ -485,7 +485,7 @@ def get_state_usage(self, state: State) -> TokenUsage | None: "input_tokens": float(usage.get("input_tokens", 0.0)), "output_tokens": float(usage.get("output_tokens", 0.0)), } - for key in ("cached_input_tokens", "cache_write_input_tokens"): + for key in ("cached_input_tokens",): value = usage.get(key) if value is not None: out[key] = float(value) diff --git a/verifiers/scripts/tui.py b/verifiers/scripts/tui.py index fa53da9df..ac9eb7a81 100755 --- a/verifiers/scripts/tui.py +++ b/verifiers/scripts/tui.py @@ -3258,20 +3258,12 @@ def _build_header_summary_text(self) -> Text: input_tok = usage.get("input_tokens") output_tok = usage.get("output_tokens") cached_input_tok = usage.get("cached_input_tokens") - cache_write_input_tok = usage.get("cache_write_input_tokens") if input_tok is not None: usage_items.append(("Avg input tokens", format_numeric(input_tok))) if cached_input_tok is not None: usage_items.append( ("Avg cached input tokens", format_numeric(cached_input_tok)) ) - if cache_write_input_tok is not None: - usage_items.append( - ( - "Avg cache write input tokens", - format_numeric(cache_write_input_tok), - ) - ) if output_tok is not None: usage_items.append(("Avg output tokens", format_numeric(output_tok))) max_tokens = sampling_args.get("max_tokens") @@ -4742,7 +4734,6 @@ def _build_usage_text(self, record: Dict[str, Any]) -> Text: input_tok = token_usage.get("input_tokens") output_tok = token_usage.get("output_tokens") cached_input_tok = token_usage.get("cached_input_tokens") - cache_write_input_tok = token_usage.get("cache_write_input_tokens") final_inp = token_usage.get("final_input_tokens") final_outp = token_usage.get("final_output_tokens") if input_tok is not None: @@ -4751,11 +4742,6 @@ def _build_usage_text(self, record: Dict[str, Any]) -> Text: usage_lines.append( f"cached_input_tokens: {format_numeric(cached_input_tok)}" ) - if cache_write_input_tok is not None: - usage_lines.append( - "cache_write_input_tokens: " - f"{format_numeric(cache_write_input_tok)}" - ) if output_tok is not None: usage_lines.append(f"output_tokens: {format_numeric(output_tok)}") if final_inp is not None: diff --git a/verifiers/types.py b/verifiers/types.py index 28e54a2eb..a2cbece39 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -171,7 +171,6 @@ class Usage(CustomBaseModel): completion_tokens: int total_tokens: int cached_input_tokens: int | None = None - cache_write_input_tokens: int | None = None class ResponseTokens(CustomBaseModel): @@ -224,7 +223,6 @@ class TokenUsage(TypedDict): input_tokens: float output_tokens: float cached_input_tokens: NotRequired[float] - cache_write_input_tokens: NotRequired[float] final_input_tokens: NotRequired[float] final_output_tokens: NotRequired[float] diff --git a/verifiers/utils/eval_display.py b/verifiers/utils/eval_display.py index 1ba4f8dd1..748fb5969 100644 --- a/verifiers/utils/eval_display.py +++ b/verifiers/utils/eval_display.py @@ -392,11 +392,8 @@ def _make_tokens_row(self, usage: TokenUsage) -> Table: "output": format_numeric(usage.get("output_tokens", 0.0)), } cached = usage.get("cached_input_tokens") - cache_write = usage.get("cache_write_input_tokens") if cached is not None: kv["cached_input"] = format_numeric(cached) - if cache_write is not None: - kv["cache_write_input"] = format_numeric(cache_write) inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") if inp is not None: diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index 7a7bda443..837baec56 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -720,9 +720,7 @@ def print_usage(results: GenerateOutputs): input_total = 0.0 output_total = 0.0 cached_input_total = 0.0 - cache_write_input_total = 0.0 cached_input_count = 0 - cache_write_input_count = 0 final_input_total = 0.0 final_output_total = 0.0 context_count = 0 @@ -737,10 +735,6 @@ def print_usage(results: GenerateOutputs): if cached is not None: cached_input_total += float(cached) cached_input_count += 1 - cache_write = token_usage.get("cache_write_input_tokens") - if cache_write is not None: - cache_write_input_total += float(cache_write) - cache_write_input_count += 1 inp = token_usage.get("final_input_tokens") out = token_usage.get("final_output_tokens") if inp is not None and out is not None: @@ -756,10 +750,6 @@ def print_usage(results: GenerateOutputs): ) if cached_input_count > 0: usage["cached_input_tokens"] = cached_input_total / cached_input_count - if cache_write_input_count > 0: - usage["cache_write_input_tokens"] = ( - cache_write_input_total / cache_write_input_count - ) if context_count > 0: usage["final_input_tokens"] = final_input_total / context_count usage["final_output_tokens"] = final_output_total / context_count @@ -774,9 +764,6 @@ def print_usage(results: GenerateOutputs): cached = usage.get("cached_input_tokens") if cached is not None: print(f"cached_input_tokens (avg): {float(cached):.3f}") - cache_write = usage.get("cache_write_input_tokens") - if cache_write is not None: - print(f"cache_write_input_tokens (avg): {float(cache_write):.3f}") print(f"output_tokens (avg): {float(usage.get('output_tokens', 0.0)):.3f}") inp = usage.get("final_input_tokens") out = usage.get("final_output_tokens") diff --git a/verifiers/utils/interception_utils.py b/verifiers/utils/interception_utils.py index 53e592c35..23abe95fa 100644 --- a/verifiers/utils/interception_utils.py +++ b/verifiers/utils/interception_utils.py @@ -856,21 +856,12 @@ def serialize_anthropic_message_response(response: Response) -> dict[str, Any]: content.append({"type": "text", "text": ""}) usage = {} if response.usage is not None: - input_tokens = response.usage.prompt_tokens - if response.usage.cache_write_input_tokens is not None: - input_tokens = max( - 0, input_tokens - response.usage.cache_write_input_tokens - ) usage = { - "input_tokens": input_tokens, + "input_tokens": response.usage.prompt_tokens, "output_tokens": response.usage.completion_tokens, } if response.usage.cached_input_tokens is not None: usage["cache_read_input_tokens"] = response.usage.cached_input_tokens - if response.usage.cache_write_input_tokens is not None: - usage["cache_creation_input_tokens"] = ( - response.usage.cache_write_input_tokens - ) return { "id": response.id, "type": "message", diff --git a/verifiers/utils/metric_utils.py b/verifiers/utils/metric_utils.py index a7f57afac..4030865bd 100644 --- a/verifiers/utils/metric_utils.py +++ b/verifiers/utils/metric_utils.py @@ -98,12 +98,6 @@ class CachedInputTokensMetric(_TokenUsageKeyMetric): _key = "cached_input_tokens" -class CacheWriteInputTokensMetric(_TokenUsageKeyMetric): - """Mean cache_write_input_tokens per output.""" - - _key = "cache_write_input_tokens" - - class FinalInputTokensMetric(_TokenUsageKeyMetric): """Mean final_input_tokens (non-completion context tokens) per output.""" diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py index e61c3d3a3..bba26d360 100644 --- a/verifiers/utils/prompt_cache_utils.py +++ b/verifiers/utils/prompt_cache_utils.py @@ -68,7 +68,6 @@ def from_url( @dataclass(frozen=True) class PromptCachePolicy: mode: PromptCacheMode = "disabled" - prefire_groups: bool = False @property def enabled(self) -> bool: @@ -76,13 +75,11 @@ def enabled(self) -> bool: class PromptCacheAdapter: - prefire_groups = True - def policy_for( self, identity: EndpointIdentity, model: str ) -> PromptCachePolicy: _ = identity, model - return PromptCachePolicy(mode="implicit", prefire_groups=self.prefire_groups) + return PromptCachePolicy(mode="implicit") class AnthropicPromptCacheAdapter(PromptCacheAdapter): @@ -90,9 +87,7 @@ def policy_for( self, identity: EndpointIdentity, model: str ) -> PromptCachePolicy: _ = identity, model - return PromptCachePolicy( - mode="anthropic_top_level", prefire_groups=self.prefire_groups - ) + return PromptCachePolicy(mode="anthropic_top_level") class OpenRouterPromptCacheAdapter(PromptCacheAdapter): @@ -103,11 +98,8 @@ def policy_for( ) -> PromptCachePolicy: _ = identity if model.startswith(self.anthropic_model_prefixes): - return PromptCachePolicy( - mode="openrouter_anthropic_top_level", - prefire_groups=self.prefire_groups, - ) - return PromptCachePolicy(mode="implicit", prefire_groups=self.prefire_groups) + return PromptCachePolicy(mode="openrouter_anthropic_top_level") + return PromptCachePolicy(mode="implicit") @dataclass(frozen=True) @@ -172,19 +164,6 @@ def resolve_prompt_cache_policy( return spec.prompt_cache.policy_for(identity, model) -def should_prefire_prompt_cache_group( - client_or_config: object, model: str, group_size: int -) -> bool: - if group_size <= 1: - return False - config = client_or_config if isinstance(client_or_config, ClientConfig) else None - if config is None: - config = getattr(client_or_config, "config", None) - if not isinstance(config, ClientConfig): - return False - return resolve_prompt_cache_policy(config, model).prefire_groups - - def _cache_control_payload() -> dict[str, str]: return {"type": "ephemeral"} diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 14f96866d..5e78c8fe8 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -27,7 +27,6 @@ serialize_messages_for_output, ) from verifiers.utils.metric_utils import ( - CacheWriteInputTokensMetric, CachedInputTokensMetric, EnvMetrics, ErrorRateMetric, @@ -118,7 +117,7 @@ def _coerce_token_usage(value: object) -> TokenUsage | None: "input_tokens": input_tokens, "output_tokens": output_tokens, } - for key in ("cached_input_tokens", "cache_write_input_tokens"): + for key in ("cached_input_tokens",): raw_value = mapping_value.get(key) if raw_value is None: continue @@ -339,7 +338,6 @@ def __init__( self.input_tokens = InputTokensMetric() self.output_tokens = OutputTokensMetric() self.cached_input_tokens = CachedInputTokensMetric() - self.cache_write_input_tokens = CacheWriteInputTokensMetric() self.final_input_tokens = FinalInputTokensMetric() self.final_output_tokens = FinalOutputTokensMetric() self.pass_at_k = PassAtKMetric(rollouts_per_example, threshold=pass_threshold) @@ -391,7 +389,6 @@ def add_outputs(self, new_outputs: list[RolloutOutput]) -> None: self.input_tokens.add_outputs(new_outputs) self.output_tokens.add_outputs(new_outputs) self.cached_input_tokens.add_outputs(new_outputs) - self.cache_write_input_tokens.add_outputs(new_outputs) self.final_input_tokens.add_outputs(new_outputs) self.final_output_tokens.add_outputs(new_outputs) self.pass_at_k.add_outputs(new_outputs) @@ -418,10 +415,6 @@ def build_metadata(self) -> GenerateMetadata: ) if self.cached_input_tokens.count > 0: usage["cached_input_tokens"] = self.cached_input_tokens.compute() - if self.cache_write_input_tokens.count > 0: - usage["cache_write_input_tokens"] = ( - self.cache_write_input_tokens.compute() - ) if self.final_input_tokens.count > 0: usage["final_input_tokens"] = self.final_input_tokens.compute() usage["final_output_tokens"] = self.final_output_tokens.compute() diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index fb6424f8e..a514b1953 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -94,22 +94,11 @@ def extract_usage_token_details(response: object) -> dict[str, int] | None: if subtract_cached_from_input: details["input_tokens"] = max(0, details["input_tokens"] - cached_int) - cache_write_input_tokens = _get_optional_usage_value( - usage, "cache_write_input_tokens" + cache_creation_input_tokens = _get_optional_usage_value( + usage, "cache_creation_input_tokens" ) - add_cache_write_to_input = False - if cache_write_input_tokens is None: - cache_write_input_tokens = _get_optional_usage_value( - usage, "cache_creation_input_tokens" - ) - add_cache_write_to_input = cache_write_input_tokens is not None - if cache_write_input_tokens is None: - cache_write_input_tokens = _get_nested_usage_value(usage, "cache_write_tokens") - if cache_write_input_tokens is not None: - cache_write_int = _coerce_usage_int(cache_write_input_tokens) - details["cache_write_input_tokens"] = cache_write_int - if add_cache_write_to_input: - details["input_tokens"] += cache_write_int + if cache_creation_input_tokens is not None: + details["input_tokens"] += _coerce_usage_int(cache_creation_input_tokens) return details @@ -144,7 +133,6 @@ def increment( output_tokens: int | float = 0, *, cached_input_tokens: int | float | None = None, - cache_write_input_tokens: int | float | None = None, mark_seen: bool = True, ) -> None: deltas: dict[str, float] = { @@ -153,10 +141,6 @@ def increment( } if cached_input_tokens is not None: deltas["cached_input_tokens"] = float(cached_input_tokens or 0.0) - if cache_write_input_tokens is not None: - deltas["cache_write_input_tokens"] = float( - cache_write_input_tokens or 0.0 - ) if any(delta < 0 for delta in deltas.values()): raise ValueError("Token usage increments must be non-negative.") if mark_seen: @@ -174,7 +158,6 @@ def increment_from_response(self, response: object) -> None: details["input_tokens"], details["output_tokens"], cached_input_tokens=details.get("cached_input_tokens"), - cache_write_input_tokens=details.get("cache_write_input_tokens"), mark_seen=True, ) @@ -189,7 +172,7 @@ def cast_token_usage(usage: Mapping[str, Any]) -> TokenUsage: "input_tokens": float(usage.get("input_tokens", 0.0)), "output_tokens": float(usage.get("output_tokens", 0.0)), } - for key in ("cached_input_tokens", "cache_write_input_tokens"): + for key in ("cached_input_tokens",): value = usage.get(key) if value is not None: out[key] = float(value) diff --git a/verifiers/v1/env.py b/verifiers/v1/env.py index 2c40fce99..925fc8407 100644 --- a/verifiers/v1/env.py +++ b/verifiers/v1/env.py @@ -9,7 +9,6 @@ from verifiers.clients import Client from verifiers.types import ClientConfig from verifiers.types import RolloutInput, SamplingArgs -from verifiers.utils.prompt_cache_utils import should_prefire_prompt_cache_group from .harness import Harness from .state import State @@ -103,24 +102,12 @@ async def _run_group_states( "score_rollout": self.score_rollouts, }, ) - task_state_pairs = list(zip(tasks, states)) - if should_prefire_prompt_cache_group(client, model, len(task_state_pairs)): - first_task, first_state = task_state_pairs[0] - first_result = await self.harness.run(first_task, first_state) - remaining_results = await asyncio.gather( - *[ - self.harness.run(task, state) - for task, state in task_state_pairs[1:] - ] - ) - states = [first_result, *remaining_results] - else: - states = await asyncio.gather( - *[ - self.harness.run(task, state) - for task, state in task_state_pairs - ] - ) + states = await asyncio.gather( + *[ + self.harness.run(task, state) + for task, state in zip(tasks, states) + ] + ) try: if self.score_rollouts: await self.harness.score_group(tasks, states) From 10e00303c64249a71b4e9df592868d592b53766b Mon Sep 17 00:00:00 2001 From: William Brown Date: Mon, 18 May 2026 16:34:42 -0700 Subject: [PATCH 3/9] Fix prompt cache type checks after main merge --- verifiers/utils/prompt_cache_utils.py | 11 +++++++---- verifiers/utils/save_utils.py | 9 ++++++++- verifiers/utils/usage_utils.py | 14 +++++++++----- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py index af0b8cc64..cf2f4db62 100644 --- a/verifiers/utils/prompt_cache_utils.py +++ b/verifiers/utils/prompt_cache_utils.py @@ -1,10 +1,13 @@ from collections.abc import Mapping from dataclasses import dataclass -from typing import Any, Literal +from typing import Any, Literal, TypeVar from urllib.parse import urlsplit from verifiers.types import ClientConfig, ClientType +NativePromptT = TypeVar("NativePromptT") +NativeToolsT = TypeVar("NativeToolsT") + PromptCacheMode = Literal[ "disabled", "implicit", @@ -162,11 +165,11 @@ def apply_prompt_cache_to_request( *, config: ClientConfig | None, model: str, - native_prompt: object, - native_tools: object, + native_prompt: NativePromptT, + native_tools: NativeToolsT, sampling_args: Mapping[str, Any], extra_kwargs: Mapping[str, Any], -) -> tuple[object, object, dict[str, Any], dict[str, Any]]: +) -> tuple[NativePromptT, NativeToolsT, dict[str, Any], dict[str, Any]]: policy = resolve_prompt_cache_policy(config, model) updated_sampling_args = dict(sampling_args) updated_extra_kwargs = dict(extra_kwargs) diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 3b1295014..45240f489 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -232,7 +232,14 @@ def state_to_output( ) usage = _extract_state_token_usage(state) if usage is not None: - token_usage: dict[str, float] = dict(usage) + token_usage: dict[str, float] = { + "input_tokens": usage["input_tokens"], + "output_tokens": usage["output_tokens"], + } + for key in ("cached_input_tokens", "final_input_tokens", "final_output_tokens"): + value = usage.get(key) + if value is not None: + token_usage[key] = value # Add context token metrics from trajectory trajectory = state.get("trajectory", []) if isinstance(trajectory, list): diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index d62a5a336..ea3d7c99f 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -1,20 +1,22 @@ import math from collections.abc import Mapping, Sequence from types import MappingProxyType -from typing import Any +from typing import Any, cast from verifiers.types import Response, TokenUsage, Usage def _get_usage_value(usage_obj: object, key: str) -> object: if isinstance(usage_obj, Mapping): - return usage_obj.get(key, 0) + usage_mapping = cast(Mapping[str, object], usage_obj) + return usage_mapping.get(key, 0) return getattr(usage_obj, key, 0) def _get_optional_usage_value(usage_obj: object, key: str) -> object: if isinstance(usage_obj, Mapping): - return usage_obj.get(key) + usage_mapping = cast(Mapping[str, object], usage_obj) + return usage_mapping.get(key) return getattr(usage_obj, key, None) @@ -24,7 +26,8 @@ def _get_nested_usage_value(usage_obj: object, key: str) -> object: return value details = _get_optional_usage_value(usage_obj, "prompt_tokens_details") if isinstance(details, Mapping): - return details.get(key) + details_mapping = cast(Mapping[str, object], details) + return details_mapping.get(key) if details is not None: return getattr(details, key, None) return None @@ -32,7 +35,8 @@ def _get_nested_usage_value(usage_obj: object, key: str) -> object: def _get_response_usage(response: object) -> object: if isinstance(response, Mapping): - return response.get("usage") + response_mapping = cast(Mapping[str, object], response) + return response_mapping.get("usage") return getattr(response, "usage", None) From 66ef2ce9d51ab21a4ec8bd7acfba4a25d81de85d Mon Sep 17 00:00:00 2001 From: William Brown Date: Wed, 20 May 2026 01:25:02 -0700 Subject: [PATCH 4/9] Address prompt cache PR feedback --- skills/evaluate-environments/SKILL.md | 1 + tests/test_prompt_cache_utils.py | 32 +++++++++++++++++++++++---- verifiers/utils/usage_utils.py | 17 +++++++++----- 3 files changed, 40 insertions(+), 10 deletions(-) diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index b34fce774..e9a1c0517 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -73,6 +73,7 @@ url = "https://api.openai.com/v1" key = "OPENAI_API_KEY" api_client_type = "openai_responses" ``` +9. Prompt caching is automatic for supported official providers inferred from the endpoint URL and API client type: OpenAI, Anthropic, and OpenRouter. Do not ask users to configure prompt caching for normal evals. Use `prompt_cache = false` on an endpoint row or `[[eval]]` entry only when a specific run needs to bypass provider prompt caching. ## Publish Gate Before Large Runs 1. After smoke tests pass and results look stable, proactively suggest pushing the environment to Hub before large eval sweeps or RL work. diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index d18ad9b94..83ceca4f4 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -33,7 +33,9 @@ async def to_native_tool(self, tool): async def to_native_prompt(self, messages): return messages, {} - async def get_native_response(self, prompt, model, sampling_args, tools=None, **kwargs): + async def get_native_response( + self, prompt, model, sampling_args, tools=None, **kwargs + ): self.request = { "prompt": prompt, "model": model, @@ -101,7 +103,9 @@ async def to_native_tool(self, tool): async def to_native_prompt(self, messages): return messages, {} - async def get_native_response(self, prompt, model, sampling_args, tools=None, **kwargs): + async def get_native_response( + self, prompt, model, sampling_args, tools=None, **kwargs + ): raise AssertionError("get_response is implemented directly") async def raise_from_native_response(self, response): @@ -275,8 +279,10 @@ async def test_v1_group_rollouts_start_concurrently_for_cached_provider(): ) ) env = vf.Env( - taskset=vf.Taskset(source=[{"question": "q"}]), - harness=vf.Harness(max_turns=1), + taskset=vf.Taskset( + config=vf.TasksetConfig.model_validate({"source": [{"question": "q"}]}) + ), + harness=vf.Harness(config=vf.HarnessConfig.model_validate({"max_turns": 1})), ) await env._run_group_states( @@ -390,6 +396,24 @@ def test_serialized_response_usage_counts_cache_details(): } +def test_serialized_responses_usage_counts_input_token_cache_details(): + response = { + "usage": { + "input_tokens": 100, + "output_tokens": 7, + "input_tokens_details": { + "cached_tokens": 80, + }, + } + } + + assert extract_usage_token_details(response) == { + "input_tokens": 20, + "output_tokens": 7, + "cached_input_tokens": 80, + } + + def test_state_output_fallback_reads_serialized_trajectory_usage(): task = vf.Task( { diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index ea3d7c99f..59feb062d 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -24,12 +24,17 @@ def _get_nested_usage_value(usage_obj: object, key: str) -> object: value = _get_optional_usage_value(usage_obj, key) if value is not None: return value - details = _get_optional_usage_value(usage_obj, "prompt_tokens_details") - if isinstance(details, Mapping): - details_mapping = cast(Mapping[str, object], details) - return details_mapping.get(key) - if details is not None: - return getattr(details, key, None) + for details_key in ("prompt_tokens_details", "input_tokens_details"): + details = _get_optional_usage_value(usage_obj, details_key) + if isinstance(details, Mapping): + details_mapping = cast(Mapping[str, object], details) + nested_value = details_mapping.get(key) + if nested_value is not None: + return nested_value + elif details is not None: + nested_value = getattr(details, key, None) + if nested_value is not None: + return nested_value return None From b713c88324f339d4cfa754a69c8e10defe5f8396 Mon Sep 17 00:00:00 2001 From: William Brown Date: Wed, 20 May 2026 21:47:46 -0700 Subject: [PATCH 5/9] Shrink prompt cache integration --- docs/evaluation.md | 2 +- skills/evaluate-environments/SKILL.md | 2 +- tests/test_endpoint_registry.py | 32 ---- tests/test_eval_cli.py | 46 ------ tests/test_prompt_cache_utils.py | 214 ++++++-------------------- verifiers/clients/client.py | 4 - verifiers/scripts/eval.py | 18 --- verifiers/types.py | 3 - verifiers/utils/eval_utils.py | 7 - verifiers/utils/prompt_cache_utils.py | 191 ++++------------------- verifiers/utils/usage_utils.py | 195 +++++++++-------------- 11 files changed, 146 insertions(+), 568 deletions(-) diff --git a/docs/evaluation.md b/docs/evaluation.md index 83c1052b2..6eff1467d 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -167,7 +167,7 @@ In `[[eval]]` TOML configs you can set extra headers as `headers = { ... }` and/ For per-request headers that need to vary per rollout (e.g. sticky DP-aware routing keyed off `example_id` or `trajectory_id`), use `headers_from_state = { "X-Name" = "state_key" }` and/or `header_from_state = ["X-Name: state_key", ...]` (same form as repeated `--header-from-state`). The value for each request is resolved at send time as `state[state_key]`. If unset, `X-Session-ID` defaults to `example_id`. -Prompt caching is automatic for supported official providers inferred from `url` and `api_client_type`: OpenAI (`https://api.openai.com`), Anthropic (`https://api.anthropic.com`), and OpenRouter (`https://openrouter.ai`). Unsupported providers run unchanged. Set `prompt_cache = false` on an endpoint row or `[[eval]]` only when you need to disable this behavior for a specific run. +Provider prompt caches are managed by the upstream API. Verifiers reports provider cache hits as `cached_input_tokens` when they appear in usage data, and automatically sends Anthropic's prompt-cache hint for official Anthropic Messages endpoints. To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index e9a1c0517..5ef865fa8 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -73,7 +73,7 @@ url = "https://api.openai.com/v1" key = "OPENAI_API_KEY" api_client_type = "openai_responses" ``` -9. Prompt caching is automatic for supported official providers inferred from the endpoint URL and API client type: OpenAI, Anthropic, and OpenRouter. Do not ask users to configure prompt caching for normal evals. Use `prompt_cache = false` on an endpoint row or `[[eval]]` entry only when a specific run needs to bypass provider prompt caching. +9. Do not ask users to configure prompt caching for normal evals. Verifiers reports provider cache hits when usage data includes them, and official Anthropic Messages endpoints receive Anthropic's prompt-cache hint automatically. ## Publish Gate Before Large Runs 1. After smoke tests pass and results look stable, proactively suggest pushing the environment to Hub before large eval sweeps or RL work. diff --git a/tests/test_endpoint_registry.py b/tests/test_endpoint_registry.py index 2ee61b0f3..6a97eed5e 100644 --- a/tests/test_endpoint_registry.py +++ b/tests/test_endpoint_registry.py @@ -243,38 +243,6 @@ def test_load_endpoints_toml_accepts_extra_headers_alias(tmp_path: Path): assert endpoints["proxy"][0]["extra_headers"] == {"X-A": "a"} -def test_load_endpoints_toml_accepts_prompt_cache_opt_out(tmp_path: Path): - registry_path = tmp_path / "endpoints.toml" - registry_path.write_text( - "[[endpoint]]\n" - 'endpoint_id = "openai"\n' - 'model = "m"\n' - 'url = "https://api.openai.com/v1"\n' - 'key = "OPENAI_API_KEY"\n' - "prompt_cache = false\n", - encoding="utf-8", - ) - - endpoints = load_endpoints(str(registry_path)) - - assert endpoints["openai"][0]["prompt_cache"] is False - - -def test_load_endpoints_toml_rejects_non_bool_prompt_cache(tmp_path: Path): - registry_path = tmp_path / "endpoints.toml" - registry_path.write_text( - "[[endpoint]]\n" - 'endpoint_id = "openai"\n' - 'model = "m"\n' - 'url = "https://api.openai.com/v1"\n' - 'key = "OPENAI_API_KEY"\n' - 'prompt_cache = "yes"\n', - encoding="utf-8", - ) - - assert load_endpoints(str(registry_path)) == {} - - def test_load_endpoints_toml_rejects_headers_and_extra_headers_together( tmp_path: Path, ): diff --git a/tests/test_eval_cli.py b/tests/test_eval_cli.py index 16adb5693..5ac7e9cf0 100644 --- a/tests/test_eval_cli.py +++ b/tests/test_eval_cli.py @@ -320,52 +320,6 @@ def test_cli_registry_headers_merged_with_eval_toml(tmp_path, monkeypatch, run_c } -def test_cli_registry_prompt_cache_opt_out_flows_to_client_config(monkeypatch, run_cli): - captured = run_cli( - monkeypatch, - { - "model": "openai", - "api_base_url": None, - "api_key_var": None, - }, - endpoints={ - "openai": [ - { - "model": "gpt-5.4-mini", - "key": "OPENAI_API_KEY", - "url": "https://api.openai.com/v1", - "prompt_cache": False, - } - ] - }, - ) - - assert captured["configs"][0].client_config.prompt_cache is False - - -def test_cli_toml_prompt_cache_opt_out_overrides_registry(monkeypatch, run_cli): - captured = run_cli( - monkeypatch, - { - "model": "openai", - "api_base_url": None, - "api_key_var": None, - "prompt_cache": False, - }, - endpoints={ - "openai": [ - { - "model": "gpt-5.4-mini", - "key": "OPENAI_API_KEY", - "url": "https://api.openai.com/v1", - } - ] - }, - ) - - assert captured["configs"][0].client_config.prompt_cache is False - - def test_cli_multi_variant_preserves_per_row_registry_headers(monkeypatch, run_cli): captured = run_cli( monkeypatch, diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index 83ceca4f4..d274ce497 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -1,6 +1,3 @@ -from __future__ import annotations - -import asyncio from types import SimpleNamespace import pytest @@ -11,9 +8,9 @@ from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient from verifiers.types import ClientConfig, Response, ResponseMessage from verifiers.utils.prompt_cache_utils import ( - EndpointIdentity, apply_prompt_cache_to_request, - resolve_prompt_cache_policy, + endpoint_origin, + uses_official_anthropic_messages, ) from verifiers.utils.save_utils import state_to_output from verifiers.utils.usage_utils import extract_usage_token_details @@ -66,127 +63,38 @@ async def close(self) -> None: pass -class ConcurrentStartClient(Client): - def __init__(self, config: ClientConfig): - super().__init__(config) - self.first_active = False - self.started_any = False - self.overlapped_first = False - - def setup_client(self, config): - return object() - - async def get_response(self, prompt, model, sampling_args, tools=None, **kwargs): - _ = prompt, model, sampling_args, tools, kwargs - if self.first_active: - self.overlapped_first = True - elif not self.started_any: - self.started_any = True - self.first_active = True - await asyncio.sleep(0.01) - self.first_active = False - return Response( - id="resp", - created=0, - model="model", - usage=None, - message=ResponseMessage( - content="ok", - finish_reason="stop", - is_truncated=False, - ), - ) - - async def to_native_tool(self, tool): - return tool - - async def to_native_prompt(self, messages): - return messages, {} - - async def get_native_response( - self, prompt, model, sampling_args, tools=None, **kwargs - ): - raise AssertionError("get_response is implemented directly") - - async def raise_from_native_response(self, response): - _ = response - - async def from_native_response(self, response): - _ = response - - async def close(self) -> None: - pass - - -def test_endpoint_identity_normalizes_official_origins(): - identity = EndpointIdentity.from_url( - "https://api.openai.com/v1", "openai_chat_completions" - ) - - assert identity is not None - assert identity.origin == "https://api.openai.com" - assert identity.host == "api.openai.com" - assert identity.path == "/v1" - - -def test_prompt_cache_policy_is_inferred_from_url_and_type(): +def test_endpoint_origin_normalizes_urls(): assert ( - resolve_prompt_cache_policy( - ClientConfig( - client_type="openai_responses", - api_base_url="https://api.openai.com/v1", - ), - "gpt-5.4-mini", - ).mode - == "implicit" - ) - assert ( - resolve_prompt_cache_policy( - ClientConfig( - client_type="anthropic_messages", - api_base_url="https://api.anthropic.com", - ), - "claude-sonnet-4-5", - ).mode - == "anthropic_top_level" - ) - assert ( - resolve_prompt_cache_policy( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://openrouter.ai/api/v1", - ), - "anthropic/claude-sonnet-4.5", - ).mode - == "openrouter_anthropic_top_level" + endpoint_origin("https://api.anthropic.com/v1") == "https://api.anthropic.com" ) - assert ( - resolve_prompt_cache_policy( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.example.com/v1", - ), - "model", - ).mode - == "disabled" + assert endpoint_origin("https://api.anthropic.com:443/v1") == ( + "https://api.anthropic.com" ) + assert endpoint_origin("http://localhost:8080/v1") == "http://localhost:8080" -def test_prompt_cache_false_disables_inferred_provider_policy(): - policy = resolve_prompt_cache_policy( +def test_official_anthropic_messages_endpoint_is_cache_control_target(): + assert uses_official_anthropic_messages( + ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com", + ) + ) + assert not uses_official_anthropic_messages( ClientConfig( client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - prompt_cache=False, - ), - "gpt-5.4-mini", + api_base_url="https://api.anthropic.com", + ) + ) + assert not uses_official_anthropic_messages( + ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.pinference.ai/api/v1", + ) ) - - assert policy.mode == "disabled" - assert not policy.enabled -def test_anthropic_request_policy_adds_top_level_cache_control(): +def test_anthropic_request_adds_top_level_cache_control(): native_prompt, native_tools, sampling_args, extra_kwargs = ( apply_prompt_cache_to_request( config=ClientConfig( @@ -207,31 +115,7 @@ def test_anthropic_request_policy_adds_top_level_cache_control(): assert extra_kwargs["cache_control"] == {"type": "ephemeral"} -def test_openrouter_anthropic_policy_uses_extra_body_cache_control(): - native_prompt, native_tools, sampling_args, extra_kwargs = ( - apply_prompt_cache_to_request( - config=ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://openrouter.ai/api/v1", - ), - model="anthropic/claude-sonnet-4.5", - native_prompt=[{"role": "user", "content": "question"}], - native_tools=[], - sampling_args={"max_tokens": 16, "extra_body": {"foo": "bar"}}, - extra_kwargs={}, - ) - ) - - assert native_prompt == [{"role": "user", "content": "question"}] - assert native_tools == [] - assert extra_kwargs == {} - assert sampling_args["extra_body"] == { - "foo": "bar", - "cache_control": {"type": "ephemeral"}, - } - - -def test_openai_policy_does_not_mutate_request(): +def test_openai_request_does_not_mutate_request(): native_prompt, native_tools, sampling_args, extra_kwargs = ( apply_prompt_cache_to_request( config=ClientConfig( @@ -252,8 +136,25 @@ def test_openai_policy_does_not_mutate_request(): assert extra_kwargs == {} +def test_non_official_anthropic_endpoint_does_not_add_cache_control(): + _, _, sampling_args, extra_kwargs = apply_prompt_cache_to_request( + config=ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.pinference.ai/api/v1", + ), + model="claude-sonnet-4-5", + native_prompt=[], + native_tools=None, + sampling_args={"max_tokens": 16}, + extra_kwargs={}, + ) + + assert sampling_args == {"max_tokens": 16} + assert extra_kwargs == {} + + @pytest.mark.asyncio -async def test_client_request_hook_applies_prompt_cache_policy(): +async def test_client_request_hook_applies_anthropic_cache_control(): client = RecordingClient( ClientConfig( client_type="anthropic_messages", @@ -270,35 +171,6 @@ async def test_client_request_hook_applies_prompt_cache_policy(): assert client.request["kwargs"]["cache_control"] == {"type": "ephemeral"} -@pytest.mark.asyncio -async def test_v1_group_rollouts_start_concurrently_for_cached_provider(): - client = ConcurrentStartClient( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - ) - ) - env = vf.Env( - taskset=vf.Taskset( - config=vf.TasksetConfig.model_validate({"source": [{"question": "q"}]}) - ), - harness=vf.Harness(config=vf.HarnessConfig.model_validate({"max_turns": 1})), - ) - - await env._run_group_states( - [ - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - {"prompt": [{"role": "user", "content": "q"}], "example_id": 0}, - ], - client, - "gpt-5.4-mini", - {}, - ) - - assert client.overlapped_first - - @pytest.mark.asyncio async def test_openai_usage_splits_cached_input_tokens(): client = OpenAIChatCompletionsClient(object()) diff --git a/verifiers/clients/client.py b/verifiers/clients/client.py index 5879b39e6..5da565c12 100644 --- a/verifiers/clients/client.py +++ b/verifiers/clients/client.py @@ -51,10 +51,6 @@ def __init__(self, client_or_config: ClientT | ClientConfig) -> None: def client(self) -> ClientT: return self._client - @property - def config(self) -> ClientConfig | None: - return self._config - @abstractmethod def setup_client(self, config: ClientConfig) -> ClientT: ... diff --git a/verifiers/scripts/eval.py b/verifiers/scripts/eval.py index 7dc600a84..0c3c07828 100644 --- a/verifiers/scripts/eval.py +++ b/verifiers/scripts/eval.py @@ -183,15 +183,6 @@ def build_extra_headers_from_state(raw: dict[str, Any]) -> dict[str, str]: return {**table, **from_list} -def build_prompt_cache_enabled(raw: dict[str, Any], default: bool = True) -> bool: - raw_prompt_cache = raw.get("prompt_cache") - if raw_prompt_cache is None: - return default - if not isinstance(raw_prompt_cache, bool): - raise ValueError("'prompt_cache' must be a boolean when provided.") - return raw_prompt_cache - - def get_env_eval_defaults(env_id: str) -> dict[str, Any]: """Get eval config defaults from the environment module's pyproject.toml. @@ -720,18 +711,13 @@ def build_eval_config(raw: dict) -> EvalConfig: } registry_headers_base: dict[str, str] = {} - registry_prompt_cache = True if endpoint_group is not None: registry_headers_base = dict(endpoint_group[0].get("extra_headers", {})) - registry_prompt_cache = bool(endpoint_group[0].get("prompt_cache", True)) merged_headers: dict[str, str] = { **registry_headers_base, **eval_headers_merged, } - prompt_cache_enabled = build_prompt_cache_enabled( - raw, default=registry_prompt_cache - ) primary_api_base_url = api_base_url if not isinstance(primary_api_base_url, str): @@ -756,9 +742,6 @@ def build_eval_config(raw: dict) -> EvalConfig: **dict(ep.get("extra_headers", {})), **eval_headers_merged, }, - prompt_cache=build_prompt_cache_enabled( - raw, default=bool(ep.get("prompt_cache", True)) - ), ) for ep in endpoint_group ] @@ -771,7 +754,6 @@ def build_eval_config(raw: dict) -> EvalConfig: endpoint_configs=endpoint_configs, extra_headers=merged_headers, extra_headers_from_state=eval_headers_from_state, - prompt_cache=prompt_cache_enabled, ) # Backward-compatible TOML field: resume_path diff --git a/verifiers/types.py b/verifiers/types.py index bb7157970..91ec44ccb 100644 --- a/verifiers/types.py +++ b/verifiers/types.py @@ -997,7 +997,6 @@ class RolloutScores(TypedDict): "model": str, "api_client_type": NotRequired[ClientType], "extra_headers": NotRequired[dict[str, str]], - "prompt_cache": NotRequired[bool], }, ) Endpoints = dict[str, list[Endpoint]] @@ -1044,7 +1043,6 @@ class ClientConfig(BaseModel): 'e.g. {"X-Session-ID": "example_id"} adds a X-Session-ID header ' "with the value of state['example_id'].", ) - prompt_cache: bool = True @field_validator("extra_headers", mode="before") @classmethod @@ -1105,7 +1103,6 @@ class EndpointClientConfig(BaseModel): max_keepalive_connections: int = 28000 max_retries: int = 10 extra_headers: dict[str, str] = Field(default_factory=dict) - prompt_cache: bool = True @field_validator("extra_headers", mode="before") @classmethod diff --git a/verifiers/utils/eval_utils.py b/verifiers/utils/eval_utils.py index f9048d106..7795327bf 100644 --- a/verifiers/utils/eval_utils.py +++ b/verifiers/utils/eval_utils.py @@ -229,12 +229,6 @@ def _coerce_endpoint(raw_endpoint: object, source: str) -> Endpoint: if coerced_headers: endpoint["extra_headers"] = coerced_headers - raw_prompt_cache = raw_endpoint_dict.get("prompt_cache") - if raw_prompt_cache is not None: - if not isinstance(raw_prompt_cache, bool): - raise ValueError(f"Field 'prompt_cache' must be a boolean in {source}") - endpoint["prompt_cache"] = raw_prompt_cache - return endpoint @@ -561,7 +555,6 @@ def load_toml_config( "headers", "header_from_state", "headers_from_state", - "prompt_cache", # sampling "sampling_args", "max_tokens", diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py index cf2f4db62..8d5966f42 100644 --- a/verifiers/utils/prompt_cache_utils.py +++ b/verifiers/utils/prompt_cache_utils.py @@ -1,160 +1,36 @@ from collections.abc import Mapping -from dataclasses import dataclass -from typing import Any, Literal, TypeVar +from typing import Any, TypeVar from urllib.parse import urlsplit -from verifiers.types import ClientConfig, ClientType +from verifiers.types import ClientConfig NativePromptT = TypeVar("NativePromptT") NativeToolsT = TypeVar("NativeToolsT") -PromptCacheMode = Literal[ - "disabled", - "implicit", - "anthropic_top_level", - "openrouter_anthropic_top_level", -] +ANTHROPIC_ORIGINS = frozenset({"https://api.anthropic.com"}) -OPENAI_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset( - { - "openai_chat_completions", - "openai_responses", - } -) -ANTHROPIC_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset({"anthropic_messages"}) -OPENROUTER_CACHE_CLIENT_TYPES: frozenset[ClientType] = frozenset( - { - "openai_chat_completions", - "openai_responses", - } -) - -@dataclass(frozen=True) -class EndpointIdentity: - client_type: ClientType - origin: str - host: str - path: str - - @classmethod - def from_config(cls, config: ClientConfig) -> "EndpointIdentity | None": - return cls.from_url(config.api_base_url, config.client_type) - - @classmethod - def from_url( - cls, api_base_url: str, client_type: ClientType - ) -> "EndpointIdentity | None": - parsed = urlsplit(api_base_url) - if not parsed.scheme or not parsed.hostname: - return None - scheme = parsed.scheme.lower() - host = parsed.hostname.lower() - port = parsed.port - netloc = host - if port is not None and not ( - (scheme == "https" and port == 443) or (scheme == "http" and port == 80) - ): - netloc = f"{host}:{port}" - return cls( - client_type=client_type, - origin=f"{scheme}://{netloc}", - host=host, - path=parsed.path or "", - ) - - -@dataclass(frozen=True) -class PromptCachePolicy: - mode: PromptCacheMode = "disabled" - - @property - def enabled(self) -> bool: - return self.mode != "disabled" - - -class PromptCacheAdapter: - def policy_for(self, identity: EndpointIdentity, model: str) -> PromptCachePolicy: - _ = identity, model - return PromptCachePolicy(mode="implicit") - - -class AnthropicPromptCacheAdapter(PromptCacheAdapter): - def policy_for(self, identity: EndpointIdentity, model: str) -> PromptCachePolicy: - _ = identity, model - return PromptCachePolicy(mode="anthropic_top_level") - - -class OpenRouterPromptCacheAdapter(PromptCacheAdapter): - anthropic_model_prefixes = ("anthropic/",) - - def policy_for(self, identity: EndpointIdentity, model: str) -> PromptCachePolicy: - _ = identity - if model.startswith(self.anthropic_model_prefixes): - return PromptCachePolicy(mode="openrouter_anthropic_top_level") - return PromptCachePolicy(mode="implicit") - - -@dataclass(frozen=True) -class ProviderSpec: - provider_id: str - origins: frozenset[str] - client_types: frozenset[ClientType] - prompt_cache: PromptCacheAdapter - - def recognizes(self, identity: EndpointIdentity) -> bool: - return ( - identity.origin in self.origins - and identity.client_type in self.client_types - ) - - -PROVIDER_SPECS: tuple[ProviderSpec, ...] = ( - ProviderSpec( - provider_id="openai", - origins=frozenset({"https://api.openai.com"}), - client_types=OPENAI_CACHE_CLIENT_TYPES, - prompt_cache=PromptCacheAdapter(), - ), - ProviderSpec( - provider_id="anthropic", - origins=frozenset({"https://api.anthropic.com"}), - client_types=ANTHROPIC_CACHE_CLIENT_TYPES, - prompt_cache=AnthropicPromptCacheAdapter(), - ), - ProviderSpec( - provider_id="openrouter", - origins=frozenset({"https://openrouter.ai"}), - client_types=OPENROUTER_CACHE_CLIENT_TYPES, - prompt_cache=OpenRouterPromptCacheAdapter(), - ), -) - -DISABLED_PROMPT_CACHE_POLICY = PromptCachePolicy() - - -def infer_provider_spec(config: ClientConfig) -> ProviderSpec | None: - identity = EndpointIdentity.from_config(config) - if identity is None: +def endpoint_origin(api_base_url: str) -> str | None: + parsed = urlsplit(api_base_url) + if not parsed.scheme or not parsed.hostname: return None - for spec in PROVIDER_SPECS: - if spec.recognizes(identity): - return spec - return None - - -def resolve_prompt_cache_policy( - config: ClientConfig | None, model: str -) -> PromptCachePolicy: - if config is None or not config.prompt_cache: - return DISABLED_PROMPT_CACHE_POLICY - identity = EndpointIdentity.from_config(config) - if identity is None: - return DISABLED_PROMPT_CACHE_POLICY - spec = infer_provider_spec(config) - if spec is None: - return DISABLED_PROMPT_CACHE_POLICY - return spec.prompt_cache.policy_for(identity, model) + scheme = parsed.scheme.lower() + host = parsed.hostname.lower() + port = parsed.port + netloc = host + if port is not None and not ( + (scheme == "https" and port == 443) or (scheme == "http" and port == 80) + ): + netloc = f"{host}:{port}" + return f"{scheme}://{netloc}" + + +def uses_official_anthropic_messages(config: ClientConfig | None) -> bool: + return ( + config is not None + and config.client_type == "anthropic_messages" + and endpoint_origin(config.api_base_url) in ANTHROPIC_ORIGINS + ) def _cache_control_payload() -> dict[str, str]: @@ -170,23 +46,8 @@ def apply_prompt_cache_to_request( sampling_args: Mapping[str, Any], extra_kwargs: Mapping[str, Any], ) -> tuple[NativePromptT, NativeToolsT, dict[str, Any], dict[str, Any]]: - policy = resolve_prompt_cache_policy(config, model) - updated_sampling_args = dict(sampling_args) + _ = model updated_extra_kwargs = dict(extra_kwargs) - updated_native_prompt = native_prompt - if policy.mode == "anthropic_top_level": + if uses_official_anthropic_messages(config): updated_extra_kwargs.setdefault("cache_control", _cache_control_payload()) - elif policy.mode == "openrouter_anthropic_top_level": - extra_body = updated_sampling_args.get("extra_body") - if isinstance(extra_body, Mapping): - extra_body = dict(extra_body) - else: - extra_body = {} - extra_body.setdefault("cache_control", _cache_control_payload()) - updated_sampling_args["extra_body"] = extra_body - return ( - updated_native_prompt, - native_tools, - updated_sampling_args, - updated_extra_kwargs, - ) + return native_prompt, native_tools, dict(sampling_args), updated_extra_kwargs diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index 59feb062d..3c478a5d5 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -1,4 +1,3 @@ -import math from collections.abc import Mapping, Sequence from types import MappingProxyType from typing import Any, cast @@ -6,110 +5,76 @@ from verifiers.types import Response, TokenUsage, Usage -def _get_usage_value(usage_obj: object, key: str) -> object: - if isinstance(usage_obj, Mapping): - usage_mapping = cast(Mapping[str, object], usage_obj) - return usage_mapping.get(key, 0) - return getattr(usage_obj, key, 0) +def _get_field(obj: object, key: str) -> object: + if isinstance(obj, Mapping): + return cast(Mapping[str, object], obj).get(key) + return getattr(obj, key, None) -def _get_optional_usage_value(usage_obj: object, key: str) -> object: - if isinstance(usage_obj, Mapping): - usage_mapping = cast(Mapping[str, object], usage_obj) - return usage_mapping.get(key) - return getattr(usage_obj, key, None) +def _as_token_count(value: object) -> int | None: + if isinstance(value, bool): + return None + if isinstance(value, int): + return max(0, value) + if isinstance(value, float) and value.is_integer(): + return max(0, int(value)) + return None + + +def _response_usage(response: object) -> object | None: + return _get_field(response, "usage") -def _get_nested_usage_value(usage_obj: object, key: str) -> object: - value = _get_optional_usage_value(usage_obj, key) - if value is not None: - return value +def _nested_cached_tokens(usage: object) -> int | None: for details_key in ("prompt_tokens_details", "input_tokens_details"): - details = _get_optional_usage_value(usage_obj, details_key) - if isinstance(details, Mapping): - details_mapping = cast(Mapping[str, object], details) - nested_value = details_mapping.get(key) - if nested_value is not None: - return nested_value - elif details is not None: - nested_value = getattr(details, key, None) - if nested_value is not None: - return nested_value + details = _get_field(usage, details_key) + if details is None: + continue + cached = _as_token_count(_get_field(details, "cached_tokens")) + if cached is not None: + return cached return None -def _get_response_usage(response: object) -> object: - if isinstance(response, Mapping): - response_mapping = cast(Mapping[str, object], response) - return response_mapping.get("usage") - return getattr(response, "usage", None) +def _cache_creation_tokens(usage: object) -> int: + return _as_token_count(_get_field(usage, "cache_creation_input_tokens")) or 0 -def _coerce_usage_int(value: object) -> int: - """Best-effort usage coercion. Invalid values degrade to zero.""" - if value is None: - return 0 - if isinstance(value, bool): - return int(value) - if isinstance(value, int): - return max(0, value) - if isinstance(value, float): - if math.isnan(value) or math.isinf(value): - return 0 - return max(0, int(value)) - if isinstance(value, str): - stripped = value.strip() - if not stripped: - return 0 - try: - return max(0, int(stripped)) - except (TypeError, ValueError): - try: - parsed = float(stripped) - if math.isnan(parsed) or math.isinf(parsed): - return 0 - return max(0, int(parsed)) - except (TypeError, ValueError): - return 0 - return 0 +def _direct_cached_tokens(usage: object) -> int | None: + cached = _as_token_count(_get_field(usage, "cached_input_tokens")) + if cached is not None: + return cached + return _as_token_count(_get_field(usage, "cache_read_input_tokens")) def extract_usage_token_details(response: object) -> dict[str, int] | None: - usage = _get_response_usage(response) + usage = _response_usage(response) if usage is None: return None - prompt_tokens = _get_usage_value(usage, "prompt_tokens") - completion_tokens = _get_usage_value(usage, "completion_tokens") - if not prompt_tokens and not completion_tokens: - prompt_tokens = _get_usage_value(usage, "input_tokens") - completion_tokens = _get_usage_value(usage, "output_tokens") + input_tokens = _as_token_count(_get_field(usage, "prompt_tokens")) + output_tokens = _as_token_count(_get_field(usage, "completion_tokens")) + if input_tokens is None and output_tokens is None: + input_tokens = _as_token_count(_get_field(usage, "input_tokens")) + output_tokens = _as_token_count(_get_field(usage, "output_tokens")) + if input_tokens is None or output_tokens is None: + return None + + input_tokens += _cache_creation_tokens(usage) details = { - "input_tokens": _coerce_usage_int(prompt_tokens), - "output_tokens": _coerce_usage_int(completion_tokens), + "input_tokens": input_tokens, + "output_tokens": output_tokens, } - subtract_cached_from_input = False - cached_input_tokens = _get_optional_usage_value(usage, "cached_input_tokens") - if cached_input_tokens is None: - cached_input_tokens = _get_optional_usage_value( - usage, "cache_read_input_tokens" - ) - if cached_input_tokens is None: - cached_input_tokens = _get_nested_usage_value(usage, "cached_tokens") - subtract_cached_from_input = cached_input_tokens is not None - if cached_input_tokens is not None: - cached_int = _coerce_usage_int(cached_input_tokens) - details["cached_input_tokens"] = cached_int - if subtract_cached_from_input: - details["input_tokens"] = max(0, details["input_tokens"] - cached_int) - - cache_creation_input_tokens = _get_optional_usage_value( - usage, "cache_creation_input_tokens" - ) - if cache_creation_input_tokens is not None: - details["input_tokens"] += _coerce_usage_int(cache_creation_input_tokens) + cached_tokens = _direct_cached_tokens(usage) + if cached_tokens is not None: + details["cached_input_tokens"] = cached_tokens + return details + cached_tokens = _nested_cached_tokens(usage) + if cached_tokens is not None: + details["input_tokens"] = max(0, input_tokens - cached_tokens) + details["cached_input_tokens"] = cached_tokens return details @@ -133,6 +98,17 @@ def response_usage_tokens(response: Response) -> tuple[int, int]: return usage_tokens(usage) +def cast_token_usage(usage: Mapping[str, Any]) -> TokenUsage: + out: TokenUsage = { + "input_tokens": float(usage.get("input_tokens", 0.0)), + "output_tokens": float(usage.get("output_tokens", 0.0)), + } + cached = usage.get("cached_input_tokens") + if cached is not None: + out["cached_input_tokens"] = float(cached) + return out + + class StateUsageTracker: """Accumulates token usage and exposes a read-only live usage mapping.""" @@ -172,8 +148,6 @@ def increment( self._usage_totals[key] = self._usage_totals.get(key, 0.0) + delta def increment_from_response(self, response: object) -> None: - if _get_response_usage(response) is None: - return details = extract_usage_token_details(response) if details is None: return @@ -190,18 +164,6 @@ def snapshot(self) -> TokenUsage | None: return cast_token_usage(self._usage_totals) -def cast_token_usage(usage: Mapping[str, Any]) -> TokenUsage: - out: TokenUsage = { - "input_tokens": float(usage.get("input_tokens", 0.0)), - "output_tokens": float(usage.get("output_tokens", 0.0)), - } - for key in ("cached_input_tokens",): - value = usage.get(key) - if value is not None: - out[key] = float(value) - return out - - def compute_context_token_metrics( trajectory: Sequence[Mapping[str, object]], ) -> dict[str, float]: @@ -214,44 +176,37 @@ def compute_context_token_metrics( Returns a dict with: final_output_tokens: Model-generated tokens (sum of completion_tokens across all steps). - final_input_tokens: Non-model tokens in context (last step's total - context minus final_output_tokens). + final_input_tokens: Non-model tokens in context (system prompts, user + messages, tool results, etc.). """ - _zero: dict[str, float] = { - "final_output_tokens": 0, - "final_input_tokens": 0, - } + zero = {"final_output_tokens": 0.0, "final_input_tokens": 0.0} if not trajectory: - return _zero + return zero last_step_total = 0 found = False for step in reversed(trajectory): - response = step.get("response") - if response is None or _get_response_usage(response) is None: - continue - details = extract_usage_token_details(response) + details = extract_usage_token_details(step.get("response")) if details is None: continue - prompt_tokens = details["input_tokens"] + details.get("cached_input_tokens", 0) - completion_tokens = details["output_tokens"] - last_step_total = prompt_tokens + completion_tokens + last_step_total = ( + details["input_tokens"] + + details.get("cached_input_tokens", 0) + + details["output_tokens"] + ) found = True break if not found: - return _zero + return zero total_completion = 0 for step in trajectory: - response = step.get("response") - if response is None or _get_response_usage(response) is None: - continue - details = extract_usage_token_details(response) + details = extract_usage_token_details(step.get("response")) if details is not None: total_completion += details["output_tokens"] return { - "final_output_tokens": total_completion, - "final_input_tokens": max(0, last_step_total - total_completion), + "final_output_tokens": float(total_completion), + "final_input_tokens": float(max(0, last_step_total - total_completion)), } From 36370772411cbe8b6bf131ae6b16fac37017c80b Mon Sep 17 00:00:00 2001 From: William Brown Date: Wed, 20 May 2026 22:01:16 -0700 Subject: [PATCH 6/9] Address prompt cache review comments --- tests/test_prompt_cache_utils.py | 24 ++++++++++++++++++++++++ verifiers/utils/prompt_cache_utils.py | 12 +++++++++--- verifiers/utils/usage_utils.py | 7 ------- 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index d274ce497..6037862a6 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -71,6 +71,7 @@ def test_endpoint_origin_normalizes_urls(): "https://api.anthropic.com" ) assert endpoint_origin("http://localhost:8080/v1") == "http://localhost:8080" + assert endpoint_origin("http://[::1]:8080/v1") == "http://[::1]:8080" def test_official_anthropic_messages_endpoint_is_cache_control_target(): @@ -115,6 +116,29 @@ def test_anthropic_request_adds_top_level_cache_control(): assert extra_kwargs["cache_control"] == {"type": "ephemeral"} +def test_anthropic_request_preserves_sampling_args_cache_control(): + _, _, sampling_args, extra_kwargs = apply_prompt_cache_to_request( + config=ClientConfig( + client_type="anthropic_messages", + api_base_url="https://api.anthropic.com", + ), + model="claude-sonnet-4-5", + native_prompt=[], + native_tools=None, + sampling_args={ + "max_tokens": 16, + "cache_control": {"type": "custom"}, + }, + extra_kwargs={}, + ) + + assert sampling_args == { + "max_tokens": 16, + "cache_control": {"type": "custom"}, + } + assert extra_kwargs == {} + + def test_openai_request_does_not_mutate_request(): native_prompt, native_tools, sampling_args, extra_kwargs = ( apply_prompt_cache_to_request( diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py index 8d5966f42..9e3676bba 100644 --- a/verifiers/utils/prompt_cache_utils.py +++ b/verifiers/utils/prompt_cache_utils.py @@ -18,10 +18,12 @@ def endpoint_origin(api_base_url: str) -> str | None: host = parsed.hostname.lower() port = parsed.port netloc = host + if ":" in host: + netloc = f"[{host}]" if port is not None and not ( (scheme == "https" and port == 443) or (scheme == "http" and port == 80) ): - netloc = f"{host}:{port}" + netloc = f"{netloc}:{port}" return f"{scheme}://{netloc}" @@ -47,7 +49,11 @@ def apply_prompt_cache_to_request( extra_kwargs: Mapping[str, Any], ) -> tuple[NativePromptT, NativeToolsT, dict[str, Any], dict[str, Any]]: _ = model + updated_sampling_args = dict(sampling_args) updated_extra_kwargs = dict(extra_kwargs) - if uses_official_anthropic_messages(config): + if ( + uses_official_anthropic_messages(config) + and "cache_control" not in updated_sampling_args + ): updated_extra_kwargs.setdefault("cache_control", _cache_control_payload()) - return native_prompt, native_tools, dict(sampling_args), updated_extra_kwargs + return native_prompt, native_tools, updated_sampling_args, updated_extra_kwargs diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index 3c478a5d5..969baff97 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -78,13 +78,6 @@ def extract_usage_token_details(response: object) -> dict[str, int] | None: return details -def extract_usage_tokens(response: object) -> tuple[int, int]: - details = extract_usage_token_details(response) - if details is None: - return 0, 0 - return details["input_tokens"], details["output_tokens"] - - def usage_tokens(usage: Usage) -> tuple[int, int]: if usage.prompt_tokens < 0 or usage.completion_tokens < 0: raise ValueError("Response usage tokens must be non-negative.") From b0f02de0359287366da16e6afd5c39c35a1deb75 Mon Sep 17 00:00:00 2001 From: William Brown Date: Wed, 20 May 2026 22:51:42 -0700 Subject: [PATCH 7/9] Address cached usage review comments --- ...st_openai_chat_completions_token_client.py | 35 ++++++++++++ tests/test_prompt_cache_utils.py | 53 +++++++++++++++++++ .../clients/openai_chat_completions_client.py | 29 +++++----- .../clients/openai_completions_client.py | 24 ++++----- verifiers/utils/usage_utils.py | 5 +- 5 files changed, 119 insertions(+), 27 deletions(-) diff --git a/tests/test_openai_chat_completions_token_client.py b/tests/test_openai_chat_completions_token_client.py index 923ff118e..c9503f2bc 100644 --- a/tests/test_openai_chat_completions_token_client.py +++ b/tests/test_openai_chat_completions_token_client.py @@ -1,3 +1,4 @@ +from types import SimpleNamespace from typing import Any, cast import httpx @@ -293,3 +294,37 @@ async def fake_get_prompt_ids( # noqa: ANN001 assert len(recording_client.calls) == 1 assert recording_client.calls[0]["path"] == "/chat/completions/tokens" assert recording_client.calls[0]["body"]["tokens"] == [10, 20] + + +@pytest.mark.asyncio +async def test_from_native_response_splits_cached_input_tokens(): + client = OpenAIChatCompletionsTokenClient(_NoopClient()) + message = SimpleNamespace( + content="ok", + tool_calls=None, + model_dump=lambda: {}, + ) + native_response = SimpleNamespace( + id="resp", + created=0, + model="test-model", + usage=SimpleNamespace( + prompt_tokens=100, + completion_tokens=5, + total_tokens=105, + prompt_tokens_details=SimpleNamespace(cached_tokens=80), + ), + choices=[ + SimpleNamespace( + message=message, + finish_reason="stop", + ) + ], + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 20 + assert response.usage.cached_input_tokens == 80 + assert response.usage.total_tokens == 25 diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index 6037862a6..04cbc7f4b 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -232,6 +232,41 @@ async def test_openai_usage_splits_cached_input_tokens(): assert response.usage.total_tokens == 25 +@pytest.mark.asyncio +async def test_openai_usage_handles_mixed_token_field_names(): + client = OpenAIChatCompletionsClient(object()) + message = SimpleNamespace( + content="ok", + tool_calls=None, + model_dump=lambda: {}, + ) + native_response = SimpleNamespace( + id="resp", + created=0, + model="gpt-5.4-mini", + usage=SimpleNamespace( + prompt_tokens=100, + output_tokens=5, + total_tokens=105, + prompt_tokens_details=SimpleNamespace(cached_tokens=80), + ), + choices=[ + SimpleNamespace( + message=message, + finish_reason="stop", + ) + ], + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 20 + assert response.usage.completion_tokens == 5 + assert response.usage.cached_input_tokens == 80 + assert response.usage.total_tokens == 25 + + @pytest.mark.asyncio async def test_anthropic_usage_splits_cache_read_and_write_tokens(): client = AnthropicMessagesClient(object()) @@ -310,6 +345,24 @@ def test_serialized_responses_usage_counts_input_token_cache_details(): } +def test_serialized_usage_handles_mixed_token_field_names(): + response = { + "usage": { + "prompt_tokens": 100, + "output_tokens": 7, + "prompt_tokens_details": { + "cached_tokens": 80, + }, + } + } + + assert extract_usage_token_details(response) == { + "input_tokens": 20, + "output_tokens": 7, + "cached_input_tokens": 80, + } + + def test_state_output_fallback_reads_serialized_trajectory_usage(): task = vf.Task( { diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index 8c0184a1e..f55188ec3 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -106,6 +106,14 @@ def get_usage_int_field(usage: Any, key: str) -> int | None: return None +def get_first_usage_int_field(usage: Any, *keys: str) -> int | None: + for key in keys: + value = get_usage_int_field(usage, key) + if value is not None: + return value + return None + + def get_cached_prompt_tokens(usage: Any) -> int | None: details = get_usage_field(usage, "prompt_tokens_details") if details is None: @@ -427,22 +435,19 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_usage_field(usage, "prompt_tokens") - completion_tokens = get_usage_field(usage, "completion_tokens") - if not isinstance(prompt_tokens, int) or not isinstance( - completion_tokens, int - ): - prompt_tokens = get_usage_field(usage, "input_tokens") - completion_tokens = get_usage_field(usage, "output_tokens") - total_tokens = get_usage_field(usage, "total_tokens") - if not isinstance(prompt_tokens, int) or not isinstance( - completion_tokens, int - ): + prompt_tokens = get_first_usage_int_field( + usage, "prompt_tokens", "input_tokens" + ) + completion_tokens = get_first_usage_int_field( + usage, "completion_tokens", "output_tokens" + ) + if prompt_tokens is None or completion_tokens is None: return None + total_tokens = get_usage_int_field(usage, "total_tokens") cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) - if not isinstance(total_tokens, int): + if total_tokens is None: total_tokens = prompt_tokens + completion_tokens elif cached_tokens is not None: total_tokens = max(0, total_tokens - cached_tokens) diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index 6c4019412..c03d5549b 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -7,7 +7,8 @@ from verifiers.clients.openai_chat_completions_client import ( content_to_text, get_cached_prompt_tokens, - get_usage_field, + get_first_usage_int_field, + get_usage_int_field, handle_openai_overlong_prompt, ) from verifiers.errors import ( @@ -113,22 +114,19 @@ def parse_usage(response: OpenAITextResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_usage_field(usage, "prompt_tokens") - completion_tokens = get_usage_field(usage, "completion_tokens") - if not isinstance(prompt_tokens, int) or not isinstance( - completion_tokens, int - ): - prompt_tokens = get_usage_field(usage, "input_tokens") - completion_tokens = get_usage_field(usage, "output_tokens") - total_tokens = get_usage_field(usage, "total_tokens") - if not isinstance(prompt_tokens, int) or not isinstance( - completion_tokens, int - ): + prompt_tokens = get_first_usage_int_field( + usage, "prompt_tokens", "input_tokens" + ) + completion_tokens = get_first_usage_int_field( + usage, "completion_tokens", "output_tokens" + ) + if prompt_tokens is None or completion_tokens is None: return None + total_tokens = get_usage_int_field(usage, "total_tokens") cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) - if not isinstance(total_tokens, int): + if total_tokens is None: total_tokens = prompt_tokens + completion_tokens elif cached_tokens is not None: total_tokens = max(0, total_tokens - cached_tokens) diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index 969baff97..c7acb22e0 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -53,9 +53,10 @@ def extract_usage_token_details(response: object) -> dict[str, int] | None: return None input_tokens = _as_token_count(_get_field(usage, "prompt_tokens")) - output_tokens = _as_token_count(_get_field(usage, "completion_tokens")) - if input_tokens is None and output_tokens is None: + if input_tokens is None: input_tokens = _as_token_count(_get_field(usage, "input_tokens")) + output_tokens = _as_token_count(_get_field(usage, "completion_tokens")) + if output_tokens is None: output_tokens = _as_token_count(_get_field(usage, "output_tokens")) if input_tokens is None or output_tokens is None: return None From 7829691a3513e4523897e3875b32de3d780e88a7 Mon Sep 17 00:00:00 2001 From: William Brown Date: Wed, 20 May 2026 23:20:02 -0700 Subject: [PATCH 8/9] Harden responses usage parsing --- tests/test_openai_responses_client.py | 30 ++++++++++++++++++++ verifiers/clients/openai_responses_client.py | 17 ++++++----- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tests/test_openai_responses_client.py b/tests/test_openai_responses_client.py index bb2d68270..df6d22c14 100644 --- a/tests/test_openai_responses_client.py +++ b/tests/test_openai_responses_client.py @@ -236,6 +236,36 @@ async def test_from_native_response_parses_text_tool_usage_and_raw_output(): ) +@pytest.mark.asyncio +async def test_from_native_response_rejects_bool_usage_counts(): + native_response = SimpleNamespace( + id="resp_1", + created_at=123.0, + model="gpt-5.2", + status="completed", + incomplete_details=None, + usage={ + "input_tokens": True, + "output_tokens": 7, + "total_tokens": 8, + }, + output=[ + { + "type": "message", + "id": "msg_1", + "role": "assistant", + "status": "completed", + "content": [{"type": "output_text", "text": "hello"}], + }, + ], + ) + client = OpenAIResponsesClient(object()) + + response = await client.from_native_response(native_response) + + assert response.usage is None + + @pytest.mark.asyncio async def test_from_native_response_uses_none_content_for_tool_call_only_response(): native_response = SimpleNamespace( diff --git a/verifiers/clients/openai_responses_client.py b/verifiers/clients/openai_responses_client.py index e50ee808b..8e0f499ef 100644 --- a/verifiers/clients/openai_responses_client.py +++ b/verifiers/clients/openai_responses_client.py @@ -10,6 +10,7 @@ content_to_text, get_cached_prompt_tokens, get_usage_field, + get_usage_int_field, handle_openai_overlong_prompt, ) from verifiers.errors import EmptyModelResponseError, InvalidModelResponseError @@ -373,27 +374,25 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_usage_field(usage, "input_tokens") - completion_tokens = get_usage_field(usage, "output_tokens") - total_tokens = get_usage_field(usage, "total_tokens") + prompt_tokens = get_usage_int_field(usage, "input_tokens") + completion_tokens = get_usage_int_field(usage, "output_tokens") + total_tokens = get_usage_int_field(usage, "total_tokens") output_details = get_usage_field(usage, "output_tokens_details") reasoning_tokens = ( - get_usage_field(output_details, "reasoning_tokens") + get_usage_int_field(output_details, "reasoning_tokens") if output_details is not None else 0 ) - if not isinstance(prompt_tokens, int) or not isinstance( - completion_tokens, int - ): + if prompt_tokens is None or completion_tokens is None: return None cached_tokens = get_cached_prompt_tokens(usage) if cached_tokens is not None: prompt_tokens = max(0, prompt_tokens - cached_tokens) - if not isinstance(total_tokens, int): + if total_tokens is None: total_tokens = prompt_tokens + completion_tokens elif cached_tokens is not None: total_tokens = max(0, total_tokens - cached_tokens) - if not isinstance(reasoning_tokens, int): + if reasoning_tokens is None: reasoning_tokens = 0 return Usage( prompt_tokens=prompt_tokens, From badc2c598f82e8b71508c8b56e099169030e10fa Mon Sep 17 00:00:00 2001 From: William Brown Date: Thu, 21 May 2026 00:17:35 -0700 Subject: [PATCH 9/9] Shrink prompt cache integration --- docs/evaluation.md | 2 - docs/reference.md | 4 +- skills/evaluate-environments/SKILL.md | 1 - tests/test_client_multimodal_types.py | 27 ++ ...st_openai_chat_completions_token_client.py | 35 -- tests/test_openai_responses_client.py | 30 -- tests/test_prompt_cache_utils.py | 384 +----------------- .../clients/anthropic_messages_client.py | 8 +- verifiers/clients/client.py | 15 +- .../clients/openai_chat_completions_client.py | 61 ++- .../clients/openai_completions_client.py | 31 +- verifiers/clients/openai_responses_client.py | 32 +- verifiers/utils/prompt_cache_utils.py | 18 +- verifiers/utils/save_utils.py | 53 +-- verifiers/utils/usage_utils.py | 172 +++----- 15 files changed, 192 insertions(+), 681 deletions(-) diff --git a/docs/evaluation.md b/docs/evaluation.md index 6eff1467d..1d6cfb4ed 100644 --- a/docs/evaluation.md +++ b/docs/evaluation.md @@ -167,8 +167,6 @@ In `[[eval]]` TOML configs you can set extra headers as `headers = { ... }` and/ For per-request headers that need to vary per rollout (e.g. sticky DP-aware routing keyed off `example_id` or `trajectory_id`), use `headers_from_state = { "X-Name" = "state_key" }` and/or `header_from_state = ["X-Name: state_key", ...]` (same form as repeated `--header-from-state`). The value for each request is resolved at send time as `state[state_key]`. If unset, `X-Session-ID` defaults to `example_id`. -Provider prompt caches are managed by the upstream API. Verifiers reports provider cache hits as `cached_input_tokens` when they appear in usage data, and automatically sends Anthropic's prompt-cache hint for official Anthropic Messages endpoints. - To define equivalent replicas, add multiple `[[endpoint]]` entries with the same `endpoint_id`. Then use the alias directly: diff --git a/docs/reference.md b/docs/reference.md index 8cba2b518..1d90f4215 100644 --- a/docs/reference.md +++ b/docs/reference.md @@ -252,16 +252,14 @@ Derivations: class TokenUsage(TypedDict, total=False): input_tokens: float output_tokens: float - cached_input_tokens: float final_input_tokens: float final_output_tokens: float ``` | Field | Description | |-------|-------------| -| `input_tokens` | Sum of non-cache-hit prompt tokens across all turns. Shared uncached context is counted each time it appears in a prompt. | +| `input_tokens` | Sum of prompt tokens across all turns. Shared context is counted each time it appears in a prompt. | | `output_tokens` | Sum of completion tokens across all turns. | -| `cached_input_tokens` | Sum of prompt tokens served from provider prompt cache, when reported by the provider. | | `final_input_tokens` | Non-completion tokens in the final turn's context (system prompts, user messages, tool results, etc.). | | `final_output_tokens` | Completion tokens in the final turn's context. Equals `output_tokens` for single-turn rollouts. | diff --git a/skills/evaluate-environments/SKILL.md b/skills/evaluate-environments/SKILL.md index 5ef865fa8..b34fce774 100644 --- a/skills/evaluate-environments/SKILL.md +++ b/skills/evaluate-environments/SKILL.md @@ -73,7 +73,6 @@ url = "https://api.openai.com/v1" key = "OPENAI_API_KEY" api_client_type = "openai_responses" ``` -9. Do not ask users to configure prompt caching for normal evals. Verifiers reports provider cache hits when usage data includes them, and official Anthropic Messages endpoints receive Anthropic's prompt-cache hint automatically. ## Publish Gate Before Large Runs 1. After smoke tests pass and results look stable, proactively suggest pushing the environment to Hub before large eval sweeps or RL work. diff --git a/tests/test_client_multimodal_types.py b/tests/test_client_multimodal_types.py index d51c38262..22a6e8c6f 100644 --- a/tests/test_client_multimodal_types.py +++ b/tests/test_client_multimodal_types.py @@ -181,6 +181,33 @@ async def test_anthropic_from_native_response_extracts_usage(): assert response.usage.reasoning_tokens == 0 +@pytest.mark.asyncio +async def test_anthropic_from_native_response_extracts_cache_usage(): + from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient + + client = AnthropicMessagesClient(object()) + native_response = SimpleNamespace( + id="msg_cache", + model="claude-haiku-4-5", + stop_reason="end_turn", + content=[SimpleNamespace(type="text", text="Hello!")], + usage=SimpleNamespace( + input_tokens=42, + output_tokens=17, + cache_creation_input_tokens=8, + cache_read_input_tokens=100, + ), + ) + + response = await client.from_native_response(native_response) + + assert response.usage is not None + assert response.usage.prompt_tokens == 50 + assert response.usage.completion_tokens == 17 + assert response.usage.cached_input_tokens == 100 + assert response.usage.total_tokens == 67 + + @pytest.mark.asyncio async def test_anthropic_from_native_response_always_parses_reasoning(): pytest.importorskip("anthropic") diff --git a/tests/test_openai_chat_completions_token_client.py b/tests/test_openai_chat_completions_token_client.py index c9503f2bc..923ff118e 100644 --- a/tests/test_openai_chat_completions_token_client.py +++ b/tests/test_openai_chat_completions_token_client.py @@ -1,4 +1,3 @@ -from types import SimpleNamespace from typing import Any, cast import httpx @@ -294,37 +293,3 @@ async def fake_get_prompt_ids( # noqa: ANN001 assert len(recording_client.calls) == 1 assert recording_client.calls[0]["path"] == "/chat/completions/tokens" assert recording_client.calls[0]["body"]["tokens"] == [10, 20] - - -@pytest.mark.asyncio -async def test_from_native_response_splits_cached_input_tokens(): - client = OpenAIChatCompletionsTokenClient(_NoopClient()) - message = SimpleNamespace( - content="ok", - tool_calls=None, - model_dump=lambda: {}, - ) - native_response = SimpleNamespace( - id="resp", - created=0, - model="test-model", - usage=SimpleNamespace( - prompt_tokens=100, - completion_tokens=5, - total_tokens=105, - prompt_tokens_details=SimpleNamespace(cached_tokens=80), - ), - choices=[ - SimpleNamespace( - message=message, - finish_reason="stop", - ) - ], - ) - - response = await client.from_native_response(native_response) - - assert response.usage is not None - assert response.usage.prompt_tokens == 20 - assert response.usage.cached_input_tokens == 80 - assert response.usage.total_tokens == 25 diff --git a/tests/test_openai_responses_client.py b/tests/test_openai_responses_client.py index df6d22c14..bb2d68270 100644 --- a/tests/test_openai_responses_client.py +++ b/tests/test_openai_responses_client.py @@ -236,36 +236,6 @@ async def test_from_native_response_parses_text_tool_usage_and_raw_output(): ) -@pytest.mark.asyncio -async def test_from_native_response_rejects_bool_usage_counts(): - native_response = SimpleNamespace( - id="resp_1", - created_at=123.0, - model="gpt-5.2", - status="completed", - incomplete_details=None, - usage={ - "input_tokens": True, - "output_tokens": 7, - "total_tokens": 8, - }, - output=[ - { - "type": "message", - "id": "msg_1", - "role": "assistant", - "status": "completed", - "content": [{"type": "output_text", "text": "hello"}], - }, - ], - ) - client = OpenAIResponsesClient(object()) - - response = await client.from_native_response(native_response) - - assert response.usage is None - - @pytest.mark.asyncio async def test_from_native_response_uses_none_content_for_tool_call_only_response(): native_response = SimpleNamespace( diff --git a/tests/test_prompt_cache_utils.py b/tests/test_prompt_cache_utils.py index 04cbc7f4b..7f9560aa9 100644 --- a/tests/test_prompt_cache_utils.py +++ b/tests/test_prompt_cache_utils.py @@ -1,390 +1,26 @@ -from types import SimpleNamespace +from verifiers.types import ClientConfig +from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_kwargs -import pytest -import verifiers.v1 as vf -from verifiers.clients.anthropic_messages_client import AnthropicMessagesClient -from verifiers.clients.client import Client -from verifiers.clients.openai_chat_completions_client import OpenAIChatCompletionsClient -from verifiers.types import ClientConfig, Response, ResponseMessage -from verifiers.utils.prompt_cache_utils import ( - apply_prompt_cache_to_request, - endpoint_origin, - uses_official_anthropic_messages, -) -from verifiers.utils.save_utils import state_to_output -from verifiers.utils.usage_utils import extract_usage_token_details - - -class RecordingClient(Client): - def __init__(self, config: ClientConfig): - super().__init__(config) - self.request = {} - - def setup_client(self, config): - return object() - - async def to_native_tool(self, tool): - return tool - - async def to_native_prompt(self, messages): - return messages, {} - - async def get_native_response( - self, prompt, model, sampling_args, tools=None, **kwargs - ): - self.request = { - "prompt": prompt, - "model": model, - "sampling_args": sampling_args, - "tools": tools, - "kwargs": kwargs, - } - return object() - - async def raise_from_native_response(self, response): - _ = response - - async def from_native_response(self, response): - _ = response - return Response( - id="resp", - created=0, - model="model", - usage=None, - message=ResponseMessage( - content="ok", - finish_reason="stop", - is_truncated=False, - ), - ) - - async def close(self) -> None: - pass - - -def test_endpoint_origin_normalizes_urls(): - assert ( - endpoint_origin("https://api.anthropic.com/v1") == "https://api.anthropic.com" - ) - assert endpoint_origin("https://api.anthropic.com:443/v1") == ( - "https://api.anthropic.com" - ) - assert endpoint_origin("http://localhost:8080/v1") == "http://localhost:8080" - assert endpoint_origin("http://[::1]:8080/v1") == "http://[::1]:8080" - - -def test_official_anthropic_messages_endpoint_is_cache_control_target(): - assert uses_official_anthropic_messages( - ClientConfig( - client_type="anthropic_messages", - api_base_url="https://api.anthropic.com", - ) - ) - assert not uses_official_anthropic_messages( - ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.anthropic.com", - ) - ) - assert not uses_official_anthropic_messages( - ClientConfig( - client_type="anthropic_messages", - api_base_url="https://api.pinference.ai/api/v1", - ) - ) - - -def test_anthropic_request_adds_top_level_cache_control(): - native_prompt, native_tools, sampling_args, extra_kwargs = ( - apply_prompt_cache_to_request( - config=ClientConfig( - client_type="anthropic_messages", - api_base_url="https://api.anthropic.com", - ), - model="claude-sonnet-4-5", - native_prompt=[{"role": "user", "content": "question"}], - native_tools=None, - sampling_args={"max_tokens": 16}, - extra_kwargs={}, - ) - ) - - assert native_prompt == [{"role": "user", "content": "question"}] - assert native_tools is None - assert sampling_args == {"max_tokens": 16} - assert extra_kwargs["cache_control"] == {"type": "ephemeral"} - - -def test_anthropic_request_preserves_sampling_args_cache_control(): - _, _, sampling_args, extra_kwargs = apply_prompt_cache_to_request( +def test_anthropic_cache_control_hint_is_default_only(): + extra_kwargs = apply_prompt_cache_to_kwargs( config=ClientConfig( client_type="anthropic_messages", - api_base_url="https://api.anthropic.com", + api_base_url="https://api.anthropic.com/v1", ), - model="claude-sonnet-4-5", - native_prompt=[], - native_tools=None, - sampling_args={ - "max_tokens": 16, - "cache_control": {"type": "custom"}, - }, + sampling_args={"max_tokens": 16}, extra_kwargs={}, ) - assert sampling_args == { - "max_tokens": 16, - "cache_control": {"type": "custom"}, - } - assert extra_kwargs == {} - - -def test_openai_request_does_not_mutate_request(): - native_prompt, native_tools, sampling_args, extra_kwargs = ( - apply_prompt_cache_to_request( - config=ClientConfig( - client_type="openai_chat_completions", - api_base_url="https://api.openai.com/v1", - ), - model="gpt-5.4-mini", - native_prompt=[{"role": "user", "content": "question"}], - native_tools=None, - sampling_args={"max_tokens": 16}, - extra_kwargs={}, - ) - ) - - assert native_prompt == [{"role": "user", "content": "question"}] - assert native_tools is None - assert sampling_args == {"max_tokens": 16} - assert extra_kwargs == {} + assert extra_kwargs == {"cache_control": {"type": "ephemeral"}} - -def test_non_official_anthropic_endpoint_does_not_add_cache_control(): - _, _, sampling_args, extra_kwargs = apply_prompt_cache_to_request( + extra_kwargs = apply_prompt_cache_to_kwargs( config=ClientConfig( client_type="anthropic_messages", - api_base_url="https://api.pinference.ai/api/v1", + api_base_url="https://api.anthropic.com/v1", ), - model="claude-sonnet-4-5", - native_prompt=[], - native_tools=None, - sampling_args={"max_tokens": 16}, + sampling_args={"cache_control": {"type": "custom"}}, extra_kwargs={}, ) - assert sampling_args == {"max_tokens": 16} assert extra_kwargs == {} - - -@pytest.mark.asyncio -async def test_client_request_hook_applies_anthropic_cache_control(): - client = RecordingClient( - ClientConfig( - client_type="anthropic_messages", - api_base_url="https://api.anthropic.com", - ) - ) - - await client.get_response( - prompt=[], - model="claude-sonnet-4-5", - sampling_args={"max_tokens": 16}, - ) - - assert client.request["kwargs"]["cache_control"] == {"type": "ephemeral"} - - -@pytest.mark.asyncio -async def test_openai_usage_splits_cached_input_tokens(): - client = OpenAIChatCompletionsClient(object()) - message = SimpleNamespace( - content="ok", - tool_calls=None, - model_dump=lambda: {}, - ) - native_response = SimpleNamespace( - id="resp", - created=0, - model="gpt-5.4-mini", - usage=SimpleNamespace( - prompt_tokens=100, - completion_tokens=5, - total_tokens=105, - prompt_tokens_details=SimpleNamespace( - cached_tokens=80, - cache_write_tokens=10, - ), - ), - choices=[ - SimpleNamespace( - message=message, - finish_reason="stop", - ) - ], - ) - - response = await client.from_native_response(native_response) - - assert response.usage is not None - assert response.usage.prompt_tokens == 20 - assert response.usage.cached_input_tokens == 80 - assert response.usage.total_tokens == 25 - - -@pytest.mark.asyncio -async def test_openai_usage_handles_mixed_token_field_names(): - client = OpenAIChatCompletionsClient(object()) - message = SimpleNamespace( - content="ok", - tool_calls=None, - model_dump=lambda: {}, - ) - native_response = SimpleNamespace( - id="resp", - created=0, - model="gpt-5.4-mini", - usage=SimpleNamespace( - prompt_tokens=100, - output_tokens=5, - total_tokens=105, - prompt_tokens_details=SimpleNamespace(cached_tokens=80), - ), - choices=[ - SimpleNamespace( - message=message, - finish_reason="stop", - ) - ], - ) - - response = await client.from_native_response(native_response) - - assert response.usage is not None - assert response.usage.prompt_tokens == 20 - assert response.usage.completion_tokens == 5 - assert response.usage.cached_input_tokens == 80 - assert response.usage.total_tokens == 25 - - -@pytest.mark.asyncio -async def test_anthropic_usage_splits_cache_read_and_write_tokens(): - client = AnthropicMessagesClient(object()) - native_response = SimpleNamespace( - id="resp", - model="claude-sonnet-4-5", - stop_reason="end_turn", - content=[SimpleNamespace(type="text", text="ok")], - usage=SimpleNamespace( - input_tokens=5, - output_tokens=7, - cache_read_input_tokens=80, - cache_creation_input_tokens=10, - ), - ) - - response = await client.from_native_response(native_response) - - assert response.usage is not None - assert response.usage.prompt_tokens == 15 - assert response.usage.cached_input_tokens == 80 - assert response.usage.total_tokens == 22 - - -def test_native_anthropic_cache_creation_counts_as_uncached_input(): - response = SimpleNamespace( - usage=SimpleNamespace( - input_tokens=5, - output_tokens=7, - cache_read_input_tokens=80, - cache_creation_input_tokens=10, - ) - ) - - assert extract_usage_token_details(response) == { - "input_tokens": 15, - "output_tokens": 7, - "cached_input_tokens": 80, - } - - -def test_serialized_response_usage_counts_cache_details(): - response = { - "usage": { - "prompt_tokens": 100, - "completion_tokens": 7, - "prompt_tokens_details": { - "cached_tokens": 80, - "cache_write_tokens": 10, - }, - } - } - - assert extract_usage_token_details(response) == { - "input_tokens": 20, - "output_tokens": 7, - "cached_input_tokens": 80, - } - - -def test_serialized_responses_usage_counts_input_token_cache_details(): - response = { - "usage": { - "input_tokens": 100, - "output_tokens": 7, - "input_tokens_details": { - "cached_tokens": 80, - }, - } - } - - assert extract_usage_token_details(response) == { - "input_tokens": 20, - "output_tokens": 7, - "cached_input_tokens": 80, - } - - -def test_serialized_usage_handles_mixed_token_field_names(): - response = { - "usage": { - "prompt_tokens": 100, - "output_tokens": 7, - "prompt_tokens_details": { - "cached_tokens": 80, - }, - } - } - - assert extract_usage_token_details(response) == { - "input_tokens": 20, - "output_tokens": 7, - "cached_input_tokens": 80, - } - - -def test_state_output_fallback_reads_serialized_trajectory_usage(): - task = vf.Task( - { - "example_id": 0, - "prompt": [{"role": "user", "content": "q"}], - } - ).freeze() - state = vf.State.for_task(task) - state["trajectory"] = [ - { - "response": { - "usage": { - "prompt_tokens": 100, - "completion_tokens": 7, - "prompt_tokens_details": {"cached_tokens": 80}, - } - } - } - ] - - output = state_to_output(state) - - assert output["token_usage"]["input_tokens"] == 20.0 - assert output["token_usage"]["cached_input_tokens"] == 80.0 - assert output["token_usage"]["output_tokens"] == 7.0 diff --git a/verifiers/clients/anthropic_messages_client.py b/verifiers/clients/anthropic_messages_client.py index 07f7e9ac3..71258822b 100644 --- a/verifiers/clients/anthropic_messages_client.py +++ b/verifiers/clients/anthropic_messages_client.py @@ -472,9 +472,13 @@ def parse_finish_reason(response: AnthropicMessage) -> FinishReason: cache_creation_input_tokens = getattr( response.usage, "cache_creation_input_tokens", None ) - if isinstance(cache_creation_input_tokens, int): + if isinstance(cache_creation_input_tokens, int) and not isinstance( + cache_creation_input_tokens, bool + ): input_tokens += cache_creation_input_tokens - if not isinstance(cached_input_tokens, int): + if not isinstance(cached_input_tokens, int) or isinstance( + cached_input_tokens, bool + ): cached_input_tokens = None return Response( diff --git a/verifiers/clients/client.py b/verifiers/clients/client.py index 5da565c12..b98806faf 100644 --- a/verifiers/clients/client.py +++ b/verifiers/clients/client.py @@ -19,7 +19,7 @@ SamplingArgs, Tool, ) -from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_request +from verifiers.utils.prompt_cache_utils import apply_prompt_cache_to_kwargs if TYPE_CHECKING: pass @@ -127,15 +127,10 @@ async def get_response( native_prompt, extra_kwargs = await self.to_native_prompt(prompt) native_tools = await self.to_native_tools(tools) - native_prompt, native_tools, sampling_args, extra_kwargs = ( - apply_prompt_cache_to_request( - config=self._config, - model=model, - native_prompt=native_prompt, - native_tools=native_tools, - sampling_args=sampling_args, - extra_kwargs=extra_kwargs, - ) + extra_kwargs = apply_prompt_cache_to_kwargs( + config=self._config, + sampling_args=sampling_args, + extra_kwargs=extra_kwargs, ) native_response = await self.get_native_response( native_prompt, diff --git a/verifiers/clients/openai_chat_completions_client.py b/verifiers/clients/openai_chat_completions_client.py index f55188ec3..a8915285d 100644 --- a/verifiers/clients/openai_chat_completions_client.py +++ b/verifiers/clients/openai_chat_completions_client.py @@ -99,30 +99,6 @@ def get_usage_field(usage: Any, key: str) -> Any: return getattr(usage, key, None) -def get_usage_int_field(usage: Any, key: str) -> int | None: - value = get_usage_field(usage, key) - if isinstance(value, int) and not isinstance(value, bool): - return value - return None - - -def get_first_usage_int_field(usage: Any, *keys: str) -> int | None: - for key in keys: - value = get_usage_int_field(usage, key) - if value is not None: - return value - return None - - -def get_cached_prompt_tokens(usage: Any) -> int | None: - details = get_usage_field(usage, "prompt_tokens_details") - if details is None: - details = get_usage_field(usage, "input_tokens_details") - if details is None: - return None - return get_usage_int_field(details, "cached_tokens") - - def content_to_text(content: Any) -> str: """Get all text content from OAI message content.""" if isinstance(content, str): @@ -435,19 +411,32 @@ def parse_usage(response: OpenAIChatResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_first_usage_int_field( - usage, "prompt_tokens", "input_tokens" - ) - completion_tokens = get_first_usage_int_field( - usage, "completion_tokens", "output_tokens" - ) - if prompt_tokens is None or completion_tokens is None: + prompt_tokens = get_usage_field(usage, "prompt_tokens") + completion_tokens = get_usage_field(usage, "completion_tokens") + if not isinstance(prompt_tokens, int) or not isinstance( + completion_tokens, int + ): + prompt_tokens = get_usage_field(usage, "input_tokens") + completion_tokens = get_usage_field(usage, "output_tokens") + total_tokens = get_usage_field(usage, "total_tokens") + if not isinstance(prompt_tokens, int) or not isinstance( + completion_tokens, int + ): return None - total_tokens = get_usage_int_field(usage, "total_tokens") - cached_tokens = get_cached_prompt_tokens(usage) - if cached_tokens is not None: - prompt_tokens = max(0, prompt_tokens - cached_tokens) - if total_tokens is None: + prompt_details = get_usage_field(usage, "prompt_tokens_details") + if prompt_details is None: + prompt_details = get_usage_field(usage, "input_tokens_details") + cached_tokens = None + if prompt_details is not None: + reported_cached_tokens = get_usage_field( + prompt_details, "cached_tokens" + ) + if isinstance(reported_cached_tokens, int) and not isinstance( + reported_cached_tokens, bool + ): + cached_tokens = reported_cached_tokens + prompt_tokens = max(0, prompt_tokens - cached_tokens) + if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens elif cached_tokens is not None: total_tokens = max(0, total_tokens - cached_tokens) diff --git a/verifiers/clients/openai_completions_client.py b/verifiers/clients/openai_completions_client.py index c03d5549b..137e7a49a 100644 --- a/verifiers/clients/openai_completions_client.py +++ b/verifiers/clients/openai_completions_client.py @@ -6,9 +6,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, - get_cached_prompt_tokens, - get_first_usage_int_field, - get_usage_int_field, + get_usage_field, handle_openai_overlong_prompt, ) from verifiers.errors import ( @@ -114,28 +112,25 @@ def parse_usage(response: OpenAITextResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_first_usage_int_field( - usage, "prompt_tokens", "input_tokens" - ) - completion_tokens = get_first_usage_int_field( - usage, "completion_tokens", "output_tokens" - ) - if prompt_tokens is None or completion_tokens is None: + prompt_tokens = get_usage_field(usage, "prompt_tokens") + completion_tokens = get_usage_field(usage, "completion_tokens") + if not isinstance(prompt_tokens, int) or not isinstance( + completion_tokens, int + ): + prompt_tokens = get_usage_field(usage, "input_tokens") + completion_tokens = get_usage_field(usage, "output_tokens") + total_tokens = get_usage_field(usage, "total_tokens") + if not isinstance(prompt_tokens, int) or not isinstance( + completion_tokens, int + ): return None - total_tokens = get_usage_int_field(usage, "total_tokens") - cached_tokens = get_cached_prompt_tokens(usage) - if cached_tokens is not None: - prompt_tokens = max(0, prompt_tokens - cached_tokens) - if total_tokens is None: + if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens - elif cached_tokens is not None: - total_tokens = max(0, total_tokens - cached_tokens) return Usage( prompt_tokens=prompt_tokens, reasoning_tokens=0, completion_tokens=completion_tokens, total_tokens=total_tokens, - cached_input_tokens=cached_tokens, ) def parse_finish_reason(response: OpenAITextResponse) -> FinishReason: diff --git a/verifiers/clients/openai_responses_client.py b/verifiers/clients/openai_responses_client.py index 8e0f499ef..8555a09dd 100644 --- a/verifiers/clients/openai_responses_client.py +++ b/verifiers/clients/openai_responses_client.py @@ -8,9 +8,7 @@ from verifiers.clients.client import Client from verifiers.clients.openai_chat_completions_client import ( content_to_text, - get_cached_prompt_tokens, get_usage_field, - get_usage_int_field, handle_openai_overlong_prompt, ) from verifiers.errors import EmptyModelResponseError, InvalidModelResponseError @@ -374,25 +372,35 @@ def parse_usage(response: OpenAIResponsesNativeResponse) -> Usage | None: usage = getattr(response, "usage", None) if usage is None: return None - prompt_tokens = get_usage_int_field(usage, "input_tokens") - completion_tokens = get_usage_int_field(usage, "output_tokens") - total_tokens = get_usage_int_field(usage, "total_tokens") + prompt_tokens = get_usage_field(usage, "input_tokens") + completion_tokens = get_usage_field(usage, "output_tokens") + total_tokens = get_usage_field(usage, "total_tokens") output_details = get_usage_field(usage, "output_tokens_details") reasoning_tokens = ( - get_usage_int_field(output_details, "reasoning_tokens") + get_usage_field(output_details, "reasoning_tokens") if output_details is not None else 0 ) - if prompt_tokens is None or completion_tokens is None: + if not isinstance(prompt_tokens, int) or not isinstance( + completion_tokens, int + ): return None - cached_tokens = get_cached_prompt_tokens(usage) - if cached_tokens is not None: - prompt_tokens = max(0, prompt_tokens - cached_tokens) - if total_tokens is None: + input_details = get_usage_field(usage, "input_tokens_details") + if input_details is None: + input_details = get_usage_field(usage, "prompt_tokens_details") + cached_tokens = None + if input_details is not None: + reported_cached_tokens = get_usage_field(input_details, "cached_tokens") + if isinstance(reported_cached_tokens, int) and not isinstance( + reported_cached_tokens, bool + ): + cached_tokens = reported_cached_tokens + prompt_tokens = max(0, prompt_tokens - cached_tokens) + if not isinstance(total_tokens, int): total_tokens = prompt_tokens + completion_tokens elif cached_tokens is not None: total_tokens = max(0, total_tokens - cached_tokens) - if reasoning_tokens is None: + if not isinstance(reasoning_tokens, int): reasoning_tokens = 0 return Usage( prompt_tokens=prompt_tokens, diff --git a/verifiers/utils/prompt_cache_utils.py b/verifiers/utils/prompt_cache_utils.py index 9e3676bba..822216a88 100644 --- a/verifiers/utils/prompt_cache_utils.py +++ b/verifiers/utils/prompt_cache_utils.py @@ -1,12 +1,9 @@ from collections.abc import Mapping -from typing import Any, TypeVar +from typing import Any from urllib.parse import urlsplit from verifiers.types import ClientConfig -NativePromptT = TypeVar("NativePromptT") -NativeToolsT = TypeVar("NativeToolsT") - ANTHROPIC_ORIGINS = frozenset({"https://api.anthropic.com"}) @@ -39,21 +36,16 @@ def _cache_control_payload() -> dict[str, str]: return {"type": "ephemeral"} -def apply_prompt_cache_to_request( +def apply_prompt_cache_to_kwargs( *, config: ClientConfig | None, - model: str, - native_prompt: NativePromptT, - native_tools: NativeToolsT, sampling_args: Mapping[str, Any], extra_kwargs: Mapping[str, Any], -) -> tuple[NativePromptT, NativeToolsT, dict[str, Any], dict[str, Any]]: - _ = model - updated_sampling_args = dict(sampling_args) +) -> dict[str, Any]: updated_extra_kwargs = dict(extra_kwargs) if ( uses_official_anthropic_messages(config) - and "cache_control" not in updated_sampling_args + and "cache_control" not in sampling_args ): updated_extra_kwargs.setdefault("cache_control", _cache_control_payload()) - return native_prompt, native_tools, updated_sampling_args, updated_extra_kwargs + return updated_extra_kwargs diff --git a/verifiers/utils/save_utils.py b/verifiers/utils/save_utils.py index 45240f489..a70e86611 100644 --- a/verifiers/utils/save_utils.py +++ b/verifiers/utils/save_utils.py @@ -16,6 +16,7 @@ ErrorInfo, GenerateMetadata, GenerateOutputs, + Response, RolloutOutput, SamplingArgs, State, @@ -41,8 +42,7 @@ from verifiers.utils.path_utils import get_results_path from verifiers.utils.usage_utils import ( StateUsageTracker, - cast_token_usage, - extract_usage_token_details, + response_usage_tokens, ) from verifiers.utils.version_utils import get_version_info @@ -129,39 +129,46 @@ def _token_usage_from_mapping(value: object, context: str) -> TokenUsage | None: for key in ("final_input_tokens", "final_output_tokens"): if key in mapping_value and mapping_value[key] is not None: usage[key] = _token_count(mapping_value[key], f"{context}.{key}") - if "cached_input_tokens" in mapping_value: - cached_input = mapping_value["cached_input_tokens"] - if cached_input is not None: - usage["cached_input_tokens"] = _token_count( - cached_input, f"{context}.cached_input_tokens" - ) + if ( + "cached_input_tokens" in mapping_value + and mapping_value["cached_input_tokens"] is not None + ): + usage["cached_input_tokens"] = _token_count( + mapping_value["cached_input_tokens"], f"{context}.cached_input_tokens" + ) return usage def _token_usage_from_trajectory(trajectory: object) -> TokenUsage | None: if not isinstance(trajectory, list): return None - usage_totals: dict[str, float] = { - "input_tokens": 0.0, - "output_tokens": 0.0, - } + input_tokens = 0 + output_tokens = 0 + cached_input_tokens = 0 usage_seen = False for index, step in enumerate(trajectory): if not isinstance(step, Mapping): raise TypeError(f"state.trajectory[{index}] must be a mapping.") step_mapping = cast(Mapping[str, object], step) response = step_mapping.get("response") - if response is None: + if response is None or not isinstance(response, Response): continue - details = extract_usage_token_details(response) - if details is None: + if response.usage is None: continue usage_seen = True - for key, value in details.items(): - usage_totals[key] = usage_totals.get(key, 0.0) + float(value) + step_input_tokens, step_output_tokens = response_usage_tokens(response) + input_tokens += step_input_tokens + output_tokens += step_output_tokens + cached_input_tokens += response.usage.cached_input_tokens or 0 if not usage_seen: return None - return cast_token_usage(usage_totals) + usage = TokenUsage( + input_tokens=float(input_tokens), + output_tokens=float(output_tokens), + ) + if cached_input_tokens > 0: + usage["cached_input_tokens"] = float(cached_input_tokens) + return usage def _extract_state_token_usage(state: State) -> TokenUsage | None: @@ -233,13 +240,11 @@ def state_to_output( usage = _extract_state_token_usage(state) if usage is not None: token_usage: dict[str, float] = { - "input_tokens": usage["input_tokens"], - "output_tokens": usage["output_tokens"], + "input_tokens": usage.get("input_tokens", 0.0), + "output_tokens": usage.get("output_tokens", 0.0), } - for key in ("cached_input_tokens", "final_input_tokens", "final_output_tokens"): - value = usage.get(key) - if value is not None: - token_usage[key] = value + if usage.get("cached_input_tokens") is not None: + token_usage["cached_input_tokens"] = usage["cached_input_tokens"] # Add context token metrics from trajectory trajectory = state.get("trajectory", []) if isinstance(trajectory, list): diff --git a/verifiers/utils/usage_utils.py b/verifiers/utils/usage_utils.py index c7acb22e0..d6ffe5f3d 100644 --- a/verifiers/utils/usage_utils.py +++ b/verifiers/utils/usage_utils.py @@ -1,90 +1,9 @@ from collections.abc import Mapping, Sequence from types import MappingProxyType -from typing import Any, cast from verifiers.types import Response, TokenUsage, Usage -def _get_field(obj: object, key: str) -> object: - if isinstance(obj, Mapping): - return cast(Mapping[str, object], obj).get(key) - return getattr(obj, key, None) - - -def _as_token_count(value: object) -> int | None: - if isinstance(value, bool): - return None - if isinstance(value, int): - return max(0, value) - if isinstance(value, float) and value.is_integer(): - return max(0, int(value)) - return None - - -def _response_usage(response: object) -> object | None: - return _get_field(response, "usage") - - -def _nested_cached_tokens(usage: object) -> int | None: - for details_key in ("prompt_tokens_details", "input_tokens_details"): - details = _get_field(usage, details_key) - if details is None: - continue - cached = _as_token_count(_get_field(details, "cached_tokens")) - if cached is not None: - return cached - return None - - -def _cache_creation_tokens(usage: object) -> int: - return _as_token_count(_get_field(usage, "cache_creation_input_tokens")) or 0 - - -def _direct_cached_tokens(usage: object) -> int | None: - cached = _as_token_count(_get_field(usage, "cached_input_tokens")) - if cached is not None: - return cached - return _as_token_count(_get_field(usage, "cache_read_input_tokens")) - - -def extract_usage_token_details(response: object) -> dict[str, int] | None: - usage = _response_usage(response) - if usage is None: - return None - - input_tokens = _as_token_count(_get_field(usage, "prompt_tokens")) - if input_tokens is None: - input_tokens = _as_token_count(_get_field(usage, "input_tokens")) - output_tokens = _as_token_count(_get_field(usage, "completion_tokens")) - if output_tokens is None: - output_tokens = _as_token_count(_get_field(usage, "output_tokens")) - if input_tokens is None or output_tokens is None: - return None - - input_tokens += _cache_creation_tokens(usage) - details = { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - } - - cached_tokens = _direct_cached_tokens(usage) - if cached_tokens is not None: - details["cached_input_tokens"] = cached_tokens - return details - - cached_tokens = _nested_cached_tokens(usage) - if cached_tokens is not None: - details["input_tokens"] = max(0, input_tokens - cached_tokens) - details["cached_input_tokens"] = cached_tokens - return details - - -def usage_tokens(usage: Usage) -> tuple[int, int]: - if usage.prompt_tokens < 0 or usage.completion_tokens < 0: - raise ValueError("Response usage tokens must be non-negative.") - return usage.prompt_tokens, usage.completion_tokens - - def response_usage_tokens(response: Response) -> tuple[int, int]: usage = response.usage if usage is None: @@ -92,15 +11,10 @@ def response_usage_tokens(response: Response) -> tuple[int, int]: return usage_tokens(usage) -def cast_token_usage(usage: Mapping[str, Any]) -> TokenUsage: - out: TokenUsage = { - "input_tokens": float(usage.get("input_tokens", 0.0)), - "output_tokens": float(usage.get("output_tokens", 0.0)), - } - cached = usage.get("cached_input_tokens") - if cached is not None: - out["cached_input_tokens"] = float(cached) - return out +def usage_tokens(usage: Usage) -> tuple[int, int]: + if usage.prompt_tokens < 0 or usage.completion_tokens < 0: + raise ValueError("Response usage tokens must be non-negative.") + return usage.prompt_tokens, usage.completion_tokens class StateUsageTracker: @@ -128,34 +42,42 @@ def increment( cached_input_tokens: int | float | None = None, mark_seen: bool = True, ) -> None: - deltas: dict[str, float] = { - "input_tokens": float(input_tokens or 0.0), - "output_tokens": float(output_tokens or 0.0), - } - if cached_input_tokens is not None: - deltas["cached_input_tokens"] = float(cached_input_tokens or 0.0) - if any(delta < 0 for delta in deltas.values()): + input_delta = float(input_tokens or 0.0) + output_delta = float(output_tokens or 0.0) + cached_input_delta = float(cached_input_tokens or 0.0) + if input_delta < 0 or output_delta < 0 or cached_input_delta < 0: raise ValueError("Token usage increments must be non-negative.") if mark_seen: self._usage_seen = True - for key, delta in deltas.items(): - self._usage_totals[key] = self._usage_totals.get(key, 0.0) + delta + self._usage_totals["input_tokens"] += input_delta + self._usage_totals["output_tokens"] += output_delta + if cached_input_tokens is not None: + self._usage_totals["cached_input_tokens"] = ( + self._usage_totals.get("cached_input_tokens", 0.0) + cached_input_delta + ) - def increment_from_response(self, response: object) -> None: - details = extract_usage_token_details(response) - if details is None: + def increment_from_response(self, response: Response) -> None: + if response.usage is None: return + input_tokens, output_tokens = response_usage_tokens(response) self.increment( - details["input_tokens"], - details["output_tokens"], - cached_input_tokens=details.get("cached_input_tokens"), + input_tokens, + output_tokens, + cached_input_tokens=response.usage.cached_input_tokens, mark_seen=True, ) def snapshot(self) -> TokenUsage | None: if not self._usage_seen: return None - return cast_token_usage(self._usage_totals) + usage: TokenUsage = { + "input_tokens": self._usage_totals["input_tokens"], + "output_tokens": self._usage_totals["output_tokens"], + } + cached_input_tokens = self._usage_totals.get("cached_input_tokens", 0.0) + if cached_input_tokens > 0: + usage["cached_input_tokens"] = cached_input_tokens + return usage def compute_context_token_metrics( @@ -170,37 +92,45 @@ def compute_context_token_metrics( Returns a dict with: final_output_tokens: Model-generated tokens (sum of completion_tokens across all steps). - final_input_tokens: Non-model tokens in context (system prompts, user - messages, tool results, etc.). + final_input_tokens: Non-model tokens in context (last step's total + context minus final_output_tokens). """ - zero = {"final_output_tokens": 0.0, "final_input_tokens": 0.0} + _zero: dict[str, float] = { + "final_output_tokens": 0, + "final_input_tokens": 0, + } if not trajectory: - return zero + return _zero + # Find the last step with usage data. last_step_total = 0 found = False for step in reversed(trajectory): - details = extract_usage_token_details(step.get("response")) - if details is None: + response = step.get("response") + if not isinstance(response, Response) or response.usage is None: continue + prompt_tokens, completion_tokens = response_usage_tokens(response) last_step_total = ( - details["input_tokens"] - + details.get("cached_input_tokens", 0) - + details["output_tokens"] + prompt_tokens + + (response.usage.cached_input_tokens or 0) + + completion_tokens ) found = True break if not found: - return zero + return _zero + # Sum completion tokens across all steps with usage data. total_completion = 0 for step in trajectory: - details = extract_usage_token_details(step.get("response")) - if details is not None: - total_completion += details["output_tokens"] + response = step.get("response") + if not isinstance(response, Response) or response.usage is None: + continue + _, completion_tokens = response_usage_tokens(response) + total_completion += completion_tokens return { - "final_output_tokens": float(total_completion), - "final_input_tokens": float(max(0, last_step_total - total_completion)), + "final_output_tokens": total_completion, + "final_input_tokens": max(0, last_step_total - total_completion), }