diff --git a/src/controlflow/agents/agent.py b/src/controlflow/agents/agent.py index e14ee063..6c4381a9 100644 --- a/src/controlflow/agents/agent.py +++ b/src/controlflow/agents/agent.py @@ -291,8 +291,7 @@ def _run_model( from controlflow.events.events import ( AgentMessage, AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, + ToolResult, ) tools = as_tools(self.get_tools() + tools) @@ -312,12 +311,17 @@ def _run_model( else: response += delta - yield AgentMessageDelta(agent=self, delta=delta, snapshot=response) + yield from AgentMessageDelta( + agent=self, message_delta=delta, message_snapshot=response + ).all_related_events(tools=tools) else: response: AIMessage = model.invoke(messages) - yield AgentMessage(agent=self, message=response) + yield from AgentMessage(agent=self, message=response).all_related_events( + tools=tools + ) + create_markdown_artifact( markdown=f""" {response.content or '(No content)'} @@ -335,9 +339,8 @@ def _run_model( logger.debug(f"Response: {response}") for tool_call in response.tool_calls + response.invalid_tool_calls: - yield ToolCallEvent(agent=self, tool_call=tool_call) result = handle_tool_call(tool_call, tools=tools) - yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_result=result) @prefect_task(task_run_name="Call LLM") async def _run_model_async( @@ -350,8 +353,7 @@ async def _run_model_async( from controlflow.events.events import ( AgentMessage, AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, + ToolResult, ) tools = as_tools(self.get_tools() + tools) @@ -371,12 +373,18 @@ async def _run_model_async( else: response += delta - yield AgentMessageDelta(agent=self, delta=delta, snapshot=response) + for event in AgentMessageDelta( + agent=self, message_delta=delta, message_snapshot=response + ).all_related_events(tools=tools): + yield event else: response: AIMessage = await model.ainvoke(messages) - yield AgentMessage(agent=self, message=response) + for event in AgentMessage(agent=self, message=response).all_related_events( + tools=tools + ): + yield event create_markdown_artifact( markdown=f""" @@ -395,6 +403,5 @@ async def _run_model_async( logger.debug(f"Response: {response}") for tool_call in response.tool_calls + response.invalid_tool_calls: - yield ToolCallEvent(agent=self, tool_call=tool_call) result = await handle_tool_call_async(tool_call, tools=tools) - yield ToolResultEvent(agent=self, tool_call=tool_call, tool_result=result) + yield ToolResult(agent=self, tool_result=result) diff --git a/src/controlflow/events/base.py b/src/controlflow/events/base.py index 1ae915d6..aad788cb 100644 --- a/src/controlflow/events/base.py +++ b/src/controlflow/events/base.py @@ -30,7 +30,7 @@ def to_messages(self, context: "CompileContext") -> list["BaseMessage"]: return [] def __repr__(self) -> str: - return f"{self.event} ({self.timestamp})" + return f"<{self.event} {self.timestamp}>" class UnpersistedEvent(Event): diff --git a/src/controlflow/events/events.py b/src/controlflow/events/events.py index 6e5c6d17..b7c9f413 100644 --- a/src/controlflow/events/events.py +++ b/src/controlflow/events/events.py @@ -1,5 +1,6 @@ from typing import TYPE_CHECKING, Literal, Optional, Union +import pydantic_core from pydantic import ConfigDict, field_validator, model_validator from controlflow.agents.agent import Agent @@ -11,7 +12,8 @@ HumanMessage, ToolMessage, ) -from controlflow.tools.tools import InvalidToolCall, ToolCall, ToolResult +from controlflow.tools.tools import InvalidToolCall, Tool, ToolCall +from controlflow.tools.tools import ToolResult as ToolResultPayload from controlflow.utilities.logging import get_logger if TYPE_CHECKING: @@ -55,7 +57,7 @@ class AgentMessage(Event): message: dict @field_validator("message", mode="before") - def _message(cls, v): + def _as_message_dict(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "ai" @@ -70,6 +72,34 @@ def _finalize(self): def ai_message(self) -> AIMessage: return AIMessage(**self.message) + def to_tool_calls(self, tools: list[Tool]) -> list["AgentToolCall"]: + calls = [] + for tool_call in ( + self.message["tool_calls"] + self.message["invalid_tool_calls"] + ): + tool = next((t for t in tools if t.name == tool_call.get("name")), None) + if tool: + calls.append( + AgentToolCall( + agent=self.agent, + tool_call=tool_call, + tool=tool, + args=tool_call["args"], + agent_message_id=self.message.get("id"), + ) + ) + return calls + + def to_content(self) -> "AgentContent": + return AgentContent( + agent=self.agent, + content=self.message["content"], + agent_message_id=self.message.get("id"), + ) + + def all_related_events(self, tools: list[Tool]) -> list[Event]: + return [self, self.to_content()] + self.to_tool_calls(tools) + def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: return [self.ai_message] @@ -87,11 +117,11 @@ class AgentMessageDelta(UnpersistedEvent): event: Literal["agent-message-delta"] = "agent-message-delta" agent: Agent - delta: dict - snapshot: dict + message_delta: dict + message_snapshot: dict - @field_validator("delta", "snapshot", mode="before") - def _message(cls, v): + @field_validator("message_delta", "message_snapshot", mode="before") + def _as_message_dict(cls, v): if isinstance(v, BaseMessage): v = v.model_dump() v["type"] = "AIMessageChunk" @@ -99,50 +129,125 @@ def _message(cls, v): @model_validator(mode="after") def _finalize(self): - self.delta["name"] = self.agent.name - self.snapshot["name"] = self.agent.name + self.message_delta["name"] = self.agent.name + self.message_snapshot["name"] = self.agent.name return self - @property - def delta_message(self) -> AIMessageChunk: - return AIMessageChunk(**self.delta) + def to_tool_call_deltas(self, tools: list[Tool]) -> list["AgentToolCallDelta"]: + deltas = [] + for call_delta in self.message_delta.get("tool_call_chunks", []): + # First match chunks by index because streaming chunks come in sequence (0,1,2...) + # and this index lets us correlate deltas to their snapshots during streaming + chunk_snapshot = next( + ( + c + for c in self.message_snapshot.get("tool_call_chunks", []) + if c.get("index", -1) == call_delta.get("index", -2) + ), + None, + ) + + if chunk_snapshot and chunk_snapshot.get("id"): + # Once we have the matching chunk, use its ID to find the full tool call + # The full tool calls contain properly parsed arguments (as Python dicts) + # while chunks just contain raw JSON strings + call_snapshot = next( + ( + c + for c in self.message_snapshot["tool_calls"] + if c.get("id") == chunk_snapshot["id"] + ), + None, + ) - @property - def snapshot_message(self) -> AIMessage: - return AIMessage(**self.snapshot | {"type": "ai"}) + if call_snapshot: + tool = next( + (t for t in tools if t.name == call_snapshot.get("name")), None + ) + # Use call_snapshot.args which is already parsed into a Python dict + # This avoids issues with pydantic's more limited JSON parser + deltas.append( + AgentToolCallDelta( + agent=self.agent, + tool_call_delta=call_delta, + tool_call_snapshot=call_snapshot, + tool=tool, + args=call_snapshot.get("args", {}), + agent_message_id=self.message_snapshot.get("id"), + ) + ) + return deltas + + def to_content_delta(self) -> "AgentContentDelta": + return AgentContentDelta( + agent=self.agent, + content_delta=self.message_delta["content"], + content_snapshot=self.message_snapshot["content"], + agent_message_id=self.message_snapshot.get("id"), + ) + def all_related_events(self, tools: list[Tool]) -> list[Event]: + return [self, self.to_content_delta()] + self.to_tool_call_deltas(tools) -class EndTurn(Event): - event: Literal["end-turn"] = "end-turn" + +class AgentContent(UnpersistedEvent): + event: Literal["agent-content"] = "agent-content" agent: Agent - next_agent_name: Optional[str] = None + agent_message_id: Optional[str] = None + content: Union[str, list[Union[str, dict]]] -class ToolCallEvent(Event): +class AgentContentDelta(UnpersistedEvent): + event: Literal["agent-content-delta"] = "agent-content-delta" + agent: Agent + agent_message_id: Optional[str] = None + content_delta: Union[str, list[Union[str, dict]]] + content_snapshot: Union[str, list[Union[str, dict]]] + + +class AgentToolCall(Event): event: Literal["tool-call"] = "tool-call" agent: Agent + agent_message_id: Optional[str] = None tool_call: Union[ToolCall, InvalidToolCall] + tool: Optional[Tool] = None + args: dict = {} -class ToolResultEvent(Event): +class AgentToolCallDelta(UnpersistedEvent): + event: Literal["agent-tool-call-delta"] = "agent-tool-call-delta" + agent: Agent + agent_message_id: Optional[str] = None + tool_call_delta: dict + tool_call_snapshot: dict + tool: Optional[Tool] = None + args: dict = {} + + +class EndTurn(Event): + event: Literal["end-turn"] = "end-turn" + agent: Agent + next_agent_name: Optional[str] = None + + +class ToolResult(Event): event: Literal["tool-result"] = "tool-result" agent: Agent - tool_call: Union[ToolCall, InvalidToolCall] - tool_result: ToolResult + tool_result: ToolResultPayload def to_messages(self, context: "CompileContext") -> list[BaseMessage]: if self.agent.name == context.agent.name: return [ ToolMessage( content=self.tool_result.str_result, - tool_call_id=self.tool_call["id"], + tool_call_id=self.tool_result.tool_call["id"], name=self.agent.name, ) ] else: return OrchestratorMessage( prefix=f'Agent "{self.agent.name}" with ID {self.agent.id} made a tool ' - f'call: {self.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' + f'call: {self.tool_result.tool_call}. The tool{" failed and" if self.tool_result.is_error else " "} ' f'produced this result:', content=self.tool_result.str_result, name=self.agent.name, diff --git a/src/controlflow/events/history.py b/src/controlflow/events/history.py index e62cc660..154d5185 100644 --- a/src/controlflow/events/history.py +++ b/src/controlflow/events/history.py @@ -21,7 +21,7 @@ def get_event_validator() -> TypeAdapter: AgentMessage, EndTurn, OrchestratorMessage, - ToolResultEvent, + ToolResult, UserMessage, ) @@ -30,7 +30,7 @@ def get_event_validator() -> TypeAdapter: UserMessage, AgentMessage, EndTurn, - ToolResultEvent, + ToolResult, Event, ] return TypeAdapter(list[types]) diff --git a/src/controlflow/events/message_compiler.py b/src/controlflow/events/message_compiler.py index aff21195..430f026e 100644 --- a/src/controlflow/events/message_compiler.py +++ b/src/controlflow/events/message_compiler.py @@ -8,8 +8,8 @@ from controlflow.events.base import Event, UnpersistedEvent from controlflow.events.events import ( AgentMessage, - ToolCallEvent, - ToolResultEvent, + AgentToolCall, + ToolResult, ) from controlflow.llm.messages import ( AIMessage, @@ -28,8 +28,8 @@ class CombinedAgentMessage(UnpersistedEvent): event: Literal["combined-agent-message"] = "combined-agent-message" agent_message: AgentMessage - tool_call: list[ToolCallEvent] = [] - tool_results: list[ToolResultEvent] = [] + tool_call: list[AgentToolCall] = [] + tool_results: list[ToolResult] = [] def to_messages(self, context: "CompileContext") -> list[BaseMessage]: messages = [] @@ -213,9 +213,9 @@ def organize_events(self, context: CompileContext) -> list[Event]: event.ai_message.tool_calls + event.ai_message.invalid_tool_calls ): tool_calls[tc["id"]] = combined_event - elif isinstance(event, ToolResultEvent): + elif isinstance(event, ToolResult): combined_event: CombinedAgentMessage = tool_calls.get( - event.tool_call["id"] + event.tool_result.tool_call["id"] ) if combined_event: combined_event.tool_results.append(event) diff --git a/src/controlflow/events/orchestrator_events.py b/src/controlflow/events/orchestrator_events.py index 932fe8de..88370f74 100644 --- a/src/controlflow/events/orchestrator_events.py +++ b/src/controlflow/events/orchestrator_events.py @@ -1,41 +1,46 @@ from dataclasses import Field -from typing import Annotated, Literal +from typing import TYPE_CHECKING, Annotated, Literal from pydantic.functional_serializers import PlainSerializer from controlflow.agents.agent import Agent from controlflow.events.base import UnpersistedEvent -from controlflow.orchestration.orchestrator import Orchestrator + +if TYPE_CHECKING: + from controlflow.orchestration.conditions import RunContext + from controlflow.orchestration.orchestrator import Orchestrator class OrchestratorStart(UnpersistedEvent): event: Literal["orchestrator-start"] = "orchestrator-start" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" + run_context: "RunContext" class OrchestratorEnd(UnpersistedEvent): event: Literal["orchestrator-end"] = "orchestrator-end" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" + run_context: "RunContext" class OrchestratorError(UnpersistedEvent): event: Literal["orchestrator-error"] = "orchestrator-error" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" error: Annotated[Exception, PlainSerializer(lambda x: str(x), return_type=str)] class AgentTurnStart(UnpersistedEvent): event: Literal["agent-turn-start"] = "agent-turn-start" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" agent: Agent class AgentTurnEnd(UnpersistedEvent): event: Literal["agent-turn-end"] = "agent-turn-end" persist: bool = False - orchestrator: Orchestrator + orchestrator: "Orchestrator" agent: Agent diff --git a/src/controlflow/handlers/__init__.py b/src/controlflow/handlers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/controlflow/handlers/callback_handler.py b/src/controlflow/handlers/callback_handler.py new file mode 100644 index 00000000..ec73a19f --- /dev/null +++ b/src/controlflow/handlers/callback_handler.py @@ -0,0 +1,24 @@ +""" +A handler that calls a callback function for each event. +""" + +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +from controlflow.events.base import Event +from controlflow.orchestration.handler import AsyncHandler, Handler + + +class CallbackHandler(Handler): + def __init__(self, callback: Callable[[Event], None]): + self.callback = callback + + def on_event(self, event: Event): + self.callback(event) + + +class AsyncCallbackHandler(AsyncHandler): + def __init__(self, callback: Callable[[Event], Coroutine[Any, Any, None]]): + self.callback = callback + + async def on_event(self, event: Event): + await self.callback(event) diff --git a/src/controlflow/handlers/print_handler.py b/src/controlflow/handlers/print_handler.py new file mode 100644 index 00000000..cc7ddf5d --- /dev/null +++ b/src/controlflow/handlers/print_handler.py @@ -0,0 +1,402 @@ +import datetime +from typing import Optional, Union + +import rich +from pydantic import BaseModel +from rich import box +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.panel import Panel +from rich.spinner import Spinner +from rich.table import Table + +from controlflow.events.events import AgentContentDelta, AgentToolCallDelta, ToolResult +from controlflow.events.orchestrator_events import ( + OrchestratorEnd, + OrchestratorError, + OrchestratorStart, +) +from controlflow.orchestration.handler import Handler +from controlflow.tools.tools import Tool +from controlflow.utilities.rich import console as cf_console + +# Global spinner for consistent animation +RUNNING_SPINNER = Spinner("dots") + + +class DisplayState(BaseModel): + """Base class for content to be displayed.""" + + agent_name: str + first_timestamp: datetime.datetime + + def format_timestamp(self) -> str: + """Format the timestamp for display.""" + local_timestamp = self.first_timestamp.astimezone() + return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) + + +class ContentState(DisplayState): + """State for content being streamed.""" + + content: str = "" + + @staticmethod + def _convert_content_to_str(content) -> str: + """Convert various content formats to a string.""" + if isinstance(content, str): + return content + + if isinstance(content, dict): + return content.get("content", content.get("text", "")) + + if isinstance(content, list): + parts = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + part = item.get("content", item.get("text", "")) + if part: + parts.append(part) + return "\n".join(parts) + + return str(content) + + def update_content(self, new_content) -> None: + """Update content, converting complex content types to string.""" + self.content = self._convert_content_to_str(new_content) + + def render_panel(self) -> Panel: + """Render content as a markdown panel.""" + return Panel( + Markdown(self.content), + title=f"[bold]Agent: {self.agent_name}[/]", + subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", + subtitle_align="right", + border_style="blue", + box=box.ROUNDED, + width=100, + padding=(1, 2), + ) + + +class ToolState(DisplayState): + """State for a tool call and its result.""" + + name: str + args: dict + result: Optional[str] = None + is_error: bool = False + is_complete: bool = False + tool: Optional[Tool] = None + + def get_status_style(self) -> tuple[Union[str, Spinner], str, str]: + """Returns (icon, text style, border style) for current status.""" + if self.is_complete: + if self.is_error: + return "❌", "red", "red" + else: + return "✅", "green", "green3" # Slightly softer green + return ( + RUNNING_SPINNER, + "yellow", + "gray50", + ) # Use shared spinner instance + + def render_completion_tool( + self, show_inputs: bool = False, show_outputs: bool = False + ) -> Panel: + """Special rendering for completion tools.""" + table = Table.grid(padding=0, expand=True) + header = Table.grid(padding=1) + header.add_column(width=2) + header.add_column() + + is_success_tool = self.tool.metadata.get("is_success_tool", False) + is_fail_tool = self.tool.metadata.get("is_fail_tool", False) + task = self.tool.metadata.get("completion_task") + task_name = task.friendly_name() if task else "Unknown Task" + # completion tools store their results on the task, rather than returning them directly + task_result = task.result if task else None + + if not self.is_complete: + icon = RUNNING_SPINNER # Use shared spinner instance + message = f"Working on task: {task_name}" + text_style = "dim" + border_style = "gray50" + else: + if self.is_error: + icon = "❌" + message = f"Error marking task status: {task_name}" + text_style = "red" + border_style = "red" + if show_outputs and self.result: + message += f"\nError: {self.result}" + elif is_fail_tool: + icon = "❌" + message = f"Task failed: {task_name}" + text_style = "red" + border_style = "red" + if show_outputs and task_result: + message += f"\nReason: {task_result}" + else: + icon = "✓" + message = f"Task complete: {task_name}" + text_style = "dim" + border_style = "gray50" + + header.add_row(icon, f"[{text_style}]{message}[/]") + table.add_row(header) + + # Show details (streaming args or final result) + if show_outputs and self.args: + details = Table.grid(padding=(0, 2)) + details.add_column(style="dim", width=9) + details.add_column() + + # If complete and successful, show task_result + if ( + self.is_complete + and not self.is_error + and not is_fail_tool + and task_result + ): + label = "Result" if is_success_tool else "Reason" + details.add_row( + f" {label}:", + f"{task_result}", + ) + # Otherwise show streaming args + else: + label = "Result" if is_success_tool else "Reason" + details.add_row( + f" {label}:", + rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), + ) + table.add_row(details) + + return Panel( + table, + title=f"[bold]Agent: {self.agent_name}[/]", + subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", + subtitle_align="right", + border_style=border_style, + box=box.ROUNDED, + width=100, + padding=(0, 1), + ) + + def render_panel( + self, + show_inputs: bool = True, + show_outputs: bool = True, + ) -> Panel: + """Render tool state as a panel with status indicator.""" + if self.tool and self.tool.metadata.get("is_completion_tool"): + return self.render_completion_tool( + show_inputs=show_inputs, show_outputs=show_outputs + ) + + icon, text_style, border_style = self.get_status_style() + table = Table.grid(padding=0, expand=True) + + header = Table.grid(padding=1) + header.add_column(width=2) + header.add_column() + tool_name = self.name.replace("_", " ").title() + header.add_row(icon, f"[{text_style} bold]{tool_name}[/]") + table.add_row(header) + + if show_inputs or show_outputs: + details = Table.grid(padding=(0, 2)) + details.add_column(style="dim", width=9) + details.add_column() + + if show_inputs and self.args: + details.add_row( + " Input:", + rich.pretty.Pretty(self.args, indent_size=2, expand_all=True), + ) + + if show_outputs and self.is_complete and self.result: + label = "Error" if self.is_error else "Output" + style = "red" if self.is_error else "green3" + details.add_row( + f" {label}:", + f"[{style}]{self.result}[/]", + ) + + table.add_row(details) + + return Panel( + table, + title=f"[bold]Agent: {self.agent_name}[/]", + subtitle=f"[italic]{self.format_timestamp()}[/]", + title_align="left", + subtitle_align="right", + border_style=border_style, + box=box.ROUNDED, + width=100, + padding=(0, 1), + ) + + +class PrintHandler(Handler): + def __init__( + self, + show_completion_tools: bool = True, + show_tool_inputs: bool = True, + show_tool_outputs: bool = True, + show_completion_tool_results: bool = False, + ): + super().__init__() + # Tool display settings + self.show_completion_tools = show_completion_tools + self.show_tool_inputs = show_tool_inputs + self.show_tool_outputs = show_tool_outputs + # Completion tool specific settings + self.show_completion_tool_results = show_completion_tool_results + + self.live: Optional[Live] = None + self.paused_id: Optional[str] = None + self.states: dict[str, DisplayState] = {} + + def update_display(self): + """Render all current state as panels and update display.""" + if not self.live or not self.live.is_started or self.paused_id: + return + + sorted_states = sorted(self.states.values(), key=lambda s: s.first_timestamp) + panels = [] + + for state in sorted_states: + if isinstance(state, ToolState): + is_completion_tool = state.tool and state.tool.metadata.get( + "is_completion_tool" + ) + + # Skip completion tools if disabled + if not self.show_completion_tools and is_completion_tool: + continue + + if is_completion_tool: + panels.append( + state.render_completion_tool( + show_outputs=self.show_completion_tool_results + ) + ) + else: + panels.append( + state.render_panel( + show_inputs=self.show_tool_inputs, + show_outputs=self.show_tool_outputs, + ) + ) + else: + panels.append(state.render_panel()) + + if panels: + self.live.update(Group(*panels), refresh=True) + + def on_agent_content_delta(self, event: AgentContentDelta): + """Handle content delta events by updating content state.""" + if not event.content_delta: + return + if event.agent_message_id not in self.states: + state = ContentState( + agent_name=event.agent.name, + first_timestamp=event.timestamp, + ) + state.update_content(event.content_snapshot) + self.states[event.agent_message_id] = state + else: + state = self.states[event.agent_message_id] + if isinstance(state, ContentState): + state.update_content(event.content_snapshot) + + self.update_display() + + def on_agent_tool_call_delta(self, event: AgentToolCallDelta): + """Handle tool call delta events by updating tool state.""" + # Handle CLI input special case + if event.tool_call_snapshot["name"] == "cli_input": + self.paused_id = event.tool_call_snapshot["id"] + if self.live and self.live.is_started: + self.live.stop() + return + + tool_id = event.tool_call_snapshot["id"] + if tool_id not in self.states: + self.states[tool_id] = ToolState( + agent_name=event.agent.name, + first_timestamp=event.timestamp, + name=event.tool_call_snapshot["name"], + args=event.args, + tool=event.tool, + ) + else: + state = self.states[tool_id] + if isinstance(state, ToolState): + state.args = event.args + + self.update_display() + + def on_tool_result(self, event: ToolResult): + """Handle tool result events by updating tool state.""" + # Handle CLI input resume + if event.tool_result.tool_call["name"] == "cli_input": + if self.paused_id == event.tool_result.tool_call["id"]: + self.paused_id = None + print() + self.live = Live( + console=cf_console, + vertical_overflow="visible", + auto_refresh=True, + ) + self.live.start() + return + + # Skip completion tools if disabled + if ( + not self.show_completion_tools + and event.tool_result.tool + and event.tool_result.tool.metadata.get("is_completion_tool") + ): + return + + tool_id = event.tool_result.tool_call["id"] + if tool_id in self.states: + state = self.states[tool_id] + if isinstance(state, ToolState): + state.is_complete = True + state.is_error = event.tool_result.is_error + state.result = event.tool_result.str_result + + self.update_display() + + def on_orchestrator_start(self, event: OrchestratorStart): + """Initialize live display.""" + self.live = Live( + console=cf_console, + vertical_overflow="visible", + auto_refresh=True, + ) + self.states.clear() + try: + self.live.start() + except rich.errors.LiveError: + pass + + def on_orchestrator_end(self, event: OrchestratorEnd): + """Clean up live display.""" + if self.live and self.live.is_started: + self.live.stop() + + def on_orchestrator_error(self, event: OrchestratorError): + """Clean up live display on error.""" + if self.live and self.live.is_started: + self.live.stop() diff --git a/src/controlflow/handlers/queue_handler.py b/src/controlflow/handlers/queue_handler.py new file mode 100644 index 00000000..0441823d --- /dev/null +++ b/src/controlflow/handlers/queue_handler.py @@ -0,0 +1,56 @@ +""" +A handler that queues events in a queue. +""" + +import asyncio +import queue +from typing import TYPE_CHECKING, Any, Callable, Coroutine + +from controlflow.events.base import Event +from controlflow.events.events import ( + AgentMessage, + AgentMessageDelta, + AgentToolCall, + ToolResult, +) +from controlflow.orchestration.handler import AsyncHandler, Handler + + +class QueueHandler(Handler): + def __init__( + self, queue: queue.Queue = None, event_filter: Callable[[Event], bool] = None + ): + self.queue = queue or queue.Queue() + self.event_filter = event_filter + + def on_event(self, event: Event): + if self.event_filter and not self.event_filter(event): + return + self.queue.put(event) + + +class AsyncQueueHandler(AsyncHandler): + def __init__( + self, queue: asyncio.Queue = None, event_filter: Callable[[Event], bool] = None + ): + self.queue = queue or asyncio.Queue() + self.event_filter = event_filter + + async def on_event(self, event: Event): + if self.event_filter and not self.event_filter(event): + return + await self.queue.put(event) + + +def message_filter(event: Event) -> bool: + return isinstance(event, (AgentMessage, AgentMessageDelta)) + + +def tool_filter(event: Event) -> bool: + return isinstance(event, (AgentToolCall, ToolResult)) + + +def result_filter(event: Event) -> bool: + return isinstance(event, (AgentToolCall, ToolResult)) and event.tool_call[ + "name" + ].startswith("mark_task_") diff --git a/src/controlflow/orchestration/handler.py b/src/controlflow/orchestration/handler.py index 9843a744..bd77a772 100644 --- a/src/controlflow/orchestration/handler.py +++ b/src/controlflow/orchestration/handler.py @@ -7,10 +7,10 @@ from controlflow.events.events import ( AgentMessage, AgentMessageDelta, + AgentToolCall, EndTurn, OrchestratorMessage, - ToolCallEvent, - ToolResultEvent, + ToolResult, UserMessage, ) from controlflow.events.orchestrator_events import ( @@ -54,10 +54,10 @@ def on_agent_message(self, event: "AgentMessage"): def on_agent_message_delta(self, event: "AgentMessageDelta"): pass - def on_tool_call(self, event: "ToolCallEvent"): + def on_tool_call(self, event: "AgentToolCall"): pass - def on_tool_result(self, event: "ToolResultEvent"): + def on_tool_result(self, event: "ToolResult"): pass def on_orchestrator_message(self, event: "OrchestratorMessage"): @@ -70,14 +70,6 @@ def on_end_turn(self, event: "EndTurn"): pass -class CallbackHandler(Handler): - def __init__(self, callback: Callable[[Event], None]): - self.callback = callback - - def on_event(self, event: Event): - self.callback(event) - - class AsyncHandler: async def handle(self, event: Event): """ @@ -112,10 +104,10 @@ async def on_agent_message(self, event: "AgentMessage"): async def on_agent_message_delta(self, event: "AgentMessageDelta"): pass - async def on_tool_call(self, event: "ToolCallEvent"): + async def on_tool_call(self, event: "AgentToolCall"): pass - async def on_tool_result(self, event: "ToolResultEvent"): + async def on_tool_result(self, event: "ToolResult"): pass async def on_orchestrator_message(self, event: "OrchestratorMessage"): diff --git a/src/controlflow/orchestration/orchestrator.py b/src/controlflow/orchestration/orchestrator.py index 94292639..10d04269 100644 --- a/src/controlflow/orchestration/orchestrator.py +++ b/src/controlflow/orchestration/orchestrator.py @@ -1,5 +1,5 @@ import logging -from typing import Callable, Optional, TypeVar, Union +from typing import AsyncIterator, Callable, Iterator, Optional, TypeVar, Union from pydantic import BaseModel, Field, field_validator @@ -8,6 +8,13 @@ from controlflow.events.base import Event from controlflow.events.events import AgentMessageDelta, OrchestratorMessage from controlflow.events.message_compiler import MessageCompiler +from controlflow.events.orchestrator_events import ( + AgentTurnEnd, + AgentTurnStart, + OrchestratorEnd, + OrchestratorError, + OrchestratorStart, +) from controlflow.flows import Flow from controlflow.instructions import get_instructions from controlflow.llm.messages import BaseMessage @@ -72,10 +79,15 @@ def _validate_handlers(cls, v): Returns: list[Handler]: The validated list of handlers. """ - from controlflow.orchestration.print_handler import PrintHandler + from controlflow.handlers.print_handler import PrintHandler if v is None and controlflow.settings.enable_default_print_handler: - v = [PrintHandler()] + v = [ + PrintHandler( + show_completion_tools=controlflow.settings.default_print_handler_show_completion_tools, + show_completion_tool_results=controlflow.settings.default_print_handler_show_completion_tool_results, + ) + ] return v or [] def handle_event(self, event: Event): @@ -85,8 +97,8 @@ def handle_event(self, event: Event): Args: event (Event): The event to handle. """ - if not isinstance(event, AgentMessageDelta): - logger.debug(f"Handling event: {repr(event)}") + from controlflow.events.events import AgentContentDelta + for handler in self.handlers: if isinstance(handler, Handler): handler.handle(event) @@ -163,6 +175,56 @@ def get_memories(self) -> list[Memory]: return memories + def _run_agent_turn( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run a single agent turn, yielding events as they occur.""" + assigned_tasks = self.get_tasks("assigned") + + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + yield OrchestratorMessage( + content=f"Starting task {task.name + ' ' if task.name else ''}(ID {task.id}) " + f"with objective: {task.objective}" + ) + + while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + if run_context.should_end(): + break + + messages = self.compile_messages() + tools = self.get_tools() + + # Run model and yield events + for event in self.agent._run_model( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event + + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 + + run_context.agent_turns += 1 + @prefect_task(task_run_name="Orchestrator.run()") def run( self, @@ -173,9 +235,11 @@ def run( Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, ) -> RunContext: - import controlflow.events.orchestrator_events - - # Create the base termination condition + """ + Run the orchestrator, handling events internally. + Returns the final run context. + """ + # Create run context at the outermost level if run_until is None: run_until = AllComplete() elif not isinstance(run_until, RunEndCondition): @@ -193,6 +257,19 @@ def run( run_context = RunContext(orchestrator=self, run_end_condition=run_until) + for event in self._run( + run_context=run_context, + model_kwargs=model_kwargs, + ): + self.handle_event(event) + return run_context + + def _run( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> Iterator[Event]: + """Run the orchestrator, yielding events as they occur.""" # Initialize the agent if not already set if not self.agent: self.agent = self.turn_strategy.get_next_agent( @@ -200,29 +277,23 @@ def run( ) # Signal the start of orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) - ) + yield OrchestratorStart(orchestrator=self, run_context=run_context) try: while True: if run_context.should_end(): break - self.handle_event( - controlflow.events.orchestrator_events.AgentTurnStart( - orchestrator=self, agent=self.agent - ) - ) - self.run_agent_turn( + yield AgentTurnStart(orchestrator=self, agent=self.agent) + + # Run turn and yield its events + for event in self._run_agent_turn( run_context=run_context, model_kwargs=model_kwargs, - ) - self.handle_event( - controlflow.events.orchestrator_events.AgentTurnEnd( - orchestrator=self, agent=self.agent - ) - ) + ): + yield event + + yield AgentTurnEnd(orchestrator=self, agent=self.agent) # Select the next agent for the following turn if available_agents := self.get_available_agents(): @@ -231,21 +302,12 @@ def run( ) except Exception as exc: - # Handle any exceptions that occur during orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorError( - orchestrator=self, error=exc - ) - ) + # Yield error event if something goes wrong + yield OrchestratorError(orchestrator=self, error=exc) raise finally: # Signal the end of orchestration - self.handle_event( - controlflow.events.orchestrator_events.OrchestratorEnd( - orchestrator=self - ) - ) - return run_context + yield OrchestratorEnd(orchestrator=self, run_context=run_context) @prefect_task async def run_async( @@ -257,9 +319,11 @@ async def run_async( Union[RunEndCondition, Callable[[RunContext], bool]] ] = None, ) -> RunContext: - import controlflow.events.orchestrator_events - - # Create the base termination condition + """ + Run the orchestrator asynchronously, handling events internally. + Returns the final run context. + """ + # Create run context at the outermost level if run_until is None: run_until = AllComplete() elif not isinstance(run_until, RunEndCondition): @@ -277,58 +341,11 @@ async def run_async( run_context = RunContext(orchestrator=self, run_end_condition=run_until) - # Initialize the agent if not already set - if not self.agent: - self.agent = self.turn_strategy.get_next_agent( - None, self.get_available_agents() - ) - - # Signal the start of orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorStart(orchestrator=self) - ) - - try: - while True: - if run_context.should_end(): - break - - await self.handle_event_async( - controlflow.events.orchestrator_events.AgentTurnStart( - orchestrator=self, agent=self.agent - ) - ) - await self.run_agent_turn_async( - run_context=run_context, - model_kwargs=model_kwargs, - ) - await self.handle_event_async( - controlflow.events.orchestrator_events.AgentTurnEnd( - orchestrator=self, agent=self.agent - ) - ) - - # Select the next agent for the following turn - if available_agents := self.get_available_agents(): - self.agent = self.turn_strategy.get_next_agent( - self.agent, available_agents - ) - - except Exception as exc: - # Handle any exceptions that occur during orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorError( - orchestrator=self, error=exc - ) - ) - raise - finally: - # Signal the end of orchestration - await self.handle_event_async( - controlflow.events.orchestrator_events.OrchestratorEnd( - orchestrator=self - ) - ) + async for event in self._run_async( + run_context=run_context, + model_kwargs=model_kwargs, + ): + await self.handle_event_async(event) return run_context @prefect_task(task_run_name="Agent turn: {self.agent.name}") @@ -576,5 +593,105 @@ def get_task_hierarchy(self) -> dict: return hierarchy + async def _run_agent_turn_async( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> AsyncIterator[Event]: + """Async version of _run_agent_turn.""" + assigned_tasks = self.get_tasks("assigned") + + self.turn_strategy.begin_turn() + + # Mark assigned tasks as running + for task in assigned_tasks: + if not task.is_running(): + task.mark_running() + yield OrchestratorMessage( + content=f"Starting task {task.name} (ID {task.id}) " + f"with objective: {task.objective}" + ) + + while not self.turn_strategy.should_end_turn(): + # fail any tasks that have reached their max llm calls + for task in assigned_tasks: + if task.max_llm_calls and task._llm_calls >= task.max_llm_calls: + task.mark_failed(reason="Max LLM calls reached for this task.") + + # Check if there are any ready tasks left + if not any(t.is_ready() for t in assigned_tasks): + logger.debug("No `ready` tasks to run") + break + + if run_context.should_end(): + break + + messages = self.compile_messages() + tools = self.get_tools() + + async for event in self.agent._run_model_async( + messages=messages, + tools=tools, + model_kwargs=model_kwargs, + ): + yield event + + run_context.llm_calls += 1 + for task in assigned_tasks: + task._llm_calls += 1 + + run_context.agent_turns += 1 + + async def _run_async( + self, + run_context: RunContext, + model_kwargs: Optional[dict] = None, + ) -> AsyncIterator[Event]: + """Run the orchestrator asynchronously, yielding events as they occur.""" + # Initialize the agent if not already set + if not self.agent: + self.agent = self.turn_strategy.get_next_agent( + None, self.get_available_agents() + ) + + # Signal the start of orchestration + yield OrchestratorStart(orchestrator=self, run_context=run_context) + + try: + while True: + if run_context.should_end(): + break + + yield AgentTurnStart(orchestrator=self, agent=self.agent) + + # Run turn and yield its events + async for event in self._run_agent_turn_async( + run_context=run_context, + model_kwargs=model_kwargs, + ): + yield event + + yield AgentTurnEnd(orchestrator=self, agent=self.agent) + + # Select the next agent for the following turn + if available_agents := self.get_available_agents(): + self.agent = self.turn_strategy.get_next_agent( + self.agent, available_agents + ) + + except Exception as exc: + # Yield error event if something goes wrong + yield OrchestratorError(orchestrator=self, error=exc) + raise + finally: + # Signal the end of orchestration + yield OrchestratorEnd(orchestrator=self, run_context=run_context) + +# Rebuild all models with forward references after Orchestrator is defined +OrchestratorStart.model_rebuild() +OrchestratorEnd.model_rebuild() +OrchestratorError.model_rebuild() +AgentTurnStart.model_rebuild() +AgentTurnEnd.model_rebuild() RunContext.model_rebuild() diff --git a/src/controlflow/orchestration/print_handler.py b/src/controlflow/orchestration/print_handler.py deleted file mode 100644 index 2d05918c..00000000 --- a/src/controlflow/orchestration/print_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -import datetime -from typing import Union - -import rich -from rich import box -from rich.console import Group -from rich.live import Live -from rich.markdown import Markdown -from rich.panel import Panel -from rich.spinner import Spinner -from rich.table import Table - -import controlflow -from controlflow.events.base import Event -from controlflow.events.events import ( - AgentMessage, - AgentMessageDelta, - ToolCallEvent, - ToolResultEvent, -) -from controlflow.events.orchestrator_events import ( - OrchestratorEnd, - OrchestratorError, - OrchestratorStart, -) -from controlflow.llm.messages import BaseMessage -from controlflow.orchestration.handler import Handler -from controlflow.tools.tools import ToolCall -from controlflow.utilities.rich import console as cf_console - - -class PrintHandler(Handler): - def __init__(self, include_completion_tools: bool = True): - self.events: dict[str, Event] = {} - self.paused_id: str = None - self.include_completion_tools = include_completion_tools - super().__init__() - - def update_live(self, latest: BaseMessage = None): - events = sorted(self.events.items(), key=lambda e: (e[1].timestamp, e[0])) - content = [] - - tool_results = {} # To track tool results by their call ID - - # gather all tool events first - for _, event in events: - if isinstance(event, ToolResultEvent): - tool_results[event.tool_call["id"]] = event - - for _, event in events: - if isinstance(event, (AgentMessageDelta, AgentMessage)): - if formatted := format_event(event, tool_results=tool_results): - content.append(formatted) - - if not content: - return - elif self.live.is_started: - self.live.update(Group(*content), refresh=True) - elif latest: - cf_console.print(format_event(latest)) - - def on_orchestrator_start(self, event: OrchestratorStart): - self.live: Live = Live( - auto_refresh=False, console=cf_console, vertical_overflow="visible" - ) - self.events.clear() - try: - self.live.start() - except rich.errors.LiveError: - pass - - def on_orchestrator_end(self, event: OrchestratorEnd): - self.live.stop() - - def on_orchestrator_error(self, event: OrchestratorError): - self.live.stop() - - def on_agent_message_delta(self, event: AgentMessageDelta): - self.events[event.snapshot_message.id] = event - self.update_live() - - def on_agent_message(self, event: AgentMessage): - self.events[event.ai_message.id] = event - self.update_live() - - def on_tool_call(self, event: ToolCallEvent): - # if collecting input on the terminal, pause the live display - # to avoid overwriting the input prompt - if event.tool_call["name"] == "cli_input": - self.paused_id = event.tool_call["id"] - self.live.stop() - self.events.clear() - - def on_tool_result(self, event: ToolResultEvent): - # skip completion tools if configured to do so - if not self.include_completion_tools and event.tool_result.tool_metadata.get( - "is_completion_tool" - ): - return - - self.events[f"tool-result:{event.tool_call['id']}"] = event - - # # if we were paused, resume the live display - if self.paused_id and self.paused_id == event.tool_call["id"]: - self.paused_id = None - # print newline to avoid odd formatting issues - print() - self.live = Live(auto_refresh=False) - self.live.start() - self.update_live(latest=event) - - -ROLE_COLORS = { - "system": "gray", - "ai": "blue", - "user": "green", -} -ROLE_NAMES = { - "system": "System", - "ai": "Agent", - "user": "User", -} - - -def format_timestamp(timestamp: datetime.datetime) -> str: - local_timestamp = timestamp.astimezone() - return local_timestamp.strftime("%I:%M:%S %p").lstrip("0").rjust(11) - - -def status(icon, text) -> Table: - t = Table.grid(padding=1) - t.add_row(icon, text) - return t - - -def format_event( - event: Union[AgentMessageDelta, AgentMessage], - tool_results: dict[str, ToolResultEvent] = None, -) -> Panel: - title = f"Agent: {event.agent.name}" - - content = [] - if isinstance(event, AgentMessageDelta): - message = event.snapshot_message - elif isinstance(event, AgentMessage): - message = event.ai_message - else: - return - - if message.content: - if isinstance(message.content, str): - content.append(Markdown(str(message.content))) - elif isinstance(message.content, dict): - if "content" in message.content: - content.append(Markdown(str(message.content["content"]))) - elif "text" in message.content: - content.append(Markdown(str(message.content["text"]))) - elif isinstance(message.content, list): - for item in message.content: - if isinstance(item, str): - content.append(Markdown(str(item))) - elif "content" in item: - content.append(Markdown(str(item["content"]))) - elif "text" in item: - content.append(Markdown(str(item["text"]))) - - tool_content = [] - for tool_call in message.tool_calls + message.invalid_tool_calls: - tool_result = (tool_results or {}).get(tool_call["id"]) - if tool_result: - c = format_tool_result(tool_result) - else: - c = format_tool_call(tool_call) - if c: - tool_content.append(c) - - if content and tool_content: - content.append("\n") - - return Panel( - Group(*content, *tool_content), - title=f"[bold]{title}[/]", - subtitle=f"[italic]{format_timestamp(event.timestamp)}[/]", - title_align="left", - subtitle_align="right", - border_style=ROLE_COLORS.get("ai", "red"), - box=box.ROUNDED, - width=100, - expand=True, - padding=(1, 2), - ) - - -def format_tool_call(tool_call: ToolCall) -> Panel: - if controlflow.settings.tools_verbose: - return status( - Spinner("dots"), - f'Tool call: "{tool_call["name"]}"\n\nTool args: {tool_call["args"]}', - ) - return status(Spinner("dots"), f'Tool call: "{tool_call["name"]}"') - - -def format_tool_result(event: ToolResultEvent) -> Panel: - if event.tool_result.is_error: - icon = ":x:" - else: - icon = ":white_check_mark:" - - if controlflow.settings.tools_verbose: - msg = f'Tool call: "{event.tool_call["name"]}"\n\nTool args: {event.tool_call["args"]}\n\nTool result: {event.tool_result.str_result}' - else: - msg = f'Tool call: "{event.tool_call["name"]}"' - return status(icon, msg) diff --git a/src/controlflow/settings.py b/src/controlflow/settings.py index c7f839c3..4f4c8b0f 100644 --- a/src/controlflow/settings.py +++ b/src/controlflow/settings.py @@ -73,7 +73,15 @@ def _validate_pretty_print_agent_events(cls, data: dict) -> dict: enable_default_print_handler: bool = Field( default=True, description="If True, a PrintHandler will be enabled and automatically " - "pretty-print agent events. Note that this may interfere with logging.", + "pretty-print agent events and completion tools.", + ) + default_print_handler_show_completion_tools: bool = Field( + default=True, + description="If True, the default PrintHandler will include completion tools.", + ) + default_print_handler_show_completion_tool_results: bool = Field( + default=False, + description="If True, the default PrintHandler will show the full results of completion tools.", ) # ------------ orchestration settings ------------ diff --git a/src/controlflow/stream.py b/src/controlflow/stream.py new file mode 100644 index 00000000..e8552ebf --- /dev/null +++ b/src/controlflow/stream.py @@ -0,0 +1,209 @@ +# Example usage +# +# # Stream all events +# for event in cf.stream.events("Write a story"): +# print(event) +# +# # Stream just messages +# for event in cf.stream.events("Write a story", events='messages'): +# print(event.content) +# +# # Stream just the result +# for delta, snapshot in cf.stream.result("Write a story"): +# print(f"New: {delta}") +# +# # Stream results from multiple tasks +# for delta, snapshot in cf.stream.result_from_tasks([task1, task2]): +# print(f"New result: {delta}") +# +from typing import Any, AsyncIterator, Callable, Iterator, Literal, Optional, Union + +from controlflow.events.base import Event +from controlflow.events.events import ( + AgentContent, + AgentContentDelta, + AgentMessage, + AgentMessageDelta, + AgentToolCall, + AgentToolCallDelta, + ToolResult, +) +from controlflow.orchestration.handler import AsyncHandler, Handler +from controlflow.orchestration.orchestrator import Orchestrator +from controlflow.tasks.task import Task + +StreamEvents = Union[ + list[str], + Literal["all", "messages", "content", "tools", "completion_tools", "agent_tools"], +] + + +def event_filter(events: StreamEvents) -> Callable[[Event], bool]: + def _event_filter(event: Event) -> bool: + if events == "all": + return True + elif events == "messages": + return isinstance(event, (AgentMessage, AgentMessageDelta)) + elif events == "content": + return isinstance(event, (AgentContent, AgentContentDelta)) + elif events == "tools": + return isinstance(event, (AgentToolCall, AgentToolCallDelta, ToolResult)) + elif events == "completion_tools": + if isinstance(event, (AgentToolCall, AgentToolCallDelta)): + return event.tool and event.tool.metadata.get("is_completion_tool") + elif isinstance(event, ToolResult): + return event.tool_result and event.tool_result.tool.metadata.get( + "is_completion_tool" + ) + return False + elif events == "agent_tools": + if isinstance(event, (AgentToolCall, AgentToolCallDelta)): + return event.tool and event.tool in event.agent.get_tools() + elif isinstance(event, ToolResult): + return ( + event.tool_result + and event.tool_result.tool in event.agent.get_tools() + ) + return False + else: + raise ValueError(f"Invalid event type: {events}") + + return _event_filter + + +# -------------------- BELOW HERE IS THE OLD STUFF -------------------- + + +def events( + objective: str, + *, + events: StreamEvents = "all", + filter_fn: Optional[Callable[[Event], bool]] = None, + **kwargs, +) -> Iterator[Event]: + """ + Stream events from a task execution. + + Args: + objective: The task objective + events: Which events to stream. Can be list of event types or: + 'all' - all events + 'messages' - agent messages + 'tools' - all tool calls/results + 'completion_tools' - only completion tools + filter_fn: Optional additional filter function + **kwargs: Additional arguments passed to Task + + Returns: + Iterator of Event objects + """ + + def get_event_filter(): + if isinstance(events, list): + return lambda e: e.event in events + elif events == "messages": + return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) + elif events == "tools": + return lambda e: isinstance(e, (AgentToolCall, ToolResult)) + elif events == "completion_tools": + return lambda e: ( + isinstance(e, (AgentToolCall, ToolResult)) + and e.tool_call["name"].startswith("mark_task_") + ) + else: # 'all' + return lambda e: True + + event_filter = get_event_filter() + + def event_handler(event: Event): + if event_filter(event) and (not filter_fn or filter_fn(event)): + yield event + + task = Task(objective=objective) + task.run(handlers=[Handler(event_handler)], **kwargs) + + +def result( + objective: str, + **kwargs, +) -> Iterator[tuple[Any, Any]]: + """ + Stream result from a task execution. + + Args: + objective: The task objective + **kwargs: Additional arguments passed to Task + + Returns: + Iterator of (delta, accumulated) result tuples + """ + current_result = None + + def result_handler(event: Event): + nonlocal current_result + if isinstance(event, ToolResult): + if event.tool_call["name"].startswith("mark_task_"): + result = event.tool_result.result # Get actual result value + if result != current_result: # Only yield if changed + current_result = result + yield (result, result) # For now delta == full result + + task = Task(objective=objective) + task.run(handlers=[Handler(result_handler)], **kwargs) + + +def events_from_tasks( + tasks: list[Task], + events: StreamEvents = "all", + filter_fn: Optional[Callable[[Event], bool]] = None, + **kwargs, +) -> Iterator[Event]: + """Stream events from multiple task executions.""" + + def get_event_filter(): + if isinstance(events, list): + return lambda e: e.event in events + elif events == "messages": + return lambda e: isinstance(e, (AgentMessage, AgentMessageDelta)) + elif events == "tools": + return lambda e: isinstance(e, (AgentToolCall, ToolResult)) + elif events == "completion_tools": + return lambda e: ( + isinstance(e, (AgentToolCall, ToolResult)) + and e.tool_call["name"].startswith("mark_task_") + ) + else: # 'all' + return lambda e: True + + event_filter = get_event_filter() + + def event_handler(event: Event): + if event_filter(event) and (not filter_fn or filter_fn(event)): + yield event + + orchestrator = Orchestrator( + tasks=tasks, handlers=[Handler(event_handler)], **kwargs + ) + orchestrator.run() + + +def result_from_tasks( + tasks: list[Task], + **kwargs, +) -> Iterator[tuple[Any, Any]]: + """Stream results from multiple task executions.""" + current_results = {task.id: None for task in tasks} + + def result_handler(event: Event): + if isinstance(event, ToolResult): + if event.tool_call["name"].startswith("mark_task_"): + task_id = event.task.id + result = event.tool_result.result + if result != current_results[task_id]: + current_results[task_id] = result + yield (result, result) + + orchestrator = Orchestrator( + tasks=tasks, handlers=[Handler(result_handler)], **kwargs + ) + orchestrator.run() diff --git a/src/controlflow/tasks/task.py b/src/controlflow/tasks/task.py index ad36127c..2ea43730 100644 --- a/src/controlflow/tasks/task.py +++ b/src/controlflow/tasks/task.py @@ -583,7 +583,11 @@ def get_success_tool(self) -> Tool: """ options = {} instructions = [] - metadata = {"is_completion_tool": True} + metadata = { + "is_completion_tool": True, + "is_success_tool": True, + "completion_task": self, + } result_schema = None # if the result_type is a tuple of options, then we want the LLM to provide @@ -714,6 +718,11 @@ def get_fail_tool(self) -> Tool: failure.""" ), include_return_description=False, + metadata={ + "is_completion_tool": True, + "is_fail_tool": True, + "completion_task": self, + }, ) def fail(reason: str) -> str: self.mark_failed(reason=reason) diff --git a/src/controlflow/tools/tools.py b/src/controlflow/tools/tools.py index 224cfa31..cce08fc4 100644 --- a/src/controlflow/tools/tools.py +++ b/src/controlflow/tools/tools.py @@ -298,16 +298,16 @@ def output_to_string(output: Any) -> str: class ToolResult(ControlFlowModel): - tool_call_id: str + tool_call: Union[ToolCall, InvalidToolCall] + tool: Optional[Tool] = None result: Any = Field(exclude=True, repr=False) str_result: str = Field(repr=False) is_error: bool = False - tool_metadata: dict = {} def handle_tool_call( tool_call: Union[ToolCall, InvalidToolCall], tools: list[Tool] -) -> Any: +) -> ToolResult: """ Given a ToolCall and set of available tools, runs the tool call and returns a ToolResult object @@ -340,15 +340,15 @@ def handle_tool_call( raise exc return ToolResult( - tool_call_id=tool_call["id"], + tool_call=tool_call, + tool=tool, result=fn_output, str_result=output_to_string(fn_output), is_error=is_error, - tool_metadata=tool.metadata if tool else {}, ) -async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any: +async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> ToolResult: """ Given a ToolCall and set of available tools, runs the tool call and returns a ToolResult object @@ -381,9 +381,9 @@ async def handle_tool_call_async(tool_call: ToolCall, tools: list[Tool]) -> Any: raise exc return ToolResult( - tool_call_id=tool_call["id"], + tool_call=tool_call, + tool=tool, result=fn_output, str_result=output_to_string(fn_output), is_error=is_error, - tool_metadata=tool.metadata if tool else {}, ) diff --git a/tests/utilities/test_testing.py b/tests/utilities/test_testing.py index 2a4e78fb..380bdd5f 100644 --- a/tests/utilities/test_testing.py +++ b/tests/utilities/test_testing.py @@ -37,15 +37,15 @@ def test_record_task_events(default_fake_llm): assert response == events[1].ai_message assert events[3].event == "tool-result" - assert events[3].tool_call == { + assert events[3].tool_result.tool_call == { "name": "mark_task_12345_successful", "args": {"task_result": "Hello!"}, "id": "call_ZEPdV8mCgeBe5UHjKzm6e3pe", "type": "tool_call", } - assert events[3].tool_result.model_dump() == dict( - tool_call_id="call_ZEPdV8mCgeBe5UHjKzm6e3pe", - str_result='Task #12345 ("say hello") marked successful.', - is_error=False, - tool_metadata={"is_completion_tool": True}, - ) + tool_result = events[3].tool_result.model_dump() + assert tool_result["tool_call"]["id"] == "call_ZEPdV8mCgeBe5UHjKzm6e3pe" + assert tool_result["str_result"] == 'Task #12345 ("say hello") marked successful.' + assert not tool_result["is_error"] + assert tool_result["tool"]["metadata"]["is_completion_tool"] + assert tool_result["tool"]["metadata"]["is_success_tool"]