Skip to content

Commit

Permalink
Merge pull request #39 from jlowin/decorator
Browse files Browse the repository at this point in the history
Remove flow task; add threadmessage construct
  • Loading branch information
jlowin authored May 16, 2024
2 parents bdefd5f + 09ec56c commit fb7f379
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 126 deletions.
4 changes: 2 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .settings import settings

from .core.flow import Flow, reset_global_flow as _reset_global_flow
from .core.flow import Flow, reset_global_flow
from .core.task import Task
from .core.agent import Agent
from .core.controller.controller import Controller
Expand All @@ -11,4 +11,4 @@
Task.model_rebuild()
Agent.model_rebuild()

_reset_global_flow()
reset_global_flow()
32 changes: 27 additions & 5 deletions src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import math
from typing import Any, Union

import marvin.utilities
Expand All @@ -13,6 +14,7 @@
from prefect.context import FlowRunContext
from pydantic import BaseModel, Field, field_validator, model_validator

import controlflow
from controlflow.core.agent import Agent
from controlflow.core.controller.moderators import marvin_moderator
from controlflow.core.flow import Flow, get_flow, get_flow_messages
Expand All @@ -24,6 +26,7 @@
create_python_artifact,
wrap_prefect_tool,
)
from controlflow.utilities.tasks import any_incomplete
from controlflow.utilities.types import FunctionTool, Thread

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,6 +63,7 @@ class Controller(BaseModel, ExposeSyncMethodsMixin):
context: dict = {}
graph: Graph = None
model_config: dict = dict(extra="forbid")
_iteration: int = 0

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -171,14 +175,12 @@ def choose_agent(
self,
agents: list[Agent],
tasks: list[Task],
history: list = None,
instructions: list[str] = None,
iterations: int = 0,
) -> Agent:
return marvin_moderator(
agents=agents,
tasks=tasks,
history=history,
instructions=instructions,
iteration=self._iteration,
)

@expose_sync_method("run_once")
Expand All @@ -189,6 +191,9 @@ async def run_once_async(self):
# get the tasks to run
tasks = self.graph.upstream_dependencies(self.tasks, include_tasks=True)

if all(t.is_complete() for t in tasks):
return

# get the agents
agent_candidates = {a for t in tasks for a in t.agents if t.is_ready()}
if self.agents:
Expand All @@ -211,7 +216,24 @@ async def run_once_async(self):
instructions=get_instructions(),
)

return await self._run_agent(agent, tasks=tasks)
await self._run_agent(agent, tasks=tasks)
self._iteration += 1

@expose_sync_method("run")
async def run_async(self):
"""
Run the controller until all tasks are complete.
"""
max_task_iterations = controlflow.settings.max_task_iterations or math.inf
start_iteration = self._iteration
while any_incomplete(self.tasks):
await self.run_once_async()
if self._iteration > start_iteration + max_task_iterations * len(
self.tasks
):
raise ValueError(
f"Task iterations exceeded maximum of {max_task_iterations} for each task."
)


class AgentHandler(PrintHandler):
Expand Down
26 changes: 18 additions & 8 deletions src/controlflow/core/controller/instruction_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,18 +98,28 @@ class TasksTemplate(Template):
You can only mark a task successful when all of its dependencies and
subtasks have been completed. Subtasks may be marked as skipped without
providing a result. All else equal, prioritize older tasks over newer ones.
providing a result. All else equal, prioritize older tasks over newer
ones.
### Providing a result
Tasks may optionally request a typed result. Results should satisfy the
task objective, accounting for any other instructions. If a task does
not require a result, you must still complete the objective by posting
messages or using other tools before marking the task as complete.
Tasks may require a typed result (the `result_type`). Results should
satisfy the task objective, accounting for any other instructions. If a
task does not require a result (`result_type=None`), you must still
complete its stated objective by posting messages or using other tools
before marking the task as complete.
Try not to write out long results in messages that other agents can
read, and then repeating the entire result when marking the task as
complete. Other agents can see your task results when it is their turn.
#### Re-using a message
You can reuse the contents of any message as a task's result by
providing a special `ThreadMessage` object when marking a task
successful. Only do this if the thread message can be converted into the
task's result_type. Indicate the number of messages ago that the message
was posted (defaults to 1). Also provide any characters to strip from the
start or end of the message, to make sure that the result doesn't reveal
any internal details (for example, always remove your name prefix and
irrelevant comments from the beginning or end of the response such as
"I'll mark the task complete now.").
"""
tasks: list[Task]
Expand Down
68 changes: 12 additions & 56 deletions src/controlflow/core/controller/moderators.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,29 @@
import itertools
from typing import TYPE_CHECKING, Any, Generator

import marvin
from pydantic import BaseModel, Field

from controlflow.core.agent import Agent
from controlflow.core.flow import Flow, get_flow_messages
from controlflow.core.flow import get_flow_messages
from controlflow.core.task import Task
from controlflow.instructions import get_instructions

if TYPE_CHECKING:
from controlflow.core.agent import Agent


def round_robin(agents: list[Agent], tasks: list[Task]) -> Generator[Any, Any, Agent]:
"""
Given a list of potential agents, delegate the tasks in a round-robin fashion.
"""
cycle = itertools.cycle(agents)
while True:
yield next(cycle)


class BaseModerator(BaseModel):
def __call__(
self, agents: list[Agent], tasks: list[Task]
) -> Generator[Any, Any, Agent]:
yield from self.run(agents=agents, tasks=tasks)


class AgentModerator(BaseModerator):
agent: Agent
participate: bool = Field(
False,
description="If True, the moderator can participate in the conversation. Default is False.",
)

def __init__(self, agent: Agent, **kwargs):
super().__init__(agent=agent, **kwargs)

def run(self, agents: list[Agent], tasks: list[Task]) -> Generator[Any, Any, Agent]:
while True:
history = get_flow_messages()

with Flow():
task = Task(
"Choose the next agent that should speak.",
instructions="""
You are acting as a moderator. Choose the next agent to
speak. Complete the task and stay silent. Do not post
any messages, even to confirm marking the task as
successful.
""",
result_type=[a.name for a in agents],
context=dict(agents=agents, history=history, tasks=tasks),
agents=[self.agent],
parent=None,
)
agent_name = task.run()
yield next(a for a in agents if a.name == agent_name)
def round_robin(
agents: list[Agent],
tasks: list[Task],
context: dict = None,
iteration: int = 0,
) -> Agent:
return agents[iteration % len(agents)]


def marvin_moderator(
agents: list[Agent],
tasks: list[Task],
history: list = None,
instructions: list[str] = None,
context: dict = None,
iteration: int = 0,
model: str = None,
) -> Agent:
history = get_flow_messages()
instructions = get_instructions()
context = context or {}
context.update(tasks=tasks, history=history, instructions=instructions)
agent = marvin.classify(
Expand Down
5 changes: 1 addition & 4 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,13 @@ def __exit__(self, *exc_info):
def get_flow() -> Flow:
"""
Loads the flow from the context.
Will error if no flow is found in the context, unless the global flow is
enabled in settings
"""
flow: Union[Flow, None] = ctx.get("flow")
if not flow:
if controlflow.settings.enable_global_flow:
return GLOBAL_FLOW
else:
raise ValueError("No flow found in context.")
raise ValueError("No flow found in context and global flow is disabled.")
return flow


Expand Down
Loading

0 comments on commit fb7f379

Please sign in to comment.