Skip to content

Commit

Permalink
Merge pull request #36 from jlowin/flow
Browse files Browse the repository at this point in the history
Automatically resolve tasks within flows
  • Loading branch information
jlowin authored May 15, 2024
2 parents edd84fb + d85b6a2 commit 06f7ac5
Show file tree
Hide file tree
Showing 6 changed files with 345 additions and 100 deletions.
28 changes: 28 additions & 0 deletions examples/task_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from controlflow import Task, flow


@flow
def book_ideas():
genre = Task("pick a genre", str)
ideas = Task(
"generate three short ideas for a book",
list[str],
context=dict(genre=genre),
)
abstract = Task(
"pick one idea and write an abstract",
str,
context=dict(ideas=ideas, genre=genre),
)
title = Task(
"pick a title",
str,
context=dict(abstract=abstract),
)

return dict(genre=genre, ideas=ideas, abstract=abstract, title=title)


if __name__ == "__main__":
result = book_ideas()
print(result)
5 changes: 3 additions & 2 deletions src/controlflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .settings import settings

from .core.flow import Flow, reset_global_flow as _reset_global_flow, flow
from .core.task import Task, task
from .core.flow import Flow, reset_global_flow as _reset_global_flow
from .core.task import Task
from .core.agent import Agent
from .core.controller.controller import Controller
from .instructions import instructions
from .decorators import flow, task

Flow.model_rebuild()
Task.model_rebuild()
Expand Down
60 changes: 0 additions & 60 deletions src/controlflow/core/flow.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import functools
import inspect
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Union

import prefect
from marvin.beta.assistants import Thread
from openai.types.beta.threads import Message
from pydantic import Field, field_validator

import controlflow
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.marvin import patch_marvin
from controlflow.utilities.types import ControlFlowModel, ToolType

if TYPE_CHECKING:
Expand Down Expand Up @@ -95,59 +91,3 @@ def get_flow_messages(limit: int = None) -> list[Message]:
"""
flow = get_flow()
return flow.thread.get_messages(limit=limit)


def flow(
fn=None,
*,
thread: Thread = None,
instructions: str = None,
tools: list[ToolType] = None,
agents: list["Agent"] = None,
):
"""
A decorator that runs a function as a Flow
"""

if fn is None:
return functools.partial(
flow,
thread=thread,
tools=tools,
agents=agents,
)

sig = inspect.signature(fn)

@functools.wraps(fn)
def wrapper(
*args,
flow_kwargs: dict = None,
**kwargs,
):
# first process callargs
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()

flow_kwargs = flow_kwargs or {}

if thread is not None:
flow_kwargs.setdefault("thread", thread)
if tools is not None:
flow_kwargs.setdefault("tools", tools)
if agents is not None:
flow_kwargs.setdefault("agents", agents)

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with ctx(flow=flow_obj), patch_marvin():
with controlflow.instructions(instructions):
return p_fn(*args, **kwargs)

return wrapper
79 changes: 41 additions & 38 deletions src/controlflow/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
GenericAlias,
Literal,
TypeVar,
Expand All @@ -31,6 +29,7 @@
from controlflow.utilities.context import ctx
from controlflow.utilities.logging import get_logger
from controlflow.utilities.prefect import wrap_prefect_tool
from controlflow.utilities.tasks import collect_tasks, visit_task_collection
from controlflow.utilities.types import (
NOTSET,
AssistantTool,
Expand All @@ -53,32 +52,6 @@ class TaskStatus(Enum):
SKIPPED = "skipped"


def visit_task_collection(
val: Any, fn: Callable, recursion_limit: int = 3, _counter: int = 0
) -> list["Task"]:
if _counter >= recursion_limit:
return val

if isinstance(val, dict):
result = {}
for key, value in list(val.items()):
result[key] = visit_task_collection(
value, fn=fn, recursion_limit=recursion_limit, _counter=_counter + 1
)
elif isinstance(val, (list, set, tuple)):
result = []
for item in val:
result.append(
visit_task_collection(
item, fn=fn, recursion_limit=recursion_limit, _counter=_counter + 1
)
)
elif isinstance(val, Task):
return fn(val)

return val


class Task(ControlFlowModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4().hex[:5]))
objective: str = Field(
Expand Down Expand Up @@ -109,6 +82,7 @@ class Task(ControlFlowModel):
error: Union[str, None] = None
tools: list[ToolType] = []
user_access: bool = False
is_auto_completed_by_subtasks: bool = False
created_at: datetime.datetime = Field(default_factory=datetime.datetime.now)
_parent: "Union[Task, None]" = None
_downstreams: list["Task"] = []
Expand Down Expand Up @@ -191,16 +165,16 @@ def _turn_list_into_literal_result_type(cls, v):

@model_validator(mode="after")
def _finalize(self):
# create dependencies to tasks passed in as context
tasks = []

def visitor(task):
tasks.append(task)
return task
# validate correlated settings
if self.result_type is not None and self.is_auto_completed_by_subtasks:
raise ValueError(
"Tasks with a result type cannot be auto-completed by their subtasks."
)

visit_task_collection(self.context, visitor)
# create dependencies to tasks passed in as context
context_tasks = collect_tasks(self.context)

for task in tasks:
for task in context_tasks:
if task not in self.depends_on:
self.depends_on.append(task)
return self
Expand Down Expand Up @@ -283,7 +257,7 @@ def run_once(self, agent: "Agent" = None):

controller.run_once()

def run(self, max_iterations: int = NOTSET) -> T:
def run(self, raise_on_error: bool = True, max_iterations: int = NOTSET) -> T:
"""
Runs the task with provided agents until it is complete.
Expand All @@ -304,7 +278,7 @@ def run(self, max_iterations: int = NOTSET) -> T:
counter += 1
if self.is_successful():
return self.result
elif self.is_failed():
elif self.is_failed() and raise_on_error:
raise ValueError(f"{self.friendly_name()} failed: {self.error}")

@contextmanager
Expand Down Expand Up @@ -394,6 +368,9 @@ def get_tools(self) -> list[ToolType]:
tools.append(marvin.utilities.tools.tool_from_function(talk_to_human))
return [wrap_prefect_tool(t) for t in tools]

def dependencies(self):
return self.depends_on + self.subtasks

def mark_successful(self, result: T = None, validate: bool = True):
if validate:
if any(t.is_incomplete() for t in self.depends_on):
Expand All @@ -418,15 +395,41 @@ def mark_successful(self, result: T = None, validate: bool = True):

self.result = result
self.status = TaskStatus.SUCCESSFUL

# attempt to complete the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_successful(validate=True)

return f"{self.friendly_name()} marked successful. Updated task definition: {self.model_dump()}"

def mark_failed(self, message: Union[str, None] = None):
self.error = message
self.status = TaskStatus.FAILED

# attempt to fail the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_failed()

return f"{self.friendly_name()} marked failed. Updated task definition: {self.model_dump()}"

def mark_skipped(self):
self.status = TaskStatus.SKIPPED
# attempt to complete the parent, if appropriate
if (
self._parent
and self._parent.is_auto_completed_by_subtasks
and all_complete(self._parent.dependencies())
):
self._parent.mark_successful(validate=False)

return f"{self.friendly_name()} marked skipped. Updated task definition: {self.model_dump()}"


Expand Down
Loading

0 comments on commit 06f7ac5

Please sign in to comment.