Skip to content

introduce cached tokens to usage #1133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,12 +441,17 @@ def _map_usage(message: AnthropicMessage | RawMessageStreamEvent) -> usage.Usage
return usage.Usage()

request_tokens = getattr(response_usage, 'input_tokens', None)
cache_creation_input_tokens = getattr(response_usage, 'cache_creation_input_tokens', None)
cache_read_input_tokens = getattr(response_usage, 'cache_read_input_tokens', None)

total_request_tokens = (request_tokens or 0) + (cache_creation_input_tokens or 0) + (cache_read_input_tokens or 0)

return usage.Usage(
# Usage coming from the RawMessageDeltaEvent doesn't have input token data, hence this getattr
request_tokens=request_tokens,
request_tokens=total_request_tokens,
response_tokens=response_usage.output_tokens,
total_tokens=(request_tokens or 0) + response_usage.output_tokens,
cached_tokens=cache_read_input_tokens,
total_tokens=total_request_tokens + response_usage.output_tokens,
)


Expand Down
1 change: 1 addition & 0 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,7 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
request_tokens=metadata.get('prompt_token_count', 0),
response_tokens=metadata.get('candidates_token_count', 0),
total_tokens=metadata.get('total_token_count', 0),
cached_tokens=metadata.get('cached_content_token_count', 0),
details=details,
)

Expand Down
7 changes: 7 additions & 0 deletions pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,9 +508,16 @@ def _map_usage(response: chat.ChatCompletion | ChatCompletionChunk) -> usage.Usa
details.update(response_usage.completion_tokens_details.model_dump(exclude_none=True))
if response_usage.prompt_tokens_details is not None:
details.update(response_usage.prompt_tokens_details.model_dump(exclude_none=True))
cached_tokens = 0
if (
response_usage.prompt_tokens_details is not None
and response_usage.prompt_tokens_details.cached_tokens is not None
):
cached_tokens = response_usage.prompt_tokens_details.cached_tokens
return usage.Usage(
request_tokens=response_usage.prompt_tokens,
response_tokens=response_usage.completion_tokens,
cached_tokens=cached_tokens,
total_tokens=response_usage.total_tokens,
details=details,
)
5 changes: 4 additions & 1 deletion pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Usage:
"""Tokens used in processing requests."""
response_tokens: int | None = None
"""Tokens used in generating responses."""
cached_tokens: int | None = None
"""Number of input tokens that were a cache hit."""
total_tokens: int | None = None
"""Total tokens used in the whole run, should generally be equal to `request_tokens + response_tokens`."""
details: dict[str, int] | None = None
Expand All @@ -36,7 +38,7 @@ def incr(self, incr_usage: Usage, *, requests: int = 0) -> None:
requests: The number of requests to increment by in addition to `incr_usage.requests`.
"""
self.requests += requests
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens':
for f in 'requests', 'request_tokens', 'response_tokens', 'total_tokens', 'cached_tokens':
self_value = getattr(self, f)
other_value = getattr(incr_usage, f)
if self_value is not None or other_value is not None:
Expand All @@ -61,6 +63,7 @@ def opentelemetry_attributes(self) -> dict[str, int]:
result = {
'gen_ai.usage.input_tokens': self.request_tokens,
'gen_ai.usage.output_tokens': self.response_tokens,
'gen_ai.usage.cached_tokens': self.cached_tokens,
}
for key, value in (self.details or {}).items():
result[f'gen_ai.usage.details.{key}'] = value
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ async def run(self, ctx: GraphRunContext) -> End[None]:

assert isinstance(n, BaseNode)
n = await run.next()
assert n == snapshot(End(None))
assert n == snapshot(End(data=None))

with pytest.raises(TypeError, match=r'`next` must be called with a `BaseNode` instance, got End\(data=None\).'):
await run.next()
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ async def run(self, ctx: GraphRunContext) -> End[int]:
node = Foo()
async with graph.iter(node, persistence=sp) as run:
end = await run.next()
assert end == snapshot(End(123))
assert end == snapshot(End(data=123))

msg = "Incorrect snapshot status 'success', must be 'created' or 'pending'."
with pytest.raises(GraphNodeStatusError, match=msg):
Expand Down
28 changes: 21 additions & 7 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,9 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
),
]
)
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=1, response_tokens=2, cached_tokens=0, total_tokens=3)
)

result = await agent.run('Hello', message_history=result.new_messages())
assert result.data == 'Hello world'
Expand Down Expand Up @@ -613,7 +615,9 @@ async def get_location(loc_name: str) -> str:
),
]
)
assert result.usage() == snapshot(Usage(requests=3, request_tokens=3, response_tokens=6, total_tokens=9))
assert result.usage() == snapshot(
Usage(requests=3, request_tokens=3, response_tokens=6, cached_tokens=0, total_tokens=9)
)


async def test_unexpected_response(client_with_handler: ClientWithHandler, env: TestEnv, allow_model_requests: None):
Expand Down Expand Up @@ -654,12 +658,16 @@ async def test_stream_text(get_gemini_client: GetGeminiClient):
'Hello world',
]
)
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6)
)

async with agent.run_stream('Hello') as result:
chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)]
assert chunks == snapshot(['Hello ', 'world'])
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6)
)


async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
Expand Down Expand Up @@ -691,7 +699,9 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
async with agent.run_stream('Hello') as result:
chunks = [chunk async for chunk in result.stream(debounce_by=None)]
assert chunks == snapshot(['abc', 'abc€def', 'abc€def'])
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=2, response_tokens=4, cached_tokens=0, total_tokens=6)
)


async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
Expand Down Expand Up @@ -721,7 +731,9 @@ async def test_stream_structured(get_gemini_client: GetGeminiClient):
async with agent.run_stream('Hello') as result:
chunks = [chunk async for chunk in result.stream(debounce_by=None)]
assert chunks == snapshot([(1, 2), (1, 2)])
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=1, response_tokens=2, cached_tokens=0, total_tokens=3)
)


async def test_stream_structured_tool_calls(get_gemini_client: GetGeminiClient):
Expand Down Expand Up @@ -762,7 +774,9 @@ async def bar(y: str) -> str:
async with agent.run_stream('Hello') as result:
response = await result.get_data()
assert response == snapshot((1, 2))
assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9))
assert result.usage() == snapshot(
Usage(requests=2, request_tokens=3, response_tokens=6, cached_tokens=0, total_tokens=9)
)
assert result.all_messages() == snapshot(
[
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
Expand Down
17 changes: 13 additions & 4 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ async def test_request_simple_usage(allow_model_requests: None):

result = await agent.run('Hello')
assert result.data == 'world'
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=2, response_tokens=1, cached_tokens=0, total_tokens=3)
)


async def test_request_structured_response(allow_model_requests: None):
Expand Down Expand Up @@ -380,6 +382,7 @@ async def get_location(loc_name: str) -> str:
requests=3,
request_tokens=5,
response_tokens=3,
cached_tokens=3,
total_tokens=9,
details={'cached_tokens': 3},
)
Expand Down Expand Up @@ -416,7 +419,9 @@ async def test_stream_text(allow_model_requests: None):
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world'])
assert result.is_complete
assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=6, response_tokens=3, cached_tokens=0, total_tokens=9)
)


async def test_stream_text_finish_reason(allow_model_requests: None):
Expand Down Expand Up @@ -487,7 +492,9 @@ async def test_stream_structured(allow_model_requests: None):
]
)
assert result.is_complete
assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=20, response_tokens=10, cached_tokens=0, total_tokens=30)
)
# double check usage matches stream count
assert result.usage().response_tokens == len(stream)

Expand Down Expand Up @@ -543,7 +550,9 @@ async def test_no_delta(allow_model_requests: None):
assert not result.is_complete
assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world'])
assert result.is_complete
assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9))
assert result.usage() == snapshot(
Usage(requests=1, request_tokens=6, response_tokens=3, cached_tokens=0, total_tokens=9)
)


@pytest.mark.parametrize('system_prompt_role', ['system', 'developer', 'user', None])
Expand Down
Loading