Skip to content

Commit 8ded8a9

Browse files
authored
add overwrite mechanism for stream_options (#465)
fix issue #442 below is an example to overwrite include_usage ``` result = Runner.run_streamed( agent, "Write a haiku about recursion in programming.", run_config=RunConfig( model_provider=CUSTOM_MODEL_PROVIDER, model_settings=ModelSettings(include_usage=True) ), ) ```
1 parent 84fb734 commit 8ded8a9

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

src/agents/model_settings.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ class ModelSettings:
5454
"""Whether to store the generated model response for later retrieval.
5555
Defaults to True if not provided."""
5656

57+
include_usage: bool | None = None
58+
"""Whether to include usage chunk.
59+
Defaults to True if not provided."""
60+
5761
def resolve(self, override: ModelSettings | None) -> ModelSettings:
5862
"""Produce a new ModelSettings by overlaying any non-None values from the
5963
override on top of this instance."""

src/agents/models/openai_chatcompletions.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,8 @@ async def _fetch_response(
521521
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
522522
store = _Converter.get_store_param(self._get_client(), model_settings)
523523

524+
stream_options = _Converter.get_stream_options_param(self._get_client(), model_settings)
525+
524526
ret = await self._get_client().chat.completions.create(
525527
model=self.model,
526528
messages=converted_messages,
@@ -534,7 +536,7 @@ async def _fetch_response(
534536
response_format=response_format,
535537
parallel_tool_calls=parallel_tool_calls,
536538
stream=stream,
537-
stream_options={"include_usage": True} if stream else NOT_GIVEN,
539+
stream_options=self._non_null_or_not_given(stream_options),
538540
store=self._non_null_or_not_given(store),
539541
reasoning_effort=self._non_null_or_not_given(reasoning_effort),
540542
extra_headers=_HEADERS,
@@ -568,12 +570,27 @@ def _get_client(self) -> AsyncOpenAI:
568570

569571

570572
class _Converter:
573+
574+
@classmethod
575+
def is_openai(cls, client: AsyncOpenAI):
576+
return str(client.base_url).startswith("https://api.openai.com")
577+
571578
@classmethod
572579
def get_store_param(cls, client: AsyncOpenAI, model_settings: ModelSettings) -> bool | None:
573580
# Match the behavior of Responses where store is True when not given
574-
default_store = True if str(client.base_url).startswith("https://api.openai.com") else None
581+
default_store = True if cls.is_openai(client) else None
575582
return model_settings.store if model_settings.store is not None else default_store
576583

584+
@classmethod
585+
def get_stream_options_param(
586+
cls, client: AsyncOpenAI, model_settings: ModelSettings
587+
) -> dict[str, bool] | None:
588+
default_include_usage = True if cls.is_openai(client) else None
589+
include_usage = model_settings.include_usage if model_settings.include_usage is not None \
590+
else default_include_usage
591+
stream_options = {"include_usage": include_usage} if include_usage is not None else None
592+
return stream_options
593+
577594
@classmethod
578595
def convert_tool_choice(
579596
cls, tool_choice: Literal["auto", "required", "none"] | str | None

tests/test_openai_chatcompletions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __init__(self, completions: DummyCompletions) -> None:
282282
# Check OpenAI client was called for streaming
283283
assert completions.kwargs["stream"] is True
284284
assert completions.kwargs["store"] is NOT_GIVEN
285-
assert completions.kwargs["stream_options"] == {"include_usage": True}
285+
assert completions.kwargs["stream_options"] is NOT_GIVEN
286286
# Response is a proper openai Response
287287
assert isinstance(response, Response)
288288
assert response.id == FAKE_RESPONSES_ID

0 commit comments

Comments
 (0)