Skip to content

Commit

Permalink
feat: introduce ModelClientStreamingChunkEvent for streaming model ou…
Browse files Browse the repository at this point in the history
…tput and update handling in agents and console (microsoft#5208)

Resolves microsoft#3983

* introduce `model_client_stream` parameter in `AssistantAgent` to
enable token-level streaming output.
* introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent`
to pass the streaming chunks to the application via `run_stream` and
`on_messages_stream`. Although this will not affect the inner messages
list in the final `Response` or `TaskResult`.
* handle this new message type in `Console`.
  • Loading branch information
ekzhu authored Jan 29, 2025
1 parent 8a0daf8 commit 225eb9d
Show file tree
Hide file tree
Showing 13 changed files with 330 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from autogen_core.models import (
AssistantMessage,
ChatCompletionClient,
CreateResult,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.tools import FunctionTool, BaseTool
from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel
from typing_extensions import Self

Expand All @@ -40,6 +41,7 @@
ChatMessage,
HandoffMessage,
MemoryQueryEvent,
ModelClientStreamingChunkEvent,
MultiModalMessage,
TextMessage,
ToolCallExecutionEvent,
Expand All @@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel):
model_context: ComponentModel | None = None
description: str
system_message: str | None = None
model_client_stream: bool
reflect_on_tool_use: bool
tool_call_summary_format: str

Expand Down Expand Up @@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
This will limit the number of recent messages sent to the model and can be useful
when the model has a limit on the number of tokens it can process.
Streaming mode:
The assistant agent can be used in streaming mode by setting `model_client_stream=True`.
In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield
:class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response.
The chunk messages will not be included in the final response's inner messages.
Args:
name (str): The name of the agent.
Expand All @@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]):
model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset.
description (str, optional): The description of the agent.
system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable.
model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode.
:meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`
messages as the model client produces chunks of response. Defaults to `False`.
reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result
to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`.
tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result.
Expand Down Expand Up @@ -268,12 +282,14 @@ def __init__(
system_message: (
str | None
) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.",
model_client_stream: bool = False,
reflect_on_tool_use: bool = False,
tool_call_summary_format: str = "{result}",
memory: Sequence[Memory] | None = None,
):
super().__init__(name=name, description=description)
self._model_client = model_client
self._model_client_stream = model_client_stream
self._memory = None
if memory is not None:
if isinstance(memory, list):
Expand Down Expand Up @@ -340,7 +356,7 @@ def __init__(

@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
"""The types of messages that the assistant agent produces."""
"""The types of final response messages that the assistant agent produces."""
message_types: List[type[ChatMessage]] = [TextMessage]
if self._handoffs:
message_types.append(HandoffMessage)
Expand Down Expand Up @@ -383,9 +399,23 @@ async def on_messages_stream(

# Generate an inference result based on the current model context.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)
model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(model_result, CreateResult)
else:
model_result = await self._model_client.create(
llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token
)

# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
Expand Down Expand Up @@ -465,14 +495,34 @@ async def on_messages_stream(
if self._reflect_on_tool_use:
# Generate another inference result based on the tool call and result.
llm_messages = self._system_messages + await self._model_context.get_messages()
model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token)
assert isinstance(model_result.content, str)
reflection_model_result: CreateResult | None = None
if self._model_client_stream:
# Stream the model client.
async for chunk in self._model_client.create_stream(
llm_messages, cancellation_token=cancellation_token
):
if isinstance(chunk, CreateResult):
reflection_model_result = chunk
elif isinstance(chunk, str):
yield ModelClientStreamingChunkEvent(content=chunk, source=self.name)
else:
raise RuntimeError(f"Invalid chunk type: {type(chunk)}")
assert isinstance(reflection_model_result, CreateResult)
else:
reflection_model_result = await self._model_client.create(
llm_messages, cancellation_token=cancellation_token
)
assert isinstance(reflection_model_result.content, str)
# Add the response to the model context.
await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name))
await self._model_context.add_message(
AssistantMessage(content=reflection_model_result.content, source=self.name)
)
# Yield the response.
yield Response(
chat_message=TextMessage(
content=model_result.content, source=self.name, models_usage=model_result.usage
content=reflection_model_result.content,
source=self.name,
models_usage=reflection_model_result.usage,
),
inner_messages=inner_messages,
)
Expand Down Expand Up @@ -538,6 +588,7 @@ def _to_config(self) -> AssistantAgentConfig:
system_message=self._system_messages[0].content
if self._system_messages and isinstance(self._system_messages[0].content, str)
else None,
model_client_stream=self._model_client_stream,
reflect_on_tool_use=self._reflect_on_tool_use,
tool_call_summary_format=self._tool_call_summary_format,
)
Expand All @@ -553,6 +604,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self:
model_context=None,
description=config.description,
system_message=config.system_message,
model_client_stream=config.model_client_stream,
reflect_on_tool_use=config.reflect_on_tool_use,
tool_call_summary_format=config.tool_call_summary_format,
)
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AgentEvent,
BaseChatMessage,
ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
from ..state import BaseState
Expand Down Expand Up @@ -178,8 +179,11 @@ async def run_stream(
output_messages.append(message.chat_message)
yield TaskResult(messages=output_messages)
else:
output_messages.append(message)
yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message)

@abstractmethod
async def on_reset(self, cancellation_token: CancellationToken) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AgentEvent,
BaseChatMessage,
ChatMessage,
ModelClientStreamingChunkEvent,
TextMessage,
)
from ._base_chat_agent import BaseChatAgent
Expand Down Expand Up @@ -150,6 +151,9 @@ async def on_messages_stream(
# Skip the task messages.
continue
yield inner_msg
if isinstance(inner_msg, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
inner_messages.append(inner_msg)
assert result is not None

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import Any, Dict

from autogen_core.tools import FunctionTool, BaseTool
from autogen_core.tools import BaseTool, FunctionTool
from pydantic import BaseModel, Field, model_validator

from .. import EVENT_LOGGER_NAME
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,27 @@ class MemoryQueryEvent(BaseAgentEvent):
type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent"


class ModelClientStreamingChunkEvent(BaseAgentEvent):
"""An event signaling a text output chunk from a model client in streaming mode."""

content: str
"""The partial text chunk."""

type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent"


ChatMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type")
]
"""Messages for agent-to-agent communication only."""


AgentEvent = Annotated[
ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent,
ToolCallRequestEvent
| ToolCallExecutionEvent
| MemoryQueryEvent
| UserInputRequestedEvent
| ModelClientStreamingChunkEvent,
Field(discriminator="type"),
]
"""Events emitted by agents and teams when they work, not used for agent-to-agent communication."""
Expand All @@ -154,4 +167,5 @@ class MemoryQueryEvent(BaseAgentEvent):
"ToolCallSummaryMessage",
"MemoryQueryEvent",
"UserInputRequestedEvent",
"ModelClientStreamingChunkEvent",
]
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage
from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
Expand Down Expand Up @@ -190,6 +190,9 @@ async def run(
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
Expand Down Expand Up @@ -279,16 +282,25 @@ async def run_stream(
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
of the type :class:`TaskResult` as the last item in the stream. Once the
of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the
team is stopped, the termination condition is reset.
.. note::
If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`,
the message will be yielded in the stream but it will not be included in the
:attr:`~autogen_agentchat.base.TaskResult.messages`.
Args:
task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`.
cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately.
Setting the cancellation token potentially put the team in an inconsistent state,
and it may not reset the termination condition.
To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead.
Returns:
stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream.
Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team:
.. code-block:: python
Expand Down Expand Up @@ -422,6 +434,9 @@ async def stop_runtime() -> None:
if message is None:
break
yield message
if isinstance(message, ModelClientStreamingChunkEvent):
# Skip the model client streaming chunk events.
continue
output_messages.append(message)

# Yield the final result.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@

from autogen_agentchat.agents import UserProxyAgent
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent
from autogen_agentchat.messages import (
AgentEvent,
ChatMessage,
ModelClientStreamingChunkEvent,
MultiModalMessage,
UserInputRequestedEvent,
)


def _is_running_in_iterm() -> bool:
Expand Down Expand Up @@ -106,6 +112,8 @@ async def Console(

last_processed: Optional[T] = None

streaming_chunks: List[str] = []

async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
Expand Down Expand Up @@ -159,13 +167,28 @@ async def Console(
else:
# Cast required for mypy to be happy
message = cast(AgentEvent | ChatMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
if message.models_usage:
if output_stats:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
await aprint(output, end="")
if not streaming_chunks:
# Print message sender.
await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n")
if isinstance(message, ModelClientStreamingChunkEvent):
await aprint(message.content, end="")
streaming_chunks.append(message.content)
else:
if streaming_chunks:
streaming_chunks.clear()
# Chunked messages are already printed, so we just print a newline.
await aprint("", end="\n")
else:
# Print message content.
await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n")
if message.models_usage:
if output_stats:
await aprint(
f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]",
end="\n",
)
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")
Expand Down
Loading

0 comments on commit 225eb9d

Please sign in to comment.