-
Notifications
You must be signed in to change notification settings - Fork 965
Add id and finish_reason to OpenTelemetry instrumentation (closes #886) #1882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
2c57b2d
cdfec33
3e26c73
dd0e70c
e15b836
6889401
179a3b4
cea7136
e57605f
cbbe508
c178c7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,6 +68,22 @@ | |
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. | ||
""" | ||
|
||
_FINISH_REASONS = { | ||
'STOP': 'stop', | ||
'MAX_TOKENS': 'length', | ||
'SAFETY': 'content_filter', | ||
'RECITATION': 'content_filter', | ||
'LANGUAGE': 'content_filter', | ||
'BLOCKLIST': 'content_filter', | ||
'PROHIBITED_CONTENT': 'content_filter', | ||
'SPII': 'content_filter', | ||
'MALFORMED_FUNCTION_CALL': 'error', # or 'tool_calls' if you prefer | ||
'OTHER': 'error', | ||
'FINISH_REASON_UNSPECIFIED': 'error', # unspecified is still a model stop reason | ||
'IMAGE_SAFETY': 'content_filter', | ||
None: None, | ||
} | ||
|
||
|
||
class GeminiModelSettings(ModelSettings, total=False): | ||
"""Settings used for a Gemini model request. | ||
|
@@ -251,30 +267,22 @@ async def _make_request( | |
yield r | ||
|
||
def _process_response(self, response: _GeminiResponse) -> ModelResponse: | ||
vendor_details: dict[str, Any] | None = None | ||
|
||
if len(response['candidates']) != 1: | ||
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover | ||
finish_reason_key = response['candidates'][0].get('finish_reason', None) | ||
finish_reason = _FINISH_REASONS.get(finish_reason_key, finish_reason_key) | ||
if 'content' not in response['candidates'][0]: | ||
if response['candidates'][0].get('finish_reason') == 'SAFETY': | ||
if finish_reason_key == 'SAFETY': | ||
raise UnexpectedModelBehavior('Safety settings triggered', str(response)) | ||
else: | ||
raise UnexpectedModelBehavior( # pragma: no cover | ||
'Content field missing from Gemini response', str(response) | ||
) | ||
parts = response['candidates'][0]['content']['parts'] | ||
vendor_id = response.get('vendor_id', None) | ||
finish_reason = response['candidates'][0].get('finish_reason') | ||
if finish_reason: | ||
vendor_details = {'finish_reason': finish_reason} | ||
usage = _metadata_as_usage(response) | ||
usage.requests = 1 | ||
return _process_response_from_parts( | ||
parts, | ||
response.get('model_version', self._model_name), | ||
usage, | ||
vendor_id=vendor_id, | ||
vendor_details=vendor_details, | ||
parts, response.get('model_version', self._model_name), usage, finish_reason | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like we accidentally dropped the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also for backward compatibility reasons, let's keep the untouched |
||
) | ||
|
||
async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: | ||
|
@@ -623,8 +631,7 @@ def _process_response_from_parts( | |
parts: Sequence[_GeminiPartUnion], | ||
model_name: GeminiModelName, | ||
usage: usage.Usage, | ||
vendor_id: str | None, | ||
vendor_details: dict[str, Any] | None = None, | ||
finish_reason: str | None = None, | ||
) -> ModelResponse: | ||
items: list[ModelResponsePart] = [] | ||
for part in parts: | ||
|
@@ -636,9 +643,7 @@ def _process_response_from_parts( | |
raise UnexpectedModelBehavior( | ||
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' | ||
) | ||
return ModelResponse( | ||
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details | ||
) | ||
return ModelResponse(parts=items, usage=usage, model_name=model_name, finish_reason=finish_reason) | ||
|
||
|
||
class _GeminiFunctionCall(TypedDict): | ||
|
@@ -760,7 +765,25 @@ class _GeminiCandidates(TypedDict): | |
"""See <https://ai.google.dev/api/generate-content#v1beta.Candidate>.""" | ||
|
||
content: NotRequired[_GeminiContent] | ||
finish_reason: NotRequired[Annotated[Literal['STOP', 'MAX_TOKENS', 'SAFETY'], pydantic.Field(alias='finishReason')]] | ||
finish_reason: NotRequired[ | ||
Annotated[ | ||
Literal[ | ||
'STOP', | ||
'MAX_TOKENS', | ||
'SAFETY', | ||
'MALFORMED_FUNCTION_CALL', | ||
'FINISH_REASON_UNSPECIFIED', | ||
'RECITATION', | ||
'LANGUAGE', | ||
'BLOCKLIST', | ||
'PROHIBITED_CONTENT', | ||
'SPII', | ||
'OTHER', | ||
'IMAGE_SAFETY', | ||
], | ||
pydantic.Field(alias='finishReason'), | ||
] | ||
] | ||
""" | ||
See <https://ai.google.dev/api/generate-content#FinishReason>, lots of other values are possible, | ||
but let's wait until we see them and know what they mean to add them here. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
from contextlib import asynccontextmanager | ||
from dataclasses import dataclass, field | ||
from datetime import datetime | ||
from typing import Any, Literal, Union, cast, overload | ||
from typing import Literal, Union, cast, overload | ||
from uuid import uuid4 | ||
|
||
from typing_extensions import assert_never | ||
|
@@ -91,6 +91,21 @@ | |
allow any name in the type hints. | ||
See [the Gemini API docs](https://ai.google.dev/gemini-api/docs/models/gemini#model-variations) for a full list. | ||
""" | ||
_FINISH_REASONS = { | ||
'STOP': 'stop', | ||
'MAX_TOKENS': 'length', | ||
'SAFETY': 'content_filter', | ||
'RECITATION': 'content_filter', | ||
'LANGUAGE': 'content_filter', | ||
'BLOCKLIST': 'content_filter', | ||
'PROHIBITED_CONTENT': 'content_filter', | ||
'SPII': 'content_filter', | ||
'MALFORMED_FUNCTION_CALL': 'error', # or 'tool_calls' if you prefer | ||
'OTHER': 'error', | ||
'FINISH_REASON_UNSPECIFIED': 'error', | ||
'IMAGE_SAFETY': 'content_filter', | ||
None: None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as up, I think we can drop this |
||
} | ||
|
||
|
||
class GoogleModelSettings(ModelSettings, total=False): | ||
|
@@ -275,6 +290,9 @@ async def _generate_content( | |
def _process_response(self, response: GenerateContentResponse) -> ModelResponse: | ||
if not response.candidates or len(response.candidates) != 1: | ||
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover | ||
finish_reason_key = response.candidates[0].finish_reason or None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the original code, we were calling |
||
finish_reason = _FINISH_REASONS.get(finish_reason_key, finish_reason_key) | ||
|
||
if response.candidates[0].content is None or response.candidates[0].content.parts is None: | ||
if response.candidates[0].finish_reason == 'SAFETY': | ||
raise UnexpectedModelBehavior('Safety settings triggered', str(response)) | ||
|
@@ -284,14 +302,14 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: | |
) # pragma: no cover | ||
parts = response.candidates[0].content.parts or [] | ||
vendor_id = response.response_id or None | ||
vendor_details: dict[str, Any] | None = None | ||
finish_reason = response.candidates[0].finish_reason | ||
if finish_reason: # pragma: no branch | ||
vendor_details = {'finish_reason': finish_reason.value} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As above, let's keep this one for backward compatibility. |
||
usage = _metadata_as_usage(response) | ||
usage.requests = 1 | ||
return _process_response_from_parts( | ||
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details | ||
parts, | ||
response.model_version or self._model_name, | ||
usage, | ||
vendor_id=vendor_id, | ||
finish_reason=finish_reason, | ||
) | ||
|
||
async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse: | ||
|
@@ -443,7 +461,7 @@ def _process_response_from_parts( | |
model_name: GoogleModelName, | ||
usage: usage.Usage, | ||
vendor_id: str | None, | ||
vendor_details: dict[str, Any] | None = None, | ||
finish_reason: str | None = None, | ||
) -> ModelResponse: | ||
items: list[ModelResponsePart] = [] | ||
for part in parts: | ||
|
@@ -460,7 +478,7 @@ def _process_response_from_parts( | |
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' | ||
) | ||
return ModelResponse( | ||
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details | ||
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, finish_reason=finish_reason | ||
) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can drop this line and it'll still work as expected.