Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ async def main():
system_prompts=(),
system_prompt_functions=[],
system_prompt_dynamic_functions={},
response_prefix=None,
),
ModelRequestNode(
request=ModelRequest(
Expand All @@ -293,7 +294,8 @@ async def main():
timestamp=datetime.datetime(...),
)
]
)
),
response_prefix=None,
),
CallToolsNode(
model_response=ModelResponse(
Expand Down Expand Up @@ -346,6 +348,7 @@ async def main():
system_prompts=(),
system_prompt_functions=[],
system_prompt_dynamic_functions={},
response_prefix=None,
),
ModelRequestNode(
request=ModelRequest(
Expand All @@ -355,7 +358,8 @@ async def main():
timestamp=datetime.datetime(...),
)
]
)
),
response_prefix=None,
),
CallToolsNode(
model_response=ModelResponse(
Expand Down
10 changes: 8 additions & 2 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ class UserPromptNode(AgentNode[DepsT, NodeRunEndT]):
_: dataclasses.KW_ONLY

deferred_tool_results: DeferredToolResults | None = None
response_prefix: str | None = None

instructions: str | None = None
instructions_functions: list[_system_prompt.SystemPromptRunner[DepsT]] = dataclasses.field(default_factory=list)
Expand Down Expand Up @@ -247,7 +248,7 @@ async def run( # noqa: C901

next_message.instructions = await ctx.deps.get_instructions(run_context)

return ModelRequestNode[DepsT, NodeRunEndT](request=next_message)
return ModelRequestNode[DepsT, NodeRunEndT](request=next_message, response_prefix=self.response_prefix)

async def _handle_deferred_tool_results( # noqa: C901
self,
Expand Down Expand Up @@ -348,6 +349,7 @@ async def _sys_parts(self, run_context: RunContext[DepsT]) -> list[_messages.Mod

async def _prepare_request_parameters(
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]],
response_prefix: str | None = None,
) -> models.ModelRequestParameters:
"""Build tools and create an agent model."""
output_schema = ctx.deps.output_schema
Expand All @@ -373,6 +375,7 @@ async def _prepare_request_parameters(
output_tools=output_tools,
output_object=output_object,
allow_text_output=allow_text_output,
response_prefix=response_prefix,
)


Expand All @@ -381,6 +384,7 @@ class ModelRequestNode(AgentNode[DepsT, NodeRunEndT]):
"""The node that makes a request to the model using the last message in state.message_history."""

request: _messages.ModelRequest
response_prefix: str | None = None

_result: CallToolsNode[DepsT, NodeRunEndT] | None = field(repr=False, init=False, default=None)
_did_stream: bool = field(repr=False, init=False, default=False)
Expand Down Expand Up @@ -469,7 +473,9 @@ async def _prepare_request(
# See `tests/test_tools.py::test_parallel_tool_return_with_deferred` for an example where this is necessary
message_history = _clean_message_history(message_history)

model_request_parameters = await _prepare_request_parameters(ctx)
# TODO: Raise exception if response_prefix is not supported by the ctx.deps.model.profile

model_request_parameters = await _prepare_request_parameters(ctx, self.response_prefix)
model_request_parameters = ctx.deps.model.customize_request_parameters(model_request_parameters)

model_settings = ctx.deps.model_settings
Expand Down
16 changes: 15 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...

@overload
Expand All @@ -455,6 +456,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...

@asynccontextmanager
Expand All @@ -472,6 +474,7 @@ async def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.

Expand Down Expand Up @@ -505,6 +508,7 @@ async def main():
system_prompts=(),
system_prompt_functions=[],
system_prompt_dynamic_functions={},
response_prefix=None,
),
ModelRequestNode(
request=ModelRequest(
Expand All @@ -514,7 +518,8 @@ async def main():
timestamp=datetime.datetime(...),
)
]
)
),
response_prefix=None,
),
CallToolsNode(
model_response=ModelResponse(
Expand Down Expand Up @@ -544,6 +549,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand All @@ -553,6 +559,13 @@ async def main():
model_used = self._get_model(model)
del model

# Validate response_prefix support
if response_prefix is not None and not model_used.profile.supports_response_prefix:
raise exceptions.UserError(
f'Model {model_used.model_name} does not support response prefix. '
'Response prefix is only supported by certain models like Anthropic Claude and some OpenAI-compatible models.'
)

deps = self._get_deps(deps)
output_schema = self._prepare_output_schema(output_type, model_used.profile)

Expand Down Expand Up @@ -640,6 +653,7 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
response_prefix=response_prefix,
)

agent_name = self.name or 'agent'
Expand Down
23 changes: 22 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def run(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[OutputDataT]: ...

@overload
Expand All @@ -145,6 +146,7 @@ async def run(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[RunOutputDataT]: ...

async def run(
Expand All @@ -162,6 +164,7 @@ async def run(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[Any]:
"""Run the agent with a user prompt in async mode.

Expand Down Expand Up @@ -194,6 +197,7 @@ async def main():
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand All @@ -214,6 +218,7 @@ async def main():
usage_limits=usage_limits,
usage=usage,
toolsets=toolsets,
response_prefix=response_prefix,
) as agent_run:
async for node in agent_run:
if event_stream_handler is not None and (
Expand Down Expand Up @@ -241,6 +246,7 @@ def run_sync(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[OutputDataT]: ...

@overload
Expand All @@ -259,6 +265,7 @@ def run_sync(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[RunOutputDataT]: ...

def run_sync(
Expand All @@ -276,6 +283,7 @@ def run_sync(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AgentRunResult[Any]:
"""Synchronously run the agent with a user prompt.

Expand Down Expand Up @@ -307,6 +315,7 @@ def run_sync(
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand All @@ -328,6 +337,7 @@ def run_sync(
infer_name=False,
toolsets=toolsets,
event_stream_handler=event_stream_handler,
response_prefix=response_prefix,
)
)

Expand All @@ -347,6 +357,7 @@ def run_stream(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ...

@overload
Expand All @@ -365,6 +376,7 @@ def run_stream(
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ...

@asynccontextmanager
Expand All @@ -383,6 +395,7 @@ async def run_stream( # noqa C901
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
event_stream_handler: EventStreamHandler[AgentDepsT] | None = None,
response_prefix: str | None = None,
) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]:
"""Run the agent with a user prompt in async streaming mode.

Expand Down Expand Up @@ -424,6 +437,7 @@ async def main():
event_stream_handler: Optional handler for events from the model's streaming response and the agent's execution of tools to use for this run.
It will receive all the events up until the final result is found, which you can then read or stream from inside the context manager.
Note that it does _not_ receive any events after the final result is found.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand All @@ -448,6 +462,7 @@ async def main():
usage=usage,
infer_name=False,
toolsets=toolsets,
response_prefix=response_prefix,
) as agent_run:
first_node = agent_run.next_node # start with the first node
assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node
Expand Down Expand Up @@ -558,6 +573,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...

@overload
Expand All @@ -575,6 +591,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...

@asynccontextmanager
Expand All @@ -593,6 +610,7 @@ async def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.

Expand Down Expand Up @@ -626,6 +644,7 @@ async def main():
system_prompts=(),
system_prompt_functions=[],
system_prompt_dynamic_functions={},
response_prefix=None,
),
ModelRequestNode(
request=ModelRequest(
Expand All @@ -635,7 +654,8 @@ async def main():
timestamp=datetime.datetime(...),
)
]
)
),
response_prefix=None,
),
CallToolsNode(
model_response=ModelResponse(
Expand Down Expand Up @@ -665,6 +685,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand Down
9 changes: 8 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ...

@overload
Expand All @@ -98,6 +99,7 @@ def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ...

@asynccontextmanager
Expand All @@ -115,6 +117,7 @@ async def iter(
usage: _usage.RunUsage | None = None,
infer_name: bool = True,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None,
response_prefix: str | None = None,
) -> AsyncIterator[AgentRun[AgentDepsT, Any]]:
"""A contextmanager which can be used to iterate over the agent graph's nodes as they are executed.

Expand Down Expand Up @@ -148,6 +151,7 @@ async def main():
system_prompts=(),
system_prompt_functions=[],
system_prompt_dynamic_functions={},
response_prefix=None,
),
ModelRequestNode(
request=ModelRequest(
Expand All @@ -157,7 +161,8 @@ async def main():
timestamp=datetime.datetime(...),
)
]
)
),
response_prefix=None,
),
CallToolsNode(
model_response=ModelResponse(
Expand Down Expand Up @@ -187,6 +192,7 @@ async def main():
usage: Optional usage to start with, useful for resuming a conversation or agents used in tools.
infer_name: Whether to try to infer the agent name from the call frame if it's not set.
toolsets: Optional additional toolsets for this run.
response_prefix: Optional prefix to prepend to the model's response. Only supported by certain models.

Returns:
The result of the run.
Expand All @@ -203,6 +209,7 @@ async def main():
usage=usage,
infer_name=infer_name,
toolsets=toolsets,
response_prefix=response_prefix,
) as run:
yield run

Expand Down
Loading
Loading