Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorator tests for flow #38

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,393 changes: 0 additions & 3,393 deletions all_files.md

This file was deleted.

16 changes: 15 additions & 1 deletion src/controlflow/core/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,20 @@ async def _run_agent(
controller=self, agent=agent, tasks=tasks, thread=thread
)

def choose_agent(
self,
agents: list[Agent],
tasks: list[Task],
history: list = None,
instructions: list[str] = None,
) -> Agent:
return marvin_moderator(
agents=agents,
tasks=tasks,
history=history,
instructions=instructions,
)

@expose_sync_method("run_once")
async def run_once_async(self):
"""
Expand All @@ -190,7 +204,7 @@ async def run_once_async(self):
elif len(agents) == 1:
agent = agents[0]
else:
agent = marvin_moderator(
agent = self.choose_agent(
agents=agents,
tasks=tasks,
history=get_flow_messages(),
Expand Down
21 changes: 12 additions & 9 deletions src/controlflow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,19 @@ def wrapper(
if agents is not None:
flow_kwargs.setdefault("agents", agents)

p_fn = prefect.flow(fn)

flow_obj = Flow(**flow_kwargs, context=bound.arguments)

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with flow_obj, patch_marvin():
# create a function to wrap as a Prefect flow
@prefect.flow
def wrapped_flow(*args, **kwargs):
with Task(
fn.__name__,
instructions="Complete all subtasks of this task.",
is_auto_completed_by_subtasks=True,
context=bound.arguments,
) as parent_task:
with controlflow.instructions(instructions):
result = p_fn(*args, **kwargs)
result = fn(*args, **kwargs)

# ensure all subtasks are completed
parent_task.run()
Expand All @@ -107,7 +103,14 @@ def wrapper(
# resolve any returned tasks; this will raise on failure
result = resolve_tasks(result)

return result
return result

logger.info(
f'Executing AI flow "{fn.__name__}" on thread "{flow_obj.thread.id}"'
)

with flow_obj, patch_marvin():
return wrapped_flow(*args, **kwargs)

return wrapper

Expand Down
12 changes: 11 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,20 @@
import pytest
from controlflow.settings import temporary_settings
from prefect.testing.utilities import prefect_test_harness

from .fixtures import *


@pytest.fixture(autouse=True, scope="session")
def temp_settings():
def temp_controlflow_settings():
with temporary_settings(enable_global_flow=False, max_task_iterations=3):
yield


@pytest.fixture(autouse=True, scope="session")
def prefect_test_fixture():
"""
Run Prefect against temporary sqlite database
"""
with prefect_test_harness():
yield
66 changes: 63 additions & 3 deletions tests/fixtures/mocks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any
from unittest.mock import AsyncMock, Mock, patch

import pytest
from controlflow.core.agent import Agent
from controlflow.core.task import Task, TaskStatus
from controlflow.utilities.user_access import talk_to_human
from marvin.settings import temporary_settings as temporary_marvin_settings

# @pytest.fixture(autouse=True)
# def mock_talk_to_human():
Expand All @@ -20,10 +24,18 @@


@pytest.fixture
def mock_run(monkeypatch):
def prevent_openai_calls():
"""Prevent any calls to the OpenAI API from being made."""
with temporary_marvin_settings(openai__api_key="unset"):
yield


@pytest.fixture
def mock_run(monkeypatch, prevent_openai_calls):
"""
This fixture mocks the calls to OpenAI. Use it in a test and assign any desired side effects (like completing a task)
to the mock object's `.side_effect` attribute.
This fixture mocks the calls to the OpenAI Assistants API. Use it in a test
and assign any desired side effects (like completing a task) to the mock
object's `.side_effect` attribute.

For example:

Expand All @@ -41,3 +53,51 @@ def side_effect():
MockRun = AsyncMock()
monkeypatch.setattr("controlflow.core.controller.controller.Run.run_async", MockRun)
yield MockRun


@pytest.fixture
def mock_controller_run_agent(monkeypatch, prevent_openai_calls):
MockRunAgent = AsyncMock()
MockThreadGetMessages = Mock()

async def _run_agent(agent: Agent, tasks: list[Task] = None, thread=None):
for task in tasks:
if agent in task.agents:
# we can't call mark_successful because we don't know the result
task.status = TaskStatus.SUCCESSFUL

MockRunAgent.side_effect = _run_agent

def get_messages(*args, **kwargs):
return []

MockThreadGetMessages.side_effect = get_messages

monkeypatch.setattr(
"controlflow.core.controller.controller.Controller._run_agent", MockRunAgent
)
monkeypatch.setattr(
"marvin.beta.assistants.Thread.get_messages", MockThreadGetMessages
)
yield MockRunAgent


@pytest.fixture
def mock_controller_choose_agent(monkeypatch):
MockChooseAgent = Mock()

def choose_agent(agents, **kwargs):
return agents[0]

MockChooseAgent.side_effect = choose_agent

monkeypatch.setattr(
"controlflow.core.controller.controller.Controller.choose_agent",
MockChooseAgent,
)
yield MockChooseAgent


@pytest.fixture
def mock_controller(mock_controller_choose_agent, mock_controller_run_agent):
pass
54 changes: 54 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
from controlflow import Task
from controlflow.decorators import flow


@pytest.mark.usefixtures("mock_controller")
class TestFlowDecorator:
def test_flow_decorator(self):
@flow
def test_flow():
return 1

result = test_flow()
assert result == 1

def test_flow_decorator_runs_all_tasks(self):
tasks: list[Task] = []

@flow
def test_flow():
task = Task(
"say hello",
result_type=str,
result="Task completed successfully",
)
tasks.append(task)

result = test_flow()
assert result is None
assert tasks[0].is_successful()
assert tasks[0].result == "Task completed successfully"

def test_flow_decorator_resolves_all_tasks(self):
@flow
def test_flow():
task1 = Task("say hello", result="hello")
task2 = Task("say goodbye", result="goodbye")
task3 = Task("say goodnight", result="goodnight")
return dict(a=task1, b=[task2], c=dict(x=dict(y=[[task3]])))

result = test_flow()
assert result == dict(
a="hello", b=["goodbye"], c=dict(x=dict(y=[["goodnight"]]))
)

def test_manually_run_task_in_flow(self):
@flow
def test_flow():
task = Task("say hello", result="hello")
task.run()
return task.result

result = test_flow()
assert result == "hello"
Loading