Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added dbostest.sqlite
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this file, it should've gotten cleaned up automatically 🤔

Binary file not shown.
83 changes: 55 additions & 28 deletions pydantic_ai_slim/pydantic_ai/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from ..toolsets.combined import CombinedToolset
from ..toolsets.function import FunctionToolset
from ..toolsets.prepared import PreparedToolset
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT
from .wrapper import WrapperAgent

if TYPE_CHECKING:
Expand Down Expand Up @@ -163,10 +163,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand All @@ -192,10 +189,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand All @@ -219,10 +213,7 @@ def __init__(
model: models.Model | models.KnownModelName | str | None = None,
*,
output_type: OutputSpec[OutputDataT] = str,
instructions: str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None = None,
instructions: Instructions = None,
system_prompt: str | Sequence[str] = (),
deps_type: type[AgentDepsT] = NoneType,
name: str | None = None,
Expand Down Expand Up @@ -320,16 +311,7 @@ def __init__(
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode)
self._output_validators = []

self._instructions = ''
self._instructions_functions = []
if isinstance(instructions, str | Callable):
instructions = [instructions]
for instruction in instructions or []:
if isinstance(instruction, str):
self._instructions += instruction + '\n'
else:
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction))
self._instructions = self._instructions.strip() or None
self._instructions, self._instructions_functions = self._instructions_literal_and_functions(instructions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking we could store self._instructions = instructions, and then call _get_instructions_literal_and_functions where we need it. I don't think we need the 2 private vars


self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt)
self._system_prompt_functions = []
Expand Down Expand Up @@ -369,11 +351,44 @@ def __init__(
self._override_tools: ContextVar[
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]]
] = ContextVar('_override_tools', default=None)
self._override_instructions: ContextVar[_utils.Option[Instructions]] = ContextVar(
'_override_instructions', default=None
)

self._enter_lock = Lock()
self._entered_count = 0
self._exit_stack = None

def _get_instructions_literal_and_functions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's just call this _get_instructions

self,
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]:
instructions, instructions_functions = self._instructions, self._instructions_functions
if override_instructions := self._override_instructions.get():
instructions, instructions_functions = self._instructions_literal_and_functions(override_instructions.value)
return instructions, instructions_functions

def _instructions_literal_and_functions(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With my suggestion above, we should not need this as a separate method anymore, so we can move its contents into the get method.

self,
instructions: Instructions,
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]:
literal_parts: list[str] = []
functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []

if isinstance(instructions, str | Callable):
instructions = [instructions]

for instruction in instructions or []:
if isinstance(instruction, str):
literal_parts.append(instruction)
elif callable(instruction):
func = cast(_system_prompt.SystemPromptFunc[AgentDepsT], instruction)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love this, but not sure how to appease the type checker. This was marked as "unknown" otherwise.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work if we change elif callable(instruction): to just else:, like we had in the original code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, i think because we have an explicit:

functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = []

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mwildehahn Hmm ok, once this PR is otherwise ready I'll have a look to see if I can clean up the typing here a bit.

functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func))
else: # pragma: no cover
raise ValueError(f'Invalid instruction: {instruction}')

literal = '\n'.join(literal_parts).strip() or None
return literal, functions

@staticmethod
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
"""Set the instrumentation options for all agents where `instrument` is not set."""
Expand Down Expand Up @@ -592,9 +607,10 @@ async def main():
usage_limits = usage_limits or _usage.UsageLimits()

async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
literal, functions = self._get_instructions_literal_and_functions()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the changes I suggested above, this can be:

Suggested change
literal, functions = self._get_instructions_literal_and_functions()
instructions, instructions_functions = self._get_instructions()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And can we move this out of the function, and then use the same instructions and instructions_functions when we build the UserPromptNode below, to save calling this method twice?

parts = [
self._instructions,
*[await func.run(run_context) for func in self._instructions_functions],
literal,
*[await func.run(run_context) for func in functions],
]

model_profile = model_used.profile
Expand Down Expand Up @@ -632,11 +648,13 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None:
get_instructions=get_instructions,
instrumentation_settings=instrumentation_settings,
)

instructions_for_node, instructions_functions_for_node = self._get_instructions_literal_and_functions()
start_node = _agent_graph.UserPromptNode[AgentDepsT](
user_prompt=user_prompt,
deferred_tool_results=deferred_tool_results,
instructions=self._instructions,
instructions_functions=self._instructions_functions,
instructions=instructions_for_node,
instructions_functions=instructions_functions_for_node,
system_prompts=self._system_prompts,
system_prompt_functions=self._system_prompt_functions,
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions,
Expand Down Expand Up @@ -720,8 +738,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -731,6 +750,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if _utils.is_set(deps):
deps_token = self._override_deps.set(_utils.Some(deps))
Expand All @@ -752,6 +772,11 @@ def override(
else:
tools_token = None

if _utils.is_set(instructions):
instructions_token = self._override_instructions.set(_utils.Some(instructions))
else:
instructions_token = None

try:
yield
finally:
Expand All @@ -763,6 +788,8 @@ def override(
self._override_toolsets.reset(toolsets_token)
if tools_token is not None:
self._override_tools.reset(tools_token)
if instructions_token is not None:
self._override_instructions.reset(instructions_token)

@overload
def instructions(
Expand Down
13 changes: 12 additions & 1 deletion pydantic_ai_slim/pydantic_ai/agent/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .. import (
_agent_graph,
_system_prompt,
_utils,
exceptions,
messages as _messages,
Expand Down Expand Up @@ -60,6 +61,14 @@
"""A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools."""


Instructions = (
str
| _system_prompt.SystemPromptFunc[AgentDepsT]
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]]
| None
)


class AbstractAgent(Generic[AgentDepsT, OutputDataT], ABC):
"""Abstract superclass for [`Agent`][pydantic_ai.agent.Agent], [`WrapperAgent`][pydantic_ai.agent.WrapperAgent], and your own custom agent implementations."""

Expand Down Expand Up @@ -681,8 +690,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -692,6 +702,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
raise NotImplementedError
yield
Expand Down
14 changes: 11 additions & 3 deletions pydantic_ai_slim/pydantic_ai/agent/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ToolFuncEither,
)
from ..toolsets import AbstractToolset
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT


class WrapperAgent(AbstractAgent[AgentDepsT, OutputDataT]):
Expand Down Expand Up @@ -214,8 +214,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -225,6 +226,13 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
with self.wrapped.override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with self.wrapped.override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
13 changes: 11 additions & 2 deletions pydantic_ai_slim/pydantic_ai/durable_exec/dbos/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
usage as _usage,
)
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
from pydantic_ai.agent.abstract import Instructions
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -704,8 +705,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -715,11 +717,18 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if _utils.is_set(model) and not isinstance(model, (DBOSModel)):
raise UserError(
'Non-DBOS model cannot be contextually overridden inside a DBOS workflow, it must be set at agent creation time.'
)

with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with super().override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
15 changes: 12 additions & 3 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
models,
usage as _usage,
)
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, RunOutputDataT, WrapperAgent
from pydantic_ai.agent import AbstractAgent, AgentRun, AgentRunResult, EventStreamHandler, WrapperAgent
from pydantic_ai.agent.abstract import Instructions, RunOutputDataT
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import Model
from pydantic_ai.output import OutputDataT, OutputSpec
Expand Down Expand Up @@ -748,8 +749,9 @@ def override(
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET,
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET,
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET,
instructions: Instructions | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools.
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions.

This is particularly useful when testing.
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures).
Expand All @@ -759,6 +761,7 @@ def override(
model: The model to use instead of the model passed to the agent run.
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run.
tools: The tools to use instead of the tools registered with the agent.
instructions: The instructions to use instead of the instructions registered with the agent.
"""
if workflow.in_workflow():
if _utils.is_set(model):
Expand All @@ -774,5 +777,11 @@ def override(
'Tools cannot be contextually overridden inside a Temporal workflow, they must be set at agent creation time.'
)

with super().override(deps=deps, model=model, toolsets=toolsets, tools=tools):
with super().override(
deps=deps,
model=model,
toolsets=toolsets,
tools=tools,
instructions=instructions,
):
yield
Loading