-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Support overriding agent instructions #2926
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
28b8e8b
4b36d26
b1c0892
fd49ab2
a1d25cd
2ad2a99
a36c1bf
e7d531c
fa10e67
d51624b
696c4ba
01c6f47
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -65,7 +65,7 @@ | |||||
from ..toolsets.combined import CombinedToolset | ||||||
from ..toolsets.function import FunctionToolset | ||||||
from ..toolsets.prepared import PreparedToolset | ||||||
from .abstract import AbstractAgent, EventStreamHandler, RunOutputDataT | ||||||
from .abstract import AbstractAgent, EventStreamHandler, Instructions, RunOutputDataT | ||||||
from .wrapper import WrapperAgent | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
@@ -163,10 +163,7 @@ def __init__( | |||||
model: models.Model | models.KnownModelName | str | None = None, | ||||||
*, | ||||||
output_type: OutputSpec[OutputDataT] = str, | ||||||
instructions: str | ||||||
| _system_prompt.SystemPromptFunc[AgentDepsT] | ||||||
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | ||||||
| None = None, | ||||||
instructions: Instructions = None, | ||||||
system_prompt: str | Sequence[str] = (), | ||||||
deps_type: type[AgentDepsT] = NoneType, | ||||||
name: str | None = None, | ||||||
|
@@ -192,10 +189,7 @@ def __init__( | |||||
model: models.Model | models.KnownModelName | str | None = None, | ||||||
*, | ||||||
output_type: OutputSpec[OutputDataT] = str, | ||||||
instructions: str | ||||||
| _system_prompt.SystemPromptFunc[AgentDepsT] | ||||||
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | ||||||
| None = None, | ||||||
instructions: Instructions = None, | ||||||
system_prompt: str | Sequence[str] = (), | ||||||
deps_type: type[AgentDepsT] = NoneType, | ||||||
name: str | None = None, | ||||||
|
@@ -219,10 +213,7 @@ def __init__( | |||||
model: models.Model | models.KnownModelName | str | None = None, | ||||||
*, | ||||||
output_type: OutputSpec[OutputDataT] = str, | ||||||
instructions: str | ||||||
| _system_prompt.SystemPromptFunc[AgentDepsT] | ||||||
| Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] | ||||||
| None = None, | ||||||
instructions: Instructions = None, | ||||||
system_prompt: str | Sequence[str] = (), | ||||||
deps_type: type[AgentDepsT] = NoneType, | ||||||
name: str | None = None, | ||||||
|
@@ -320,16 +311,7 @@ def __init__( | |||||
self._output_schema = _output.OutputSchema[OutputDataT].build(output_type, default_mode=default_output_mode) | ||||||
self._output_validators = [] | ||||||
|
||||||
self._instructions = '' | ||||||
self._instructions_functions = [] | ||||||
if isinstance(instructions, str | Callable): | ||||||
instructions = [instructions] | ||||||
for instruction in instructions or []: | ||||||
if isinstance(instruction, str): | ||||||
self._instructions += instruction + '\n' | ||||||
else: | ||||||
self._instructions_functions.append(_system_prompt.SystemPromptRunner(instruction)) | ||||||
self._instructions = self._instructions.strip() or None | ||||||
self._instructions, self._instructions_functions = self._instructions_literal_and_functions(instructions) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking we could store |
||||||
|
||||||
self._system_prompts = (system_prompt,) if isinstance(system_prompt, str) else tuple(system_prompt) | ||||||
self._system_prompt_functions = [] | ||||||
|
@@ -369,11 +351,44 @@ def __init__( | |||||
self._override_tools: ContextVar[ | ||||||
_utils.Option[Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]]] | ||||||
] = ContextVar('_override_tools', default=None) | ||||||
self._override_instructions: ContextVar[_utils.Option[Instructions]] = ContextVar( | ||||||
'_override_instructions', default=None | ||||||
) | ||||||
|
||||||
self._enter_lock = Lock() | ||||||
self._entered_count = 0 | ||||||
self._exit_stack = None | ||||||
|
||||||
def _get_instructions_literal_and_functions( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's just call this |
||||||
self, | ||||||
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]: | ||||||
instructions, instructions_functions = self._instructions, self._instructions_functions | ||||||
if override_instructions := self._override_instructions.get(): | ||||||
instructions, instructions_functions = self._instructions_literal_and_functions(override_instructions.value) | ||||||
return instructions, instructions_functions | ||||||
|
||||||
def _instructions_literal_and_functions( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With my suggestion above, we should not need this as a separate method anymore, so we can move its contents into the get method. |
||||||
self, | ||||||
instructions: Instructions, | ||||||
) -> tuple[str | None, list[_system_prompt.SystemPromptRunner[AgentDepsT]]]: | ||||||
literal_parts: list[str] = [] | ||||||
functions: list[_system_prompt.SystemPromptRunner[AgentDepsT]] = [] | ||||||
|
||||||
if isinstance(instructions, str | Callable): | ||||||
instructions = [instructions] | ||||||
|
||||||
for instruction in instructions or []: | ||||||
if isinstance(instruction, str): | ||||||
literal_parts.append(instruction) | ||||||
elif callable(instruction): | ||||||
func = cast(_system_prompt.SystemPromptFunc[AgentDepsT], instruction) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't love this, but not sure how to appease the type checker. This was marked as "unknown" otherwise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work if we change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, i think because we have an explicit:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mwildehahn Hmm ok, once this PR is otherwise ready I'll have a look to see if I can clean up the typing here a bit. |
||||||
functions.append(_system_prompt.SystemPromptRunner[AgentDepsT](func)) | ||||||
else: # pragma: no cover | ||||||
raise ValueError(f'Invalid instruction: {instruction}') | ||||||
|
||||||
literal = '\n'.join(literal_parts).strip() or None | ||||||
return literal, functions | ||||||
|
||||||
@staticmethod | ||||||
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: | ||||||
"""Set the instrumentation options for all agents where `instrument` is not set.""" | ||||||
|
@@ -592,9 +607,10 @@ async def main(): | |||||
usage_limits = usage_limits or _usage.UsageLimits() | ||||||
|
||||||
async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: | ||||||
literal, functions = self._get_instructions_literal_and_functions() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the changes I suggested above, this can be:
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And can we move this out of the function, and then use the same |
||||||
parts = [ | ||||||
self._instructions, | ||||||
*[await func.run(run_context) for func in self._instructions_functions], | ||||||
literal, | ||||||
*[await func.run(run_context) for func in functions], | ||||||
] | ||||||
|
||||||
model_profile = model_used.profile | ||||||
|
@@ -632,11 +648,13 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: | |||||
get_instructions=get_instructions, | ||||||
instrumentation_settings=instrumentation_settings, | ||||||
) | ||||||
|
||||||
instructions_for_node, instructions_functions_for_node = self._get_instructions_literal_and_functions() | ||||||
start_node = _agent_graph.UserPromptNode[AgentDepsT]( | ||||||
user_prompt=user_prompt, | ||||||
deferred_tool_results=deferred_tool_results, | ||||||
instructions=self._instructions, | ||||||
instructions_functions=self._instructions_functions, | ||||||
instructions=instructions_for_node, | ||||||
instructions_functions=instructions_functions_for_node, | ||||||
mwildehahn marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
system_prompts=self._system_prompts, | ||||||
system_prompt_functions=self._system_prompt_functions, | ||||||
system_prompt_dynamic_functions=self._system_prompt_dynamic_functions, | ||||||
|
@@ -720,8 +738,9 @@ def override( | |||||
model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, | ||||||
toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, | ||||||
tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] | _utils.Unset = _utils.UNSET, | ||||||
instructions: Instructions | _utils.Unset = _utils.UNSET, | ||||||
) -> Iterator[None]: | ||||||
"""Context manager to temporarily override agent dependencies, model, toolsets, or tools. | ||||||
"""Context manager to temporarily override agent dependencies, model, toolsets, tools, and instructions. | ||||||
|
||||||
This is particularly useful when testing. | ||||||
You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). | ||||||
|
@@ -731,6 +750,7 @@ def override( | |||||
model: The model to use instead of the model passed to the agent run. | ||||||
toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. | ||||||
tools: The tools to use instead of the tools registered with the agent. | ||||||
instructions: The instructions to use instead of the instructions registered with the agent. | ||||||
""" | ||||||
if _utils.is_set(deps): | ||||||
deps_token = self._override_deps.set(_utils.Some(deps)) | ||||||
|
@@ -752,6 +772,11 @@ def override( | |||||
else: | ||||||
tools_token = None | ||||||
|
||||||
if _utils.is_set(instructions): | ||||||
instructions_token = self._override_instructions.set(_utils.Some(instructions)) | ||||||
else: | ||||||
instructions_token = None | ||||||
|
||||||
try: | ||||||
yield | ||||||
finally: | ||||||
|
@@ -763,6 +788,8 @@ def override( | |||||
self._override_toolsets.reset(toolsets_token) | ||||||
if tools_token is not None: | ||||||
self._override_tools.reset(tools_token) | ||||||
if instructions_token is not None: | ||||||
self._override_instructions.reset(instructions_token) | ||||||
|
||||||
@overload | ||||||
def instructions( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove this file, it should've gotten cleaned up automatically 🤔