diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index f19afea8d..25ff1d49b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -596,6 +596,13 @@ class ModelResponse: For OpenAI models, this may include 'logprobs', 'finish_reason', etc. """ + finish_reason: Literal['stop', 'length', 'content_filter', 'tool_calls', 'error'] | str | None = None + """The reason why the model finished generating the response. + This can be one of the standard reasons like 'stop', 'length', 'content_filter', 'tool_calls', or 'error', + or a custom reason provided by the model. + If the model does not provide a finish reason, this will be `None`. + """ + vendor_id: str | None = None """Vendor ID as specified by the model provider. This can be used to track the specific request to the model.""" @@ -605,6 +612,8 @@ def otel_events(self) -> list[Event]: def new_event_body(): new_body: dict[str, Any] = {'role': 'assistant'} + if self.finish_reason is not None: + new_body['finish_reason'] = self.finish_reason ev = Event('gen_ai.assistant.message', body=new_body) result.append(ev) return new_body diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 11650107a..bdd87f916 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -68,6 +68,22 @@ See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. """ +_FINISH_REASONS = { + 'STOP': 'stop', + 'MAX_TOKENS': 'length', + 'SAFETY': 'content_filter', + 'RECITATION': 'content_filter', + 'LANGUAGE': 'content_filter', + 'BLOCKLIST': 'content_filter', + 'PROHIBITED_CONTENT': 'content_filter', + 'SPII': 'content_filter', + 'MALFORMED_FUNCTION_CALL': 'error', # or 'tool_calls' if you prefer + 'OTHER': 'error', + 'FINISH_REASON_UNSPECIFIED': 'error', # unspecified is still a model stop reason + 'IMAGE_SAFETY': 'content_filter', + None: None, +} + class GeminiModelSettings(ModelSettings, total=False): """Settings used for a Gemini model request. @@ -251,30 +267,22 @@ async def _make_request( yield r def _process_response(self, response: _GeminiResponse) -> ModelResponse: - vendor_details: dict[str, Any] | None = None - if len(response['candidates']) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover + finish_reason_key = response['candidates'][0].get('finish_reason', None) + finish_reason = _FINISH_REASONS.get(finish_reason_key, finish_reason_key) if 'content' not in response['candidates'][0]: - if response['candidates'][0].get('finish_reason') == 'SAFETY': + if finish_reason_key == 'SAFETY': raise UnexpectedModelBehavior('Safety settings triggered', str(response)) else: raise UnexpectedModelBehavior( # pragma: no cover 'Content field missing from Gemini response', str(response) ) parts = response['candidates'][0]['content']['parts'] - vendor_id = response.get('vendor_id', None) - finish_reason = response['candidates'][0].get('finish_reason') - if finish_reason: - vendor_details = {'finish_reason': finish_reason} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( - parts, - response.get('model_version', self._model_name), - usage, - vendor_id=vendor_id, - vendor_details=vendor_details, + parts, response.get('model_version', self._model_name), usage, finish_reason ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: @@ -623,8 +631,7 @@ def _process_response_from_parts( parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage, - vendor_id: str | None, - vendor_details: dict[str, Any] | None = None, + finish_reason: str | None = None, ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -636,9 +643,7 @@ def _process_response_from_parts( raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse( - parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details - ) + return ModelResponse(parts=items, usage=usage, model_name=model_name, finish_reason=finish_reason) class _GeminiFunctionCall(TypedDict): @@ -760,7 +765,25 @@ class _GeminiCandidates(TypedDict): """See .""" content: NotRequired[_GeminiContent] - finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]] + finish_reason: NotRequired[ + Annotated[ + Literal[ + 'STOP', + 'MAX_TOKENS', + 'SAFETY', + 'MALFORMED_FUNCTION_CALL', + 'FINISH_REASON_UNSPECIFIED', + 'RECITATION', + 'LANGUAGE', + 'BLOCKLIST', + 'PROHIBITED_CONTENT', + 'SPII', + 'OTHER', + 'IMAGE_SAFETY', + ], + pydantic.Field(alias='finishReason'), + ] + ] """ See , lots of other values are possible, but let's wait until we see them and know what they mean to add them here. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 99f9950f0..8a1964fb0 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -5,7 +5,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Literal, Union, cast, overload +from typing import Literal, Union, cast, overload from uuid import uuid4 from typing_extensions import assert_never @@ -91,6 +91,21 @@ allow any name in the type hints. See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. """ +_FINISH_REASONS = { + 'STOP': 'stop', + 'MAX_TOKENS': 'length', + 'SAFETY': 'content_filter', + 'RECITATION': 'content_filter', + 'LANGUAGE': 'content_filter', + 'BLOCKLIST': 'content_filter', + 'PROHIBITED_CONTENT': 'content_filter', + 'SPII': 'content_filter', + 'MALFORMED_FUNCTION_CALL': 'error', # or 'tool_calls' if you prefer + 'OTHER': 'error', + 'FINISH_REASON_UNSPECIFIED': 'error', + 'IMAGE_SAFETY': 'content_filter', + None: None, +} class GoogleModelSettings(ModelSettings, total=False): @@ -275,6 +290,9 @@ async def _generate_content( def _process_response(self, response: GenerateContentResponse) -> ModelResponse: if not response.candidates or len(response.candidates) != 1: raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover + finish_reason_key = response.candidates[0].finish_reason or None + finish_reason = _FINISH_REASONS.get(finish_reason_key, finish_reason_key) + if response.candidates[0].content is None or response.candidates[0].content.parts is None: if response.candidates[0].finish_reason == 'SAFETY': raise UnexpectedModelBehavior('Safety settings triggered', str(response)) @@ -284,14 +302,14 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: ) # pragma: no cover parts = response.candidates[0].content.parts or [] vendor_id = response.response_id or None - vendor_details: dict[str, Any] | None = None - finish_reason = response.candidates[0].finish_reason - if finish_reason: # pragma: no branch - vendor_details = {'finish_reason': finish_reason.value} usage = _metadata_as_usage(response) usage.requests = 1 return _process_response_from_parts( - parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details + parts, + response.model_version or self._model_name, + usage, + vendor_id=vendor_id, + finish_reason=finish_reason, ) async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: @@ -443,7 +461,7 @@ def _process_response_from_parts( model_name: GoogleModelName, usage: usage.Usage, vendor_id: str | None, - vendor_details: dict[str, Any] | None = None, + finish_reason: str | None = None, ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -460,7 +478,7 @@ def _process_response_from_parts( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) return ModelResponse( - parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details + parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, finish_reason=finish_reason ) diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 00ada3eaf..b06bbb706 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -243,6 +243,10 @@ def finish(response: ModelResponse): ) ) new_attributes: dict[str, AttributeValue] = response.usage.opentelemetry_attributes() # pyright: ignore[reportAssignmentType] + if response.vendor_id is not None: + new_attributes['gen_ai.response.id'] = response.vendor_id + if response.finish_reason is not None: + new_attributes['gen_ai.response.finish_reasons'] = [response.finish_reason] attributes.update(getattr(span, 'attributes', {})) request_model = attributes[GEN_AI_REQUEST_MODEL_ATTRIBUTE] new_attributes['gen_ai.response.model'] = response.model_name or request_model diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 1734104b0..542fd9806 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -544,7 +544,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) @@ -560,15 +560,15 @@ async def test_text_success(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( parts=[TextPart(content='Hello world')], usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', + finish_reason='stop', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, ), ] ) @@ -592,7 +592,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -655,7 +655,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -675,7 +675,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -698,7 +698,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}), model_name='gemini-1.5-flash-123', timestamp=IsNow(tz=timezone.utc), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) @@ -1119,7 +1119,7 @@ async def get_image() -> BinaryContent: ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -1149,7 +1149,7 @@ async def get_image() -> BinaryContent: ), model_name='gemini-2.5-pro-preview-03-25', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) @@ -1277,8 +1277,8 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_ details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8}, ), model_name='gemini-1.5-flash', + finish_reason='stop', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, ), ] ) diff --git a/tests/models/test_google.py b/tests/models/test_google.py index ca8a82a73..c8118ef67 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -99,7 +99,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) @@ -167,7 +167,7 @@ async def temperature(city: str, date: datetime.date) -> str: ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -193,7 +193,7 @@ async def temperature(city: str, date: datetime.date) -> str: ), model_name='gemini-1.5-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -252,7 +252,7 @@ async def get_capital(country: str) -> str: ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ModelRequest( parts=[ @@ -279,7 +279,7 @@ async def get_capital(country: str) -> str: ), model_name='models/gemini-2.5-pro-preview-05-06', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) @@ -542,7 +542,7 @@ def instructions() -> str: ), model_name='gemini-2.0-flash', timestamp=IsDatetime(), - vendor_details={'finish_reason': 'STOP'}, + finish_reason='stop', ), ] ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index f7caad399..b15f1da84 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -81,6 +81,8 @@ async def request( ], usage=Usage(request_tokens=100, response_tokens=200), model_name='my_model_123', + vendor_id='chatcmpl-123', + finish_reason='stop', ) @asynccontextmanager @@ -158,6 +160,8 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'logfire.span_type': 'span', 'gen_ai.response.model': 'my_model_123', 'gen_ai.usage.input_tokens': 100, + 'gen_ai.response.id': 'chatcmpl-123', + 'gen_ai.response.finish_reasons': ('stop',), 'gen_ai.usage.output_tokens': 200, }, }, @@ -277,6 +281,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'index': 0, 'message': { 'role': 'assistant', + 'finish_reason': 'stop', 'content': 'text1', 'tool_calls': [ { @@ -302,7 +307,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'trace_flags': 1, }, { - 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text2'}}, + 'body': {'index': 0, 'message': {'role': 'assistant', 'finish_reason': 'stop', 'content': 'text2'}}, 'severity_number': 9, 'severity_text': None, 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, @@ -566,6 +571,8 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'logfire.span_type': 'span', 'gen_ai.response.model': 'my_model_123', 'gen_ai.usage.input_tokens': 100, + 'gen_ai.response.id': 'chatcmpl-123', + 'gen_ai.response.finish_reasons': ('stop',), 'gen_ai.usage.output_tokens': 200, 'events': IsJson( snapshot( @@ -629,6 +636,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'index': 0, 'message': { 'role': 'assistant', + 'finish_reason': 'stop', 'content': 'text1', 'tool_calls': [ { @@ -648,7 +656,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): { 'event.name': 'gen_ai.choice', 'index': 0, - 'message': {'role': 'assistant', 'content': 'text2'}, + 'message': {'role': 'assistant', 'finish_reason': 'stop', 'content': 'text2'}, 'gen_ai.system': 'my_system', }, ] diff --git a/tests/test_agent.py b/tests/test_agent.py index d59848155..62f0a7e18 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2356,6 +2356,7 @@ def test_binary_content_all_messages_json(): 'vendor_id': None, 'timestamp': IsStr(), 'kind': 'response', + 'finish_reason': None, 'vendor_details': None, }, ]