diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index fb2de47d..daedb857 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -2,7 +2,7 @@ import logging import random import warnings -from contextlib import contextmanager +from contextlib import AbstractContextManager, contextmanager from typing import ( TYPE_CHECKING, Any, @@ -13,7 +13,14 @@ ) from langchain_core.language_models import BaseChatModel -from pydantic import Field, field_serializer, field_validator +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + field_serializer, + field_validator, + model_validator, +) from typing_extensions import Self import controlflow @@ -48,27 +55,29 @@ class Agent(ControlFlowModel, abc.ABC): Class for objects that can be used as agents in a flow """ - model_config = dict(arbitrary_types_allowed=True) + model_config = ConfigDict(arbitrary_types_allowed=True) - id: str = Field(None) + id: Optional[str] = Field(default=None) name: str = Field( default_factory=lambda: random.choice(AGENT_NAMES), description="The name of the agent.", ) description: Optional[str] = Field( - None, description="A description of the agent, visible to other agents." + default=None, description="A description of the agent, visible to other agents." ) instructions: Optional[str] = Field( - "You are a diligent AI assistant. You complete your tasks efficiently and without error.", + default="You are a diligent AI assistant. You complete your tasks efficiently and without error.", description="Instructions for the agent, private to this agent.", ) prompt: Optional[str] = Field( - None, + default=None, description="A system template for the agent. The template should be formatted as a jinja2 template.", ) - tools: list[Tool] = Field([], description="List of tools available to the agent.") + tools: list[Tool] = Field( + default=[], description="List of tools available to the agent." + ) interactive: bool = Field( - False, + default=False, description="If True, the agent is given tools for interacting with a human user.", ) memories: list[Memory] = Field( @@ -77,39 +86,48 @@ class Agent(ControlFlowModel, abc.ABC): ) model: Optional[Union[str, BaseChatModel]] = Field( - None, + default=None, description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used. A compatible string can be passed to automatically retrieve the model.", exclude=True, ) llm_rules: Optional[LLMRules] = Field( - None, + default=None, description="The LLM rules for the agent. If not provided, the rules will be inferred from the model (if possible).", ) - _cm_stack: list[contextmanager] = [] + _cm_stack: list[AbstractContextManager] = PrivateAttr(default_factory=list) - def __init__(self, instructions: str = None, **kwargs): - if instructions is not None: - kwargs["instructions"] = instructions + @model_validator(mode="before") + def handle_positional_arg(cls, v): + if isinstance(v, str): + return {"instructions": v} + return v + @model_validator(mode="before") + @classmethod + def emit_deprecation_warnings(cls, v): # deprecated in 0.9 - if "user_access" in kwargs: + if "user_access" in v: warnings.warn( "The `user_access` argument is deprecated. Use `interactive=True` instead.", DeprecationWarning, ) - kwargs["interactive"] = kwargs.pop("user_access") + v["interactive"] = v.pop("user_access") + return v + @model_validator(mode="after") + def add_instructions(self) -> Self: if additional_instructions := get_instructions(): - kwargs["instructions"] = ( - kwargs.get("instructions") - or "" + "\n" + "\n".join(additional_instructions) + self.instructions = ( + self.instructions or "" + "\n" + "\n".join(additional_instructions) ).strip() + return self - super().__init__(**kwargs) - + @model_validator(mode="after") + def generate_id(self) -> Self: if not self.id: self.id = self._generate_id() + return self def __hash__(self) -> int: return id(self) @@ -157,7 +175,7 @@ def serialize_for_prompt(self) -> dict: dct.pop("interactive") return dct - def get_model(self, tools: list["Tool"] = None) -> BaseChatModel: + def get_model(self, tools: Optional[list["Tool"]] = None) -> BaseChatModel: """ Retrieve the LLM model for this agent """ @@ -212,8 +230,8 @@ def run( self, objective: str, *, - turn_strategy: "TurnStrategy" = None, - handlers: list["Handler"] = None, + turn_strategy: Optional["TurnStrategy"] = None, + handlers: Optional[list["Handler"]] = None, **task_kwargs, ): return controlflow.run( @@ -228,8 +246,8 @@ async def run_async( self, objective: str, *, - turn_strategy: "TurnStrategy" = None, - handlers: list["Handler"] = None, + turn_strategy: Optional["TurnStrategy"] = None, + handlers: Optional[list["Handler"]] = None, **task_kwargs, ): return await controlflow.run_async( diff --git a/tests/agents/test_agents.py b/tests/agents/test_agents.py index cc420886..5d0dfeef 100644 --- a/tests/agents/test_agents.py +++ b/tests/agents/test_agents.py @@ -55,7 +55,7 @@ def test_agent_loads_instructions_at_creation(self): with instructions("test instruction"): agent = Agent() - assert "test instruction" in agent.instructions + assert agent.instructions and "test instruction" in agent.instructions @pytest.mark.skip(reason="IDs are not stable right now") def test_stable_id(self):