From 4581650bb4ff26d059ab81b6bdd722f8de7a047a Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 21:54:36 +0800 Subject: [PATCH 01/11] response prefill support for anthropic, deepseek and openrouter --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 9 ++++- .../pydantic_ai/agent/__init__.py | 10 +++++ .../pydantic_ai/agent/abstract.py | 20 +++++++++- .../pydantic_ai/models/__init__.py | 2 + .../pydantic_ai/models/anthropic.py | 28 +++++++++---- pydantic_ai_slim/pydantic_ai/models/openai.py | 28 ++++++++++--- .../pydantic_ai/profiles/__init__.py | 3 ++ .../pydantic_ai/profiles/anthropic.py | 2 +- .../pydantic_ai/profiles/openai.py | 4 ++ test_response_prefix_example.py | 39 +++++++++++++++++++ tests/models/test_anthropic.py | 24 ++++++++++++ tests/test_agent.py | 29 ++++++++++++++ 12 files changed, 182 insertions(+), 16 deletions(-) create mode 100644 test_response_prefix_example.py diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index e121ec475a..bc0a0b3838 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -167,6 +167,8 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]): system_prompt_functions: list[_system_prompt.SystemPromptRunner[DepsT]] system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[DepsT]] + response_prefix: str | None = None + async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] ) -> ModelRequestNode[DepsT, NodeRunEndT] | CallToolsNode[DepsT, NodeRunEndT]: @@ -225,7 +227,7 @@ async def run( instructions = await ctx.deps.get_instructions(run_context) next_message = _messages.ModelRequest(parts, instructions=instructions) - return ModelRequestNode[DepsT, NodeRunEndT](request=next_message) + return ModelRequestNode[DepsT, NodeRunEndT](request=next_message, response_prefix=self.response_prefix) async def _handle_message_history_model_response( self, @@ -333,6 +335,7 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], + response_prefix: str | None = None, ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" output_schema = ctx.deps.output_schema @@ -358,6 +361,7 @@ async def _prepare_request_parameters( output_tools=output_tools, output_object=output_object, allow_text_output=allow_text_output, + response_prefix=response_prefix, ) @@ -366,6 +370,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]): """The node that makes a request to the model using the last message in state.message_history.""" request: _messages.ModelRequest + response_prefix: str | None = None _result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None) _did_stream: bool = field(repr=False, init=False, default=False) @@ -442,7 +447,7 @@ async def _prepare_request( message_history = await _process_message_history(ctx.state, ctx.deps.history_processors, run_context) - model_request_parameters = await _prepare_request_parameters(ctx) + model_request_parameters = await _prepare_request_parameters(ctx, self.response_prefix) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) model_settings = ctx.deps.model_settings diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 615bc86350..0e9619dacf 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -476,6 +476,7 @@ async def iter( # noqa: C901 usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -549,6 +550,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -558,6 +560,13 @@ async def main(): model_used = self._get_model(model) del model + # Validate response_prefix support + if response_prefix is not None and not model_used.profile.supports_response_prefix: + raise exceptions.UserError( + f'Model {model_used.model_name} does not support response prefix. ' + 'Response prefix is only supported by certain models like Anthropic Claude and some OpenAI-compatible models.' + ) + deps = self._get_deps(deps) new_message_index = len(message_history) if message_history else 0 output_schema = self._prepare_output_schema(output_type, model_used.profile) @@ -663,6 +672,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: system_prompts=self._system_prompts, system_prompt_functions=self._system_prompt_functions, system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, + response_prefix=response_prefix, ) agent_name = self.name or 'agent' diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 2e18954360..2061b2d7eb 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -127,6 +127,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -145,6 +146,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -162,6 +164,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -194,6 +197,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -214,6 +218,7 @@ async def main(): usage_limits=usage_limits, usage=usage, toolsets=toolsets, + response_prefix=response_prefix, ) as agent_run: async for node in agent_run: if event_stream_handler is not None and ( @@ -241,6 +246,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -259,6 +265,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -276,6 +283,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -307,6 +315,7 @@ def run_sync( infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -328,6 +337,7 @@ def run_sync( infer_name=False, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, ) ) @@ -347,6 +357,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -365,6 +376,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -383,6 +395,7 @@ async def run_stream( # noqa C901 infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async streaming mode. @@ -448,6 +461,7 @@ async def main(): usage=usage, infer_name=False, toolsets=toolsets, + response_prefix=response_prefix, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -557,6 +571,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload @@ -574,6 +589,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -592,6 +608,7 @@ async def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -665,7 +682,8 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. - + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. + Returns: The result of the run. """ diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 15638f385b..532e00a06c 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -380,6 +380,8 @@ class ModelRequestParameters: output_tools: list[ToolDefinition] = field(default_factory=list) allow_text_output: bool = True + response_prefix: str | None = None + @cached_property def tool_defs(self) -> dict[str, ToolDefinition]: return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]} diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 509b7fdf59..ed16e82967 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -203,7 +203,7 @@ async def request( response = await self._messages_create( messages, False, cast(AnthropicModelSettings, model_settings or {}), model_request_parameters ) - model_response = self._process_response(response) + model_response = self._process_response(response, model_request_parameters) return model_response @asynccontextmanager @@ -268,6 +268,10 @@ async def _messages_create( system_prompt, anthropic_messages = await self._map_message(messages) + # Add response prefix as assistant message if provided + if model_request_parameters.response_prefix: + anthropic_messages.append({'role': 'assistant', 'content': model_request_parameters.response_prefix}) + try: extra_headers = model_settings.get('extra_headers', {}) for k, v in tool_headers.items(): @@ -296,12 +300,18 @@ async def _messages_create( raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover - def _process_response(self, response: BetaMessage) -> ModelResponse: + def _process_response( + self, response: BetaMessage, model_request_parameters: ModelRequestParameters + ) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" items: list[ModelResponsePart] = [] - for item in response.content: + for i, item in enumerate(response.content): if isinstance(item, BetaTextBlock): - items.append(TextPart(content=item.text)) + content = item.text + # Prepend response prefix to the first text block if provided + if i == 0 and model_request_parameters.response_prefix: + content = model_request_parameters.response_prefix + content + items.append(TextPart(content=content)) elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock): items.append( BuiltinToolReturnPart( @@ -616,6 +626,7 @@ class AnthropicStreamedResponse(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 current_block: BetaContentBlock | None = None + first_text_delta = True async for event in self._response: if isinstance(event, BetaRawMessageStartEvent): @@ -659,9 +670,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - maybe_event = self._parts_manager.handle_text_delta( - vendor_part_id=event.index, content=event.delta.text - ) + content = event.delta.text + # Prepend response prefix to the first text delta if provided + if first_text_delta and self.model_request_parameters.response_prefix: + content = self.model_request_parameters.response_prefix + content + first_text_delta = False + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=event.index, content=content) if maybe_event is not None: # pragma: no branch yield maybe_event elif isinstance(event.delta, BetaThinkingDelta): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index a057758018..a3b5a150cb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -363,7 +363,7 @@ async def request( response = await self._completions_create( messages, False, cast(OpenAIChatModelSettings, model_settings or {}), model_request_parameters ) - model_response = self._process_response(response) + model_response = self._process_response(response, model_request_parameters) return model_response @asynccontextmanager @@ -419,7 +419,7 @@ async def _completions_create( else: tool_choice = 'auto' - openai_messages = await self._map_messages(messages) + openai_messages = await self._map_messages(messages, model_request_parameters) response_format: chat.completion_create_params.ResponseFormat | None = None if model_request_parameters.output_mode == 'native': @@ -471,7 +471,9 @@ async def _completions_create( raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e raise # pragma: lax no cover - def _process_response(self, response: chat.ChatCompletion | str) -> ModelResponse: + def _process_response( + self, response: chat.ChatCompletion | str, model_request_parameters: ModelRequestParameters + ) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" # Although the OpenAI SDK claims to return a Pydantic model (`ChatCompletion`) from the chat completions function: # * it hasn't actually performed validation (presumably they're creating the model with `model_construct` or something?!) @@ -523,9 +525,13 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons ] if choice.message.content is not None: + content = choice.message.content + # Prepend response prefix if provided + if model_request_parameters.response_prefix: + content = model_request_parameters.response_prefix + content items.extend( (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part) - for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags) + for part in split_content_into_text_and_thinking(content, self.profile.thinking_tags) ) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: @@ -601,7 +607,9 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete f'`{tool.__class__.__name__}` is not supported by `OpenAIChatModel`. If it should be, please file an issue.' ) - async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCompletionMessageParam]: + async def _map_messages( + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters | None = None + ) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" openai_messages: list[chat.ChatCompletionMessageParam] = [] for message in messages: @@ -641,6 +649,11 @@ async def _map_messages(self, messages: list[ModelMessage]) -> list[chat.ChatCom assert_never(message) if instructions := self._get_instructions(messages): openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) + + # Add response prefix as assistant message if provided + if model_request_parameters and model_request_parameters.response_prefix: + openai_messages.append({'role': 'assistant', 'content': model_request_parameters.response_prefix}) + return openai_messages @staticmethod @@ -1230,6 +1243,7 @@ class OpenAIStreamedResponse(StreamedResponse): _provider_name: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + first_text_delta = True async for chunk in self._response: self._usage += _map_usage(chunk) @@ -1252,6 +1266,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: + # Prepend response prefix to the first text delta if provided + if first_text_delta and self.model_request_parameters.response_prefix: + content = self.model_request_parameters.response_prefix + content + first_text_delta = False maybe_event = self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, diff --git a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py index 9915ecf04f..87dba8a0c6 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/__init__.py @@ -55,6 +55,9 @@ class ModelProfile: This is currently only used by `OpenAIChatModel`, `HuggingFaceModel`, and `GroqModel`. """ + supports_response_prefix: bool = False + """Whether the model supports response prefix (prefill) functionality.""" + @classmethod def from_profile(cls, profile: ModelProfile | None) -> Self: """Build a ModelProfile subclass instance from a ModelProfile instance.""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py index f6a2755819..76e117c718 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/anthropic.py @@ -5,4 +5,4 @@ def anthropic_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for an Anthropic model.""" - return ModelProfile(thinking_tags=('', '')) + return ModelProfile(thinking_tags=('', ''), supports_response_prefix=True) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index da18c0f097..860e138f17 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -77,12 +77,16 @@ def openai_model_profile(model_name: str) -> ModelProfile: # See https://github.com/pydantic/pydantic-ai/issues/974 for more details. openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None + # Enable response prefix for DeepSeek and OpenRouter models + supports_response_prefix = 'deepseek' in model_name.lower() or 'openrouter' in model_name.lower() + return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, supports_json_schema_output=True, supports_json_object_output=True, openai_unsupported_model_settings=openai_unsupported_model_settings, openai_system_prompt_role=openai_system_prompt_role, + supports_response_prefix=supports_response_prefix, openai_chat_supports_web_search=supports_web_search, ) diff --git a/test_response_prefix_example.py b/test_response_prefix_example.py new file mode 100644 index 0000000000..0a6f42e9de --- /dev/null +++ b/test_response_prefix_example.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python3 +"""Example script demonstrating the response prefix feature in Pydantic AI.""" + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + + +def test_response_prefix(): + """Test the response prefix feature with validation.""" + # Test that unsupported models raise an error + agent = Agent(TestModel()) + + try: + agent.run_sync('Hello', response_prefix='Assistant: ') + assert False, 'Should have raised UserError' + except Exception as e: + print(f'✅ Validation works: {e}') + + # Create a mock model that supports response prefix + class MockResponsePrefixModel(TestModel): + @property + def profile(self): # pyright: ignore[reportIncompatibleVariableOverride] + profile = super().profile + profile.supports_response_prefix = True + return profile + + # Create an agent with the mock model + agent = Agent(MockResponsePrefixModel()) + + # Test that the parameter is accepted without error + result = agent.run_sync('Hello', response_prefix='Assistant: ') + print('✅ Response prefix parameter accepted by supported model') + print(f'Response: {result.output}') + + print('✅ Response prefix feature working correctly!') + + +if __name__ == '__main__': + test_response_prefix() diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 8ddfabb3f2..50042e0fb9 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -2718,3 +2718,27 @@ async def test_anthropic_web_search_tool_stream(allow_model_requests: None, anth PartDeltaEvent(index=17, delta=TextPartDelta(content_delta=' disruptions affecting North America.')), ] ) + + +async def test_anthropic_response_prefix(allow_model_requests: None): + """Test that Anthropic models correctly handle response prefix.""" + m = AnthropicModel('claude-sonnet-4-0') + agent = Agent(m) + + # Test non-streaming response + result = await agent.run('Hello', response_prefix='Assistant: ') + assert result.output.startswith('Assistant: ') + + # Test streaming response + event_parts: list[Any] = [] + async with agent.iter(user_prompt='Hello', response_prefix='Assistant: ') as agent_run: + async for node in agent_run: + if Agent.is_model_request_node(node): + async with node.stream(agent_run.ctx) as request_stream: + async for event in request_stream: + event_parts.append(event) + + # Check that the first text part starts with the prefix + text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] + assert len(text_parts) > 0 + assert text_parts[0].part.content.startswith('Assistant: ') diff --git a/tests/test_agent.py b/tests/test_agent.py index a48b47da79..0dffe1d1d4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4802,3 +4802,32 @@ def test_tool_requires_approval_error(): @agent.tool_plain(requires_approval=True) def delete_file(path: str) -> None: pass + + +def test_response_prefix_validation(): + """Test that response_prefix raises an error for unsupported models.""" + # Test with a model that doesn't support response prefix + agent = Agent(TestModel()) + + with pytest.raises(UserError, match='Model test does not support response prefix'): + agent.run_sync('Hello', response_prefix='Assistant: ') + + +def test_response_prefix_parameter_passed(): + """Test that response_prefix parameter is accepted by run methods.""" + # Test with a model that supports response prefix + from pydantic_ai.models.test import TestModel + + # Create a mock model that supports response prefix + class MockResponsePrefixModel(TestModel): + @property + def profile(self): + profile = super().profile + profile.supports_response_prefix = True + return profile + + agent = Agent(MockResponsePrefixModel()) + + # This should not raise an error + result = agent.run_sync('Hello', response_prefix='Assistant: ') + assert result.output is not None From 02ddf5e1d9c61a8aecfea871cf408f2863a78850 Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 22:12:11 +0800 Subject: [PATCH 02/11] fix: lint, test, pyright --- .../pydantic_ai/agent/__init__.py | 2 ++ pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 5 +++++ .../durable_exec/temporal/_agent.py | 20 +++++++++++++++++++ tests/models/test_anthropic.py | 2 +- tests/models/test_fallback.py | 3 +++ tests/models/test_instrumented.py | 5 +++++ tests/models/test_model_request_parameters.py | 1 + tests/test_agent.py | 2 +- tests/test_logfire.py | 3 +++ 9 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 0e9619dacf..94986d4440 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -442,6 +442,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload @@ -459,6 +460,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index e53ead8cef..8a7aab728f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -81,6 +81,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @overload @@ -98,6 +99,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -115,6 +117,7 @@ async def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -188,6 +191,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -204,6 +208,7 @@ async def main(): usage=usage, infer_name=infer_name, toolsets=toolsets, + response_prefix=response_prefix, ) as run: yield run diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index 0b487e39b5..b0e7492b24 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -267,6 +267,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -285,6 +286,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -302,6 +304,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -335,6 +338,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -358,6 +362,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler or self.event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -377,6 +382,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -395,6 +401,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -412,6 +419,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -444,6 +452,7 @@ def run_sync( infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -466,6 +475,7 @@ def run_sync( infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -485,6 +495,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -503,6 +514,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -521,6 +533,7 @@ async def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -551,6 +564,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -575,6 +589,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) as result: yield result @@ -594,6 +609,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -612,6 +628,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -630,6 +647,7 @@ async def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -704,6 +722,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -737,6 +756,7 @@ async def main(): usage=usage, infer_name=infer_name, toolsets=toolsets, + response_prefix=response_prefix, **_deprecated_kwargs, ) as run: yield run diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 50042e0fb9..d2eeaa3c3c 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -2741,4 +2741,4 @@ async def test_anthropic_response_prefix(allow_model_requests: None): # Check that the first text part starts with the prefix text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] assert len(text_parts) > 0 - assert text_parts[0].part.content.startswith('Assistant: ') + assert cast(TextPart, text_parts[0].part).content.startswith('Assistant: ') diff --git a/tests/models/test_fallback.py b/tests/models/test_fallback.py index 484a73ac37..bd7f773574 100644 --- a/tests/models/test_fallback.py +++ b/tests/models/test_fallback.py @@ -139,6 +139,7 @@ def test_first_failed_instrumented(capfire: CaptureLogfire) -> None: 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function:failure_response:,function:success_response:', @@ -238,6 +239,7 @@ async def test_first_failed_instrumented_stream(capfire: CaptureLogfire) -> None 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.span_type': 'span', 'logfire.msg': 'chat fallback:function::failure_response_stream,function::success_response_stream', @@ -344,6 +346,7 @@ def test_all_failed_instrumented(capfire: CaptureLogfire) -> None: 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.json_schema': { 'type': 'object', diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 9ea6fadc47..4d244e890c 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -176,6 +176,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.json_schema': { 'type': 'object', @@ -407,6 +408,7 @@ async def test_instrumented_model_stream(capfire: CaptureLogfire): 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.json_schema': { 'type': 'object', @@ -505,6 +507,7 @@ async def test_instrumented_model_stream_break(capfire: CaptureLogfire): 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.json_schema': { 'type': 'object', @@ -623,6 +626,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire, instr 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat gpt-4o', @@ -749,6 +753,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire, instr 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'gen_ai.request.temperature': 1, 'logfire.msg': 'chat gpt-4o', diff --git a/tests/models/test_model_request_parameters.py b/tests/models/test_model_request_parameters.py index 2915796ab1..72951a7c08 100644 --- a/tests/models/test_model_request_parameters.py +++ b/tests/models/test_model_request_parameters.py @@ -14,4 +14,5 @@ def test_model_request_parameters_are_serializable(): 'allow_text_output': True, 'output_tools': [], 'output_object': None, + 'response_prefix': None, } diff --git a/tests/test_agent.py b/tests/test_agent.py index 0dffe1d1d4..5a9810789f 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4821,7 +4821,7 @@ def test_response_prefix_parameter_passed(): # Create a mock model that supports response prefix class MockResponsePrefixModel(TestModel): @property - def profile(self): + def profile(self): # pyright: ignore[reportIncompatibleVariableOverride] profile = super().profile profile.supports_response_prefix = True return profile diff --git a/tests/test_logfire.py b/tests/test_logfire.py index a4a7237ac0..c28fb8795f 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -394,6 +394,7 @@ async def my_ret(x: int) -> str: 'output_tools': [], 'output_object': None, 'allow_text_output': True, + 'response_prefix': None, } ) ), @@ -783,6 +784,7 @@ class MyOutput: } ], 'allow_text_output': False, + 'response_prefix': None, } ) ), @@ -887,6 +889,7 @@ async def test_feedback(capfire: CaptureLogfire) -> None: 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.span_type': 'span', 'logfire.msg': 'chat test', From 91008c23171276d17c59ff8eb28d52a13bfa4880 Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 22:25:36 +0800 Subject: [PATCH 03/11] fix: lint, test, pyright --- docs/agents.md | 6 ++++-- pydantic_ai_slim/pydantic_ai/agent/__init__.py | 3 ++- pydantic_ai_slim/pydantic_ai/agent/abstract.py | 6 ++++-- pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 3 ++- pydantic_ai_slim/pydantic_ai/run.py | 6 ++++-- tests/models/test_anthropic.py | 12 ++++++------ 6 files changed, 22 insertions(+), 14 deletions(-) diff --git a/docs/agents.md b/docs/agents.md index 83d2658cc8..633d307380 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -294,7 +294,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( @@ -357,7 +358,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 94986d4440..9a52272314 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -522,7 +522,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 2061b2d7eb..7ecb051a1d 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -437,6 +437,7 @@ async def main(): event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. Note that it does _not_ receive any events after the final result is found. + response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. Returns: The result of the run. @@ -652,7 +653,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( @@ -683,7 +685,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models. - + Returns: The result of the run. """ diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 8a7aab728f..6fd13a80ee 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -161,7 +161,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 7ed6b848c0..034dc880ed 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -62,7 +62,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( @@ -197,7 +198,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index d2eeaa3c3c..aa4b13f64d 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -2720,18 +2720,18 @@ async def test_anthropic_web_search_tool_stream(allow_model_requests: None, anth ) -async def test_anthropic_response_prefix(allow_model_requests: None): +async def test_anthropic_response_prefix(allow_model_requests: None, anthropic_api_key: str): """Test that Anthropic models correctly handle response prefix.""" - m = AnthropicModel('claude-sonnet-4-0') + m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) agent = Agent(m) # Test non-streaming response - result = await agent.run('Hello', response_prefix='Assistant: ') - assert result.output.startswith('Assistant: ') + result = await agent.run('What is your favorite color?', response_prefix='My favorite color is') + assert result.output.startswith('My favorite color is') # Test streaming response event_parts: list[Any] = [] - async with agent.iter(user_prompt='Hello', response_prefix='Assistant: ') as agent_run: + async with agent.iter(user_prompt='Hello', response_prefix='My favorite color is') as agent_run: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as request_stream: @@ -2741,4 +2741,4 @@ async def test_anthropic_response_prefix(allow_model_requests: None): # Check that the first text part starts with the prefix text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] assert len(text_parts) > 0 - assert cast(TextPart, text_parts[0].part).content.startswith('Assistant: ') + assert cast(TextPart, text_parts[0].part).content.startswith('My favorite color is') From eae47053398840ac64b9fd8345626097506505ec Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 22:44:01 +0800 Subject: [PATCH 04/11] fix: tests --- docs/agents.md | 2 + .../pydantic_ai/agent/__init__.py | 1 + .../pydantic_ai/agent/abstract.py | 1 + pydantic_ai_slim/pydantic_ai/agent/wrapper.py | 1 + .../durable_exec/temporal/_agent.py | 4 +- pydantic_ai_slim/pydantic_ai/run.py | 2 + .../test_anthropic_response_prefix.yaml | 148 ++++++++++++++++++ tests/models/test_anthropic.py | 10 +- 8 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml diff --git a/docs/agents.md b/docs/agents.md index 633d307380..20f20d8470 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -285,6 +285,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( @@ -349,6 +350,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( diff --git a/pydantic_ai_slim/pydantic_ai/agent/__init__.py b/pydantic_ai_slim/pydantic_ai/agent/__init__.py index 9a52272314..db43f186ce 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/agent/__init__.py @@ -513,6 +513,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( diff --git a/pydantic_ai_slim/pydantic_ai/agent/abstract.py b/pydantic_ai_slim/pydantic_ai/agent/abstract.py index 7ecb051a1d..1de0404c3f 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/agent/abstract.py @@ -644,6 +644,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( diff --git a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py index 6fd13a80ee..f80e5f0e70 100644 --- a/pydantic_ai_slim/pydantic_ai/agent/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/agent/wrapper.py @@ -152,6 +152,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py index b0e7492b24..373c8a9189 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py @@ -683,6 +683,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( @@ -692,7 +693,8 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( diff --git a/pydantic_ai_slim/pydantic_ai/run.py b/pydantic_ai_slim/pydantic_ai/run.py index 034dc880ed..1718394331 100644 --- a/pydantic_ai_slim/pydantic_ai/run.py +++ b/pydantic_ai_slim/pydantic_ai/run.py @@ -53,6 +53,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( @@ -189,6 +190,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml new file mode 100644 index 0000000000..b65acadcfd --- /dev/null +++ b/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml @@ -0,0 +1,148 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '215' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 4096 + messages: + - content: + - text: 'What is the name of color #FF0000' + type: text + role: user + - content: It's name is + role: assistant + model: claude-3-5-sonnet-latest + stream: false + uri: https://api.anthropic.com/v1/messages?beta=true + response: + headers: + connection: + - keep-alive + content-length: + - '601' + content-type: + - application/json + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + content: + - text: |2- + Red. #FF0000 is the hexadecimal color code for pure red in RGB color space, where: + - FF represents the maximum value (255) for the red channel + - 00 represents no green + - 00 represents no blue + type: text + id: msg_0139bvQ8xYiX5eFuQo2kt2uJ + model: claude-3-5-sonnet-20241022 + role: assistant + stop_reason: end_turn + stop_sequence: null + type: message + usage: + cache_creation: + ephemeral_1h_input_tokens: 0 + ephemeral_5m_input_tokens: 0 + cache_creation_input_tokens: 0 + cache_read_input_tokens: 0 + input_tokens: 21 + output_tokens: 58 + service_tier: standard + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '186' + content-type: + - application/json + host: + - api.anthropic.com + method: POST + parsed_body: + max_tokens: 4096 + messages: + - content: + - text: Hello + type: text + role: user + - content: It's name is + role: assistant + model: claude-3-5-sonnet-latest + stream: true + uri: https://api.anthropic.com/v1/messages?beta=true + response: + body: + string: |+ + event: message_start + data: {"type":"message_start","message":{"id":"msg_01LrWvsWg9nVGu2ZAabNiKgH","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":3,"service_tier":"standard"}} } + + event: content_block_start + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Claude. It"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" was created by Anthropic and aims to be direct"} } + + event: ping + data: {"type": "ping"} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" and honest in its interactions."} } + + event: ping + data: {"type": "ping"} + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } + + event: ping + data: {"type": "ping"} + + event: ping + data: {"type": "ping"} + + event: message_delta + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":23} } + + event: message_stop + data: {"type":"message_stop" } + + headers: + cache-control: + - no-cache + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index aa4b13f64d..10e10a8ffb 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -2726,19 +2726,19 @@ async def test_anthropic_response_prefix(allow_model_requests: None, anthropic_a agent = Agent(m) # Test non-streaming response - result = await agent.run('What is your favorite color?', response_prefix='My favorite color is') - assert result.output.startswith('My favorite color is') + result = await agent.run('What is the name of color #FF0000', response_prefix="It's name is") + assert result.output.startswith("It's name is") # Test streaming response event_parts: list[Any] = [] - async with agent.iter(user_prompt='Hello', response_prefix='My favorite color is') as agent_run: + async with agent.iter(user_prompt='Hello', response_prefix="It's name is") as agent_run: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as request_stream: async for event in request_stream: event_parts.append(event) - # Check that the first text part starts with the prefix + # Check that the first text part xpstarts with the prefix text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] assert len(text_parts) > 0 - assert cast(TextPart, text_parts[0].part).content.startswith('My favorite color is') + assert cast(TextPart, text_parts[0].part).content.startswith("It's name is") From 153bc571e21af20d160d78fda74a5f227f7f01e4 Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 22:54:47 +0800 Subject: [PATCH 05/11] test: more coverage --- tests/models/test_openai.py | 48 +++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 2b7e6eb32f..31403c440c 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -22,6 +22,7 @@ ImageUrl, ModelRequest, ModelResponse, + PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, @@ -2929,3 +2930,50 @@ def test_deprecated_openai_model(openai_api_key: str): provider = OpenAIProvider(api_key=openai_api_key) OpenAIModel('gpt-4o', provider=provider) # type: ignore[reportDeprecated] + + +async def test_openai_response_prefix(allow_model_requests: None): + """Test that OpenAI models correctly handle response prefix.""" + c = completion_message( + ChatCompletionMessage(content='Red', role='assistant'), + ) + mock_client = MockOpenAI.create_mock(c) + # Use a model name that supports response prefix (DeepSeek models do) + m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + # Test non-streaming response + result = await agent.run('What is the name of color #FF0000', response_prefix="It's name is ") + assert result.output == "It's name is Red" + + # Verify that the response prefix was added to the request + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + assert 'messages' in kwargs + messages = kwargs['messages'] + # Should have user message and assistant message with prefix + assert len(messages) == 2 + assert messages[0]['role'] == 'user' + assert messages[1]['role'] == 'assistant' + assert messages[1]['content'] == "It's name is " + + +async def test_openai_response_prefix_stream(allow_model_requests: None): + """Test that OpenAI models correctly handle response prefix in streaming.""" + stream = [text_chunk('Red'), chunk([])] + mock_client = MockOpenAI.create_mock_stream(stream) + # Use a model name that supports response prefix (DeepSeek models do) + m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m) + + event_parts: list[Any] = [] + async with agent.iter(user_prompt='What is the name of color #FF0000', response_prefix="It's name is ") as agent_run: + async for node in agent_run: + if Agent.is_model_request_node(node): + async with node.stream(agent_run.ctx) as request_stream: + async for event in request_stream: + event_parts.append(event) + + # Check that the first text part starts with the prefix + text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] + assert len(text_parts) > 0 + assert cast(TextPart, text_parts[0].part).content == "It's name is Red" From 3e9ccbd2ebffaf393bb0ca608c020a306a2ccb17 Mon Sep 17 00:00:00 2001 From: yf-yang Date: Sat, 6 Sep 2025 22:59:58 +0800 Subject: [PATCH 06/11] fix: lint --- tests/models/test_openai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 31403c440c..b42922268c 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2966,7 +2966,9 @@ async def test_openai_response_prefix_stream(allow_model_requests: None): agent = Agent(m) event_parts: list[Any] = [] - async with agent.iter(user_prompt='What is the name of color #FF0000', response_prefix="It's name is ") as agent_run: + async with agent.iter( + user_prompt='What is the name of color #FF0000', response_prefix="It's name is " + ) as agent_run: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as request_stream: From c58795f525165cae5b60c3b5516206e31d7440db Mon Sep 17 00:00:00 2001 From: yf-yang Date: Thu, 11 Sep 2025 10:46:40 +0800 Subject: [PATCH 07/11] fix: apply code review comments --- .../pydantic_ai/durable_exec/dbos/_agent.py | 26 +++++++++++++++++++ .../pydantic_ai/models/anthropic.py | 16 +++++++----- pydantic_ai_slim/pydantic_ai/models/openai.py | 17 +++++------- .../pydantic_ai/profiles/deepseek.py | 5 +++- .../pydantic_ai/profiles/openai.py | 4 --- .../pydantic_ai/providers/openrouter.py | 4 ++- 6 files changed, 49 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index 9db6c9b109..cc6e904fe7 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -115,6 +115,7 @@ async def wrapped_run_workflow( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: with self._dbos_overrides(): @@ -131,6 +132,7 @@ async def wrapped_run_workflow( infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -152,6 +154,7 @@ def wrapped_run_sync_workflow( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: with self._dbos_overrides(): @@ -168,6 +171,7 @@ def wrapped_run_sync_workflow( infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -240,6 +244,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -258,6 +263,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -275,6 +281,7 @@ async def run( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -308,6 +315,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. + response_prefix: Optional response prefix to use for this run. Returns: The result of the run. @@ -325,6 +333,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -344,6 +353,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -362,6 +372,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -379,6 +390,7 @@ def run_sync( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -411,6 +423,7 @@ def run_sync( infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. + response_prefix: Optional response prefix to use for this run. Returns: The result of the run. @@ -428,6 +441,7 @@ def run_sync( infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) @@ -447,6 +461,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -465,6 +480,7 @@ def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, ) -> AbstractAsyncContextManager[StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -483,6 +499,7 @@ async def run_stream( infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, event_stream_handler: EventStreamHandler[AgentDepsT] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -513,6 +530,7 @@ async def main(): infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. event_stream_handler: Optional event stream handler to use for this run. It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager. + response_prefix: Optional response prefix to use for this run. Returns: The result of the run. @@ -537,6 +555,7 @@ async def main(): infer_name=infer_name, toolsets=toolsets, event_stream_handler=event_stream_handler, + response_prefix=response_prefix, **_deprecated_kwargs, ) as result: yield result @@ -556,6 +575,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -574,6 +594,7 @@ def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -592,6 +613,7 @@ async def iter( usage: _usage.RunUsage | None = None, infer_name: bool = True, toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + response_prefix: str | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -627,6 +649,7 @@ async def main(): system_prompts=(), system_prompt_functions=[], system_prompt_dynamic_functions={}, + response_prefix=None, ), ModelRequestNode( request=ModelRequest( @@ -637,6 +660,7 @@ async def main(): ) ] ) + response_prefix=None, ), CallToolsNode( model_response=ModelResponse( @@ -666,6 +690,7 @@ async def main(): usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. toolsets: Optional additional toolsets for this run. + response_prefix: Optional response prefix to use for this run. Returns: The result of the run. @@ -688,6 +713,7 @@ async def main(): usage=usage, infer_name=infer_name, toolsets=toolsets, + response_prefix=response_prefix, **_deprecated_kwargs, ) as run: yield run diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index ed16e82967..19a3cb5587 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -626,7 +626,12 @@ class AnthropicStreamedResponse(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 current_block: BetaContentBlock | None = None - first_text_delta = True + + # Handle response prefix by emitting it as the first text event + if response_prefix := self.model_request_parameters.response_prefix: + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix) + if maybe_event is not None: + yield maybe_event async for event in self._response: if isinstance(event, BetaRawMessageStartEvent): @@ -670,12 +675,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - content = event.delta.text - # Prepend response prefix to the first text delta if provided - if first_text_delta and self.model_request_parameters.response_prefix: - content = self.model_request_parameters.response_prefix + content - first_text_delta = False - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=event.index, content=content) + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id=event.index, content=event.delta.text + ) if maybe_event is not None: # pragma: no branch yield maybe_event elif isinstance(event.delta, BetaThinkingDelta): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index a3b5a150cb..9f778df5bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -525,13 +525,9 @@ def _process_response( ] if choice.message.content is not None: - content = choice.message.content - # Prepend response prefix if provided - if model_request_parameters.response_prefix: - content = model_request_parameters.response_prefix + content items.extend( (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part) - for part in split_content_into_text_and_thinking(content, self.profile.thinking_tags) + for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags) ) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: @@ -1243,7 +1239,12 @@ class OpenAIStreamedResponse(StreamedResponse): _provider_name: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - first_text_delta = True + # Handle response prefix by emitting it as the first text event + if response_prefix := self.model_request_parameters.response_prefix: + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix) + if maybe_event is not None: + yield maybe_event + async for chunk in self._response: self._usage += _map_usage(chunk) @@ -1266,10 +1267,6 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content if content is not None: - # Prepend response prefix to the first text delta if provided - if first_text_delta and self.model_request_parameters.response_prefix: - content = self.model_request_parameters.response_prefix + content - first_text_delta = False maybe_event = self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, diff --git a/pydantic_ai_slim/pydantic_ai/profiles/deepseek.py b/pydantic_ai_slim/pydantic_ai/profiles/deepseek.py index 92e166964d..aba8b4139b 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/deepseek.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/deepseek.py @@ -5,4 +5,7 @@ def deepseek_model_profile(model_name: str) -> ModelProfile | None: """Get the model profile for a DeepSeek model.""" - return ModelProfile(ignore_streamed_leading_whitespace='r1' in model_name) + return ModelProfile( + ignore_streamed_leading_whitespace='r1' in model_name, + supports_response_prefix=True, + ) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/openai.py b/pydantic_ai_slim/pydantic_ai/profiles/openai.py index 860e138f17..da18c0f097 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/openai.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/openai.py @@ -77,16 +77,12 @@ def openai_model_profile(model_name: str) -> ModelProfile: # See https://github.com/pydantic/pydantic-ai/issues/974 for more details. openai_system_prompt_role = 'user' if model_name.startswith('o1-mini') else None - # Enable response prefix for DeepSeek and OpenRouter models - supports_response_prefix = 'deepseek' in model_name.lower() or 'openrouter' in model_name.lower() - return OpenAIModelProfile( json_schema_transformer=OpenAIJsonSchemaTransformer, supports_json_schema_output=True, supports_json_object_output=True, openai_unsupported_model_settings=openai_unsupported_model_settings, openai_system_prompt_role=openai_system_prompt_role, - supports_response_prefix=supports_response_prefix, openai_chat_supports_web_search=supports_web_search, ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/openrouter.py b/pydantic_ai_slim/pydantic_ai/providers/openrouter.py index 96b0602e36..f434c87948 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/openrouter.py +++ b/pydantic_ai_slim/pydantic_ai/providers/openrouter.py @@ -70,7 +70,9 @@ def model_profile(self, model_name: str) -> ModelProfile | None: # As OpenRouterProvider is always used with OpenAIChatModel, which used to unconditionally use OpenAIJsonSchemaTransformer, # we need to maintain that behavior unless json_schema_transformer is set explicitly - return OpenAIModelProfile(json_schema_transformer=OpenAIJsonSchemaTransformer).update(profile) + return OpenAIModelProfile( + json_schema_transformer=OpenAIJsonSchemaTransformer, supports_response_prefix=True + ).update(profile) @overload def __init__(self) -> None: ... From 8ba3d250b758ff60c53ab2fd49e05314dd2db6cc Mon Sep 17 00:00:00 2001 From: yf-yang Date: Thu, 11 Sep 2025 10:58:57 +0800 Subject: [PATCH 08/11] fix: tests --- .../pydantic_ai/durable_exec/dbos/_agent.py | 2 +- tests/models/test_openai.py | 49 ------------------- 2 files changed, 1 insertion(+), 50 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py index cc6e904fe7..fe05858dea 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py @@ -659,7 +659,7 @@ async def main(): timestamp=datetime.datetime(...), ) ] - ) + ), response_prefix=None, ), CallToolsNode( diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index b42922268c..f726e84aa3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -2930,52 +2930,3 @@ def test_deprecated_openai_model(openai_api_key: str): provider = OpenAIProvider(api_key=openai_api_key) OpenAIModel('gpt-4o', provider=provider) # type: ignore[reportDeprecated] - - -async def test_openai_response_prefix(allow_model_requests: None): - """Test that OpenAI models correctly handle response prefix.""" - c = completion_message( - ChatCompletionMessage(content='Red', role='assistant'), - ) - mock_client = MockOpenAI.create_mock(c) - # Use a model name that supports response prefix (DeepSeek models do) - m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) - agent = Agent(m) - - # Test non-streaming response - result = await agent.run('What is the name of color #FF0000', response_prefix="It's name is ") - assert result.output == "It's name is Red" - - # Verify that the response prefix was added to the request - kwargs = get_mock_chat_completion_kwargs(mock_client)[0] - assert 'messages' in kwargs - messages = kwargs['messages'] - # Should have user message and assistant message with prefix - assert len(messages) == 2 - assert messages[0]['role'] == 'user' - assert messages[1]['role'] == 'assistant' - assert messages[1]['content'] == "It's name is " - - -async def test_openai_response_prefix_stream(allow_model_requests: None): - """Test that OpenAI models correctly handle response prefix in streaming.""" - stream = [text_chunk('Red'), chunk([])] - mock_client = MockOpenAI.create_mock_stream(stream) - # Use a model name that supports response prefix (DeepSeek models do) - m = OpenAIChatModel('deepseek-chat', provider=OpenAIProvider(openai_client=mock_client)) - agent = Agent(m) - - event_parts: list[Any] = [] - async with agent.iter( - user_prompt='What is the name of color #FF0000', response_prefix="It's name is " - ) as agent_run: - async for node in agent_run: - if Agent.is_model_request_node(node): - async with node.stream(agent_run.ctx) as request_stream: - async for event in request_stream: - event_parts.append(event) - - # Check that the first text part starts with the prefix - text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] - assert len(text_parts) > 0 - assert cast(TextPart, text_parts[0].part).content == "It's name is Red" From 344369fd3d08d6c8611c7e70d5ce13da2f064057 Mon Sep 17 00:00:00 2001 From: yf-yang Date: Thu, 11 Sep 2025 11:00:58 +0800 Subject: [PATCH 09/11] fix: lint --- tests/models/test_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index f726e84aa3..2b7e6eb32f 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -22,7 +22,6 @@ ImageUrl, ModelRequest, ModelResponse, - PartStartEvent, RetryPromptPart, SystemPromptPart, TextPart, From a370d2311c4b07159a8af0fd679a058026f6e2b6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 16 Sep 2025 18:53:54 +0000 Subject: [PATCH 10/11] Fix snapshot --- tests/models/test_instrumented.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index 4d244e890c..5a350ee866 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -1321,6 +1321,7 @@ async def test_response_cost_error(capfire: CaptureLogfire, monkeypatch: pytest. 'output_object': None, 'output_tools': [], 'allow_text_output': True, + 'response_prefix': None, }, 'logfire.span_type': 'span', 'logfire.msg': 'chat gpt-4o', From b42d6b05c1ffe71f56f861ba6f506ad02d5c1044 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 16 Sep 2025 19:23:54 +0000 Subject: [PATCH 11/11] some tweaks --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 + .../pydantic_ai/models/anthropic.py | 38 ++--- .../pydantic_ai/models/mistral.py | 2 + pydantic_ai_slim/pydantic_ai/models/openai.py | 27 ++-- .../test_anthropic_response_prefix.yaml | 76 +++++---- tests/models/test_anthropic.py | 144 ++++++++++++++++-- 6 files changed, 218 insertions(+), 71 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index f350206f5c..45d9e8925c 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -473,6 +473,8 @@ async def _prepare_request( # See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary message_history = _clean_message_history(message_history) + # TODO: Raise exception if response_prefix is not supported by the ctx.deps.model.profile + model_request_parameters = await _prepare_request_parameters(ctx, self.response_prefix) model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters) diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index c93f837012..f31938d255 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -266,11 +266,7 @@ async def _messages_create( if (allow_parallel_tool_calls := model_settings.get('parallel_tool_calls')) is not None: tool_choice['disable_parallel_tool_use'] = not allow_parallel_tool_calls - system_prompt, anthropic_messages = await self._map_message(messages) - - # Add response prefix as assistant message if provided - if model_request_parameters.response_prefix: - anthropic_messages.append({'role': 'assistant', 'content': model_request_parameters.response_prefix}) + system_prompt, anthropic_messages = await self._map_message(messages, model_request_parameters) try: extra_headers = model_settings.get('extra_headers', {}) @@ -308,9 +304,8 @@ def _process_response( for i, item in enumerate(response.content): if isinstance(item, BetaTextBlock): content = item.text - # Prepend response prefix to the first text block if provided - if i == 0 and model_request_parameters.response_prefix: - content = model_request_parameters.response_prefix + content + if i == 0 and (response_prefix := model_request_parameters.response_prefix): + content = response_prefix + content items.append(TextPart(content=content)) elif isinstance(item, BetaWebSearchToolResultBlock | BetaCodeExecutionToolResultBlock): items.append( @@ -410,7 +405,9 @@ def _get_builtin_tools( ) return tools, extra_headers - async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[BetaMessageParam]]: # noqa: C901 + async def _map_message( # noqa: C901 + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters + ) -> tuple[str, list[BetaMessageParam]]: """Just maps a `pydantic_ai.Message` to a `anthropic.types.MessageParam`.""" system_prompt_parts: list[str] = [] anthropic_messages: list[BetaMessageParam] = [] @@ -522,6 +519,12 @@ async def _map_message(self, messages: list[ModelMessage]) -> tuple[str, list[Be anthropic_messages.append(BetaMessageParam(role='assistant', content=assistant_content_params)) else: assert_never(m) + + if response_prefix := model_request_parameters.response_prefix: + anthropic_messages.append( + BetaMessageParam(role='assistant', content=[BetaTextBlockParam(text=response_prefix, type='text')]) + ) + if instructions := self._get_instructions(messages): system_prompt_parts.insert(0, instructions) system_prompt = '\n\n'.join(system_prompt_parts) @@ -627,11 +630,7 @@ class AnthropicStreamedResponse(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 current_block: BetaContentBlock | None = None - # Handle response prefix by emitting it as the first text event - if response_prefix := self.model_request_parameters.response_prefix: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix) - if maybe_event is not None: - yield maybe_event + response_prefix = self.model_request_parameters.response_prefix async for event in self._response: if isinstance(event, BetaRawMessageStartEvent): @@ -640,10 +639,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block - if isinstance(current_block, BetaTextBlock) and current_block.text: - maybe_event = self._parts_manager.handle_text_delta( - vendor_part_id=event.index, content=current_block.text - ) + if isinstance(current_block, BetaTextBlock) and (current_block.text or response_prefix): + text = current_block.text + if response_prefix: + text = response_prefix + text + response_prefix = None + + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=event.index, content=text) if maybe_event is not None: # pragma: no branch yield maybe_event elif isinstance(current_block, BetaThinkingBlock): diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0c749f3c60..07130fa26a 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -558,6 +558,8 @@ def _map_messages(self, messages: list[ModelMessage]) -> list[MistralMessages]: # Insert a dummy assistant message processed_messages.append(MistralAssistantMessage(content=[MistralTextChunk(text='OK')])) + # TODO: Insert response_prefix + return processed_messages def _map_user_prompt(self, part: UserPromptPart) -> MistralUserMessage: diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 9998f31390..d4bcff9413 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -544,10 +544,13 @@ def _process_response( for lp in choice.logprobs.content ] - if choice.message.content is not None: + if (content := choice.message.content) is not None: + if response_prefix := model_request_parameters.response_prefix: + content = response_prefix + content + items.extend( (replace(part, id='content', provider_name=self.system) if isinstance(part, ThinkingPart) else part) - for part in split_content_into_text_and_thinking(choice.message.content, self.profile.thinking_tags) + for part in split_content_into_text_and_thinking(content, self.profile.thinking_tags) ) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: @@ -624,7 +627,7 @@ def _get_web_search_options(self, model_request_parameters: ModelRequestParamete ) async def _map_messages( - self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters | None = None + self, messages: list[ModelMessage], model_request_parameters: ModelRequestParameters ) -> list[chat.ChatCompletionMessageParam]: """Just maps a `pydantic_ai.Message` to a `openai.types.ChatCompletionMessageParam`.""" openai_messages: list[chat.ChatCompletionMessageParam] = [] @@ -666,9 +669,9 @@ async def _map_messages( if instructions := self._get_instructions(messages): openai_messages.insert(0, chat.ChatCompletionSystemMessageParam(content=instructions, role='system')) - # Add response prefix as assistant message if provided - if model_request_parameters and model_request_parameters.response_prefix: - openai_messages.append({'role': 'assistant', 'content': model_request_parameters.response_prefix}) + if response_prefix := model_request_parameters.response_prefix: + # TODO: Add prefix=True for DeepSeek? + openai_messages.append(chat.ChatCompletionAssistantMessageParam(role='assistant', content=response_prefix)) return openai_messages @@ -1358,11 +1361,7 @@ class OpenAIStreamedResponse(StreamedResponse): _provider_name: str async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: - # Handle response prefix by emitting it as the first text event - if response_prefix := self.model_request_parameters.response_prefix: - maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=response_prefix) - if maybe_event is not None: - yield maybe_event + response_prefix = self.model_request_parameters.response_prefix async for chunk in self._response: self._usage += _map_usage(chunk) @@ -1385,7 +1384,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # Handle the text part of the response content = choice.delta.content - if content is not None: + if content is not None or response_prefix: + if response_prefix: + content = response_prefix + (content or '') + response_prefix = None + maybe_event = self._parts_manager.handle_text_delta( vendor_part_id='content', content=content, diff --git a/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml b/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml index b65acadcfd..1bc4f0c6f8 100644 --- a/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml +++ b/tests/models/cassettes/test_anthropic/test_anthropic_response_prefix.yaml @@ -8,7 +8,7 @@ interactions: connection: - keep-alive content-length: - - '215' + - '256' content-type: - application/json host: @@ -21,17 +21,20 @@ interactions: - text: 'What is the name of color #FF0000' type: text role: user - - content: It's name is + - content: + - text: Su nombre es + type: text role: assistant - model: claude-3-5-sonnet-latest + model: claude-sonnet-4-0 stream: false + system: Be concise. uri: https://api.anthropic.com/v1/messages?beta=true response: headers: connection: - keep-alive content-length: - - '601' + - '518' content-type: - application/json strict-transport-security: @@ -40,14 +43,11 @@ interactions: - chunked parsed_body: content: - - text: |2- - Red. #FF0000 is the hexadecimal color code for pure red in RGB color space, where: - - FF represents the maximum value (255) for the red channel - - 00 represents no green - - 00 represents no blue + - text: ' **rojo** (o "red" en inglés). Este es el valor hexadecimal para el color rojo puro en el espacio de color + RGB.' type: text - id: msg_0139bvQ8xYiX5eFuQo2kt2uJ - model: claude-3-5-sonnet-20241022 + id: msg_01AsJ8x22wZUZK43ebDwD12n + model: claude-sonnet-4-20250514 role: assistant stop_reason: end_turn stop_sequence: null @@ -58,8 +58,8 @@ interactions: ephemeral_5m_input_tokens: 0 cache_creation_input_tokens: 0 cache_read_input_tokens: 0 - input_tokens: 21 - output_tokens: 58 + input_tokens: 24 + output_tokens: 39 service_tier: standard status: code: 200 @@ -73,7 +73,7 @@ interactions: connection: - keep-alive content-length: - - '186' + - '255' content-type: - application/json host: @@ -83,52 +83,64 @@ interactions: max_tokens: 4096 messages: - content: - - text: Hello + - text: 'What is the name of color #FF0000' type: text role: user - - content: It's name is + - content: + - text: Su nombre es + type: text role: assistant - model: claude-3-5-sonnet-latest + model: claude-sonnet-4-0 stream: true + system: Be concise. uri: https://api.anthropic.com/v1/messages?beta=true response: body: string: |+ event: message_start - data: {"type":"message_start","message":{"id":"msg_01LrWvsWg9nVGu2ZAabNiKgH","type":"message","role":"assistant","model":"claude-3-5-sonnet-20241022","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":3,"service_tier":"standard"}} } + data: {"type":"message_start","message":{"id":"msg_01CAZPvhQ5cuSvKdgBBvi7ev","type":"message","role":"assistant","model":"claude-sonnet-4-20250514","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":24,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0},"output_tokens":2,"service_tier":"standard"}} } event: content_block_start - data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } + data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""} } event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" Claude. It"} } + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" **r"} } event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" was created by Anthropic and aims to be direct"} } + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"ojo** (o **red**"} } event: ping data: {"type": "ping"} event: content_block_delta - data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" and honest in its interactions."} } + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" en inglés).\n\nEl color #FF0000 es"} } - event: ping - data: {"type": "ping"} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" rojo puro en el sistema hex"} } - event: content_block_stop - data: {"type":"content_block_stop","index":0 } + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"adecimal RGB, donde:\n- FF = 255 ("} } - event: ping - data: {"type": "ping"} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"máximo valor de rojo)\n- 00 ="} } - event: ping - data: {"type": "ping"} + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" 0 (sin"} } + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" verde)\n- 00 = "}} + + event: content_block_delta + data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"0 (sin azul)"} } + + event: content_block_stop + data: {"type":"content_block_stop","index":0 } event: message_delta - data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":23} } + data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":24,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":82} } event: message_stop - data: {"type":"message_stop" } + data: {"type":"message_stop" } headers: cache-control: diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index d726665ec6..8ca02feb63 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -2727,23 +2727,149 @@ async def test_anthropic_web_search_tool_stream(allow_model_requests: None, anth async def test_anthropic_response_prefix(allow_model_requests: None, anthropic_api_key: str): """Test that Anthropic models correctly handle response prefix.""" - m = AnthropicModel('claude-3-5-sonnet-latest', provider=AnthropicProvider(api_key=anthropic_api_key)) - agent = Agent(m) + m = AnthropicModel('claude-sonnet-4-0', provider=AnthropicProvider(api_key=anthropic_api_key)) + agent = Agent(m, instructions='Be concise.') # Test non-streaming response - result = await agent.run('What is the name of color #FF0000', response_prefix="It's name is") - assert result.output.startswith("It's name is") + result = await agent.run('What is the name of color #FF0000', response_prefix='Su nombre es') + assert result.output == snapshot( + 'Su nombre es **rojo** (o "red" en inglés). Este es el valor hexadecimal para el color rojo puro en el espacio de color RGB.' + ) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the name of color #FF0000', + timestamp=IsDatetime(), + ) + ], + instructions='Be concise.', + ), + ModelResponse( + parts=[ + TextPart( + content='Su nombre es **rojo** (o "red" en inglés). Este es el valor hexadecimal para el color rojo puro en el espacio de color RGB.' + ) + ], + usage=RequestUsage( + input_tokens=24, + output_tokens=39, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 24, + 'output_tokens': 39, + }, + ), + model_name='claude-sonnet-4-20250514', + timestamp=IsDatetime(), + provider_name='anthropic', + provider_details={'finish_reason': 'end_turn'}, + provider_response_id='msg_01AsJ8x22wZUZK43ebDwD12n', + finish_reason='stop', + ), + ] + ) # Test streaming response event_parts: list[Any] = [] - async with agent.iter(user_prompt='Hello', response_prefix="It's name is") as agent_run: + async with agent.iter(user_prompt='What is the name of color #FF0000', response_prefix='Su nombre es') as agent_run: async for node in agent_run: if Agent.is_model_request_node(node): async with node.stream(agent_run.ctx) as request_stream: async for event in request_stream: event_parts.append(event) - # Check that the first text part xpstarts with the prefix - text_parts = [p for p in event_parts if isinstance(p, PartStartEvent) and isinstance(p.part, TextPart)] - assert len(text_parts) > 0 - assert cast(TextPart, text_parts[0].part).content.startswith("It's name is") + assert event_parts == snapshot( + [ + PartStartEvent(index=0, part=TextPart(content='Su nombre es')), + FinalResultEvent(tool_name=None, tool_call_id=None), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' **r')), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='ojo** (o **red**')), + PartDeltaEvent( + index=0, + delta=TextPartDelta( + content_delta="""\ + en inglés). + +El color #FF0000 es\ +""" + ), + ), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' rojo puro en el sistema hex')), + PartDeltaEvent( + index=0, + delta=TextPartDelta( + content_delta="""\ +adecimal RGB, donde: +- FF = 255 (\ +""" + ), + ), + PartDeltaEvent( + index=0, + delta=TextPartDelta( + content_delta="""\ +máximo valor de rojo) +- 00 =\ +""" + ), + ), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta=' 0 (sin')), + PartDeltaEvent( + index=0, + delta=TextPartDelta( + content_delta="""\ + verde) +- 00 = \ +""" + ), + ), + PartDeltaEvent(index=0, delta=TextPartDelta(content_delta='0 (sin azul)')), + ] + ) + assert agent_run.result is not None + assert agent_run.result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What is the name of color #FF0000', + timestamp=IsDatetime(), + ) + ], + instructions='Be concise.', + ), + ModelResponse( + parts=[ + TextPart( + content="""\ +Su nombre es **rojo** (o **red** en inglés). + +El color #FF0000 es rojo puro en el sistema hexadecimal RGB, donde: +- FF = 255 (máximo valor de rojo) +- 00 = 0 (sin verde) +- 00 = 0 (sin azul)\ +""" + ) + ], + usage=RequestUsage( + input_tokens=24, + output_tokens=82, + details={ + 'cache_creation_input_tokens': 0, + 'cache_read_input_tokens': 0, + 'input_tokens': 24, + 'output_tokens': 82, + }, + ), + model_name='claude-sonnet-4-20250514', + timestamp=IsDatetime(), + provider_name='anthropic', + provider_details={'finish_reason': 'end_turn'}, + provider_response_id='msg_01CAZPvhQ5cuSvKdgBBvi7ev', + finish_reason='stop', + ), + ] + )