Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Oct 26, 2024
1 parent f5b098d commit 3297a07
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 28 deletions.
72 changes: 45 additions & 27 deletions src/controlflow/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import random
import warnings
from contextlib import contextmanager
from contextlib import AbstractContextManager, contextmanager
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -13,7 +13,14 @@
)

from langchain_core.language_models import BaseChatModel
from pydantic import Field, field_serializer, field_validator
from pydantic import (
ConfigDict,
Field,
PrivateAttr,
field_serializer,
field_validator,
model_validator,
)
from typing_extensions import Self

import controlflow
Expand Down Expand Up @@ -48,27 +55,29 @@ class Agent(ControlFlowModel, abc.ABC):
Class for objects that can be used as agents in a flow
"""

model_config = dict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True)

id: str = Field(None)
id: Optional[str] = Field(default=None)
name: str = Field(
default_factory=lambda: random.choice(AGENT_NAMES),
description="The name of the agent.",
)
description: Optional[str] = Field(
None, description="A description of the agent, visible to other agents."
default=None, description="A description of the agent, visible to other agents."
)
instructions: Optional[str] = Field(
"You are a diligent AI assistant. You complete your tasks efficiently and without error.",
default="You are a diligent AI assistant. You complete your tasks efficiently and without error.",
description="Instructions for the agent, private to this agent.",
)
prompt: Optional[str] = Field(
None,
default=None,
description="A system template for the agent. The template should be formatted as a jinja2 template.",
)
tools: list[Tool] = Field([], description="List of tools available to the agent.")
tools: list[Tool] = Field(
default=[], description="List of tools available to the agent."
)
interactive: bool = Field(
False,
default=False,
description="If True, the agent is given tools for interacting with a human user.",
)
memories: list[Memory] = Field(
Expand All @@ -77,39 +86,48 @@ class Agent(ControlFlowModel, abc.ABC):
)

model: Optional[Union[str, BaseChatModel]] = Field(
None,
default=None,
description="The LangChain BaseChatModel used by the agent. If not provided, the default model will be used. A compatible string can be passed to automatically retrieve the model.",
exclude=True,
)
llm_rules: Optional[LLMRules] = Field(
None,
default=None,
description="The LLM rules for the agent. If not provided, the rules will be inferred from the model (if possible).",
)

_cm_stack: list[contextmanager] = []
_cm_stack: list[AbstractContextManager] = PrivateAttr(default_factory=list)

def __init__(self, instructions: str = None, **kwargs):
if instructions is not None:
kwargs["instructions"] = instructions
@model_validator(mode="before")
def handle_positional_arg(cls, v):
if isinstance(v, str):
return {"instructions": v}
return v

@model_validator(mode="before")
@classmethod
def emit_deprecation_warnings(cls, v):
# deprecated in 0.9
if "user_access" in kwargs:
if "user_access" in v:
warnings.warn(
"The `user_access` argument is deprecated. Use `interactive=True` instead.",
DeprecationWarning,
)
kwargs["interactive"] = kwargs.pop("user_access")
v["interactive"] = v.pop("user_access")
return v

@model_validator(mode="after")
def add_instructions(self) -> Self:
if additional_instructions := get_instructions():
kwargs["instructions"] = (
kwargs.get("instructions")
or "" + "\n" + "\n".join(additional_instructions)
self.instructions = (
self.instructions or "" + "\n" + "\n".join(additional_instructions)
).strip()
return self

super().__init__(**kwargs)

@model_validator(mode="after")
def generate_id(self) -> Self:
if not self.id:
self.id = self._generate_id()
return self

def __hash__(self) -> int:
return id(self)
Expand Down Expand Up @@ -157,7 +175,7 @@ def serialize_for_prompt(self) -> dict:
dct.pop("interactive")
return dct

def get_model(self, tools: list["Tool"] = None) -> BaseChatModel:
def get_model(self, tools: Optional[list["Tool"]] = None) -> BaseChatModel:
"""
Retrieve the LLM model for this agent
"""
Expand Down Expand Up @@ -212,8 +230,8 @@ def run(
self,
objective: str,
*,
turn_strategy: "TurnStrategy" = None,
handlers: list["Handler"] = None,
turn_strategy: Optional["TurnStrategy"] = None,
handlers: Optional[list["Handler"]] = None,
**task_kwargs,
):
return controlflow.run(
Expand All @@ -228,8 +246,8 @@ async def run_async(
self,
objective: str,
*,
turn_strategy: "TurnStrategy" = None,
handlers: list["Handler"] = None,
turn_strategy: Optional["TurnStrategy"] = None,
handlers: Optional[list["Handler"]] = None,
**task_kwargs,
):
return await controlflow.run_async(
Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_agent_loads_instructions_at_creation(self):
with instructions("test instruction"):
agent = Agent()

assert "test instruction" in agent.instructions
assert agent.instructions and "test instruction" in agent.instructions

@pytest.mark.skip(reason="IDs are not stable right now")
def test_stable_id(self):
Expand Down

0 comments on commit 3297a07

Please sign in to comment.