From 225eb9d0b205576dba1cfc97856d05ccce5ab71c Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 28 Jan 2025 18:49:02 -0800 Subject: [PATCH] feat: introduce ModelClientStreamingChunkEvent for streaming model output and update handling in agents and console (#5208) Resolves #3983 * introduce `model_client_stream` parameter in `AssistantAgent` to enable token-level streaming output. * introduce `ModelClientStreamingChunkEvent` as a type of `AgentEvent` to pass the streaming chunks to the application via `run_stream` and `on_messages_stream`. Although this will not affect the inner messages list in the final `Response` or `TaskResult`. * handle this new message type in `Console`. --- .../agents/_assistant_agent.py | 70 +++++++++-- .../agents/_base_chat_agent.py | 6 +- .../agents/_society_of_mind_agent.py | 4 + .../src/autogen_agentchat/base/_handoff.py | 2 +- .../src/autogen_agentchat/messages.py | 16 ++- .../teams/_group_chat/_base_group_chat.py | 19 ++- .../src/autogen_agentchat/ui/_console.py | 39 ++++-- .../tests/test_assistant_agent.py | 66 ++++++++++- .../tutorial/agents.ipynb | 111 ++++++++++++++++++ .../replay/_replay_chat_completion_client.py | 7 +- .../src/autogen_ext/ui/_rich_console.py | 4 + .../models/test_chat_completion_cache.py | 4 +- .../test_reply_chat_completion_client.py | 14 ++- 13 files changed, 330 insertions(+), 32 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index b1144d9c4667..6c75214c0baa 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -22,13 +22,14 @@ from autogen_core.models import ( AssistantMessage, ChatCompletionClient, + CreateResult, FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, SystemMessage, UserMessage, ) -from autogen_core.tools import FunctionTool, BaseTool +from autogen_core.tools import BaseTool, FunctionTool from pydantic import BaseModel from typing_extensions import Self @@ -40,6 +41,7 @@ ChatMessage, HandoffMessage, MemoryQueryEvent, + ModelClientStreamingChunkEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -62,6 +64,7 @@ class AssistantAgentConfig(BaseModel): model_context: ComponentModel | None = None description: str system_message: str | None = None + model_client_stream: bool reflect_on_tool_use: bool tool_call_summary_format: str @@ -126,6 +129,14 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): This will limit the number of recent messages sent to the model and can be useful when the model has a limit on the number of tokens it can process. + Streaming mode: + + The assistant agent can be used in streaming mode by setting `model_client_stream=True`. + In this mode, the :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield + :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` + messages as the model client produces chunks of response. + The chunk messages will not be included in the final response's inner messages. + Args: name (str): The name of the agent. @@ -138,6 +149,9 @@ class AssistantAgent(BaseChatAgent, Component[AssistantAgentConfig]): model_context (ChatCompletionContext | None, optional): The model context for storing and retrieving :class:`~autogen_core.models.LLMMessage`. It can be preloaded with initial messages. The initial messages will be cleared when the agent is reset. description (str, optional): The description of the agent. system_message (str, optional): The system message for the model. If provided, it will be prepended to the messages in the model context when making an inference. Set to `None` to disable. + model_client_stream (bool, optional): If `True`, the model client will be used in streaming mode. + :meth:`on_messages_stream` and :meth:`BaseChatAgent.run_stream` methods will also yield :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` + messages as the model client produces chunks of response. Defaults to `False`. reflect_on_tool_use (bool, optional): If `True`, the agent will make another model inference using the tool call and result to generate a response. If `False`, the tool call result will be returned as the response. Defaults to `False`. tool_call_summary_format (str, optional): The format string used to create a tool call summary for every tool call result. @@ -268,12 +282,14 @@ def __init__( system_message: ( str | None ) = "You are a helpful AI assistant. Solve tasks using your tools. Reply with TERMINATE when the task has been completed.", + model_client_stream: bool = False, reflect_on_tool_use: bool = False, tool_call_summary_format: str = "{result}", memory: Sequence[Memory] | None = None, ): super().__init__(name=name, description=description) self._model_client = model_client + self._model_client_stream = model_client_stream self._memory = None if memory is not None: if isinstance(memory, list): @@ -340,7 +356,7 @@ def __init__( @property def produced_message_types(self) -> Sequence[type[ChatMessage]]: - """The types of messages that the assistant agent produces.""" + """The types of final response messages that the assistant agent produces.""" message_types: List[type[ChatMessage]] = [TextMessage] if self._handoffs: message_types.append(HandoffMessage) @@ -383,9 +399,23 @@ async def on_messages_stream( # Generate an inference result based on the current model context. llm_messages = self._system_messages + await self._model_context.get_messages() - model_result = await self._model_client.create( - llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token - ) + model_result: CreateResult | None = None + if self._model_client_stream: + # Stream the model client. + async for chunk in self._model_client.create_stream( + llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token + ): + if isinstance(chunk, CreateResult): + model_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent(content=chunk, source=self.name) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + assert isinstance(model_result, CreateResult) + else: + model_result = await self._model_client.create( + llm_messages, tools=self._tools + self._handoff_tools, cancellation_token=cancellation_token + ) # Add the response to the model context. await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) @@ -465,14 +495,34 @@ async def on_messages_stream( if self._reflect_on_tool_use: # Generate another inference result based on the tool call and result. llm_messages = self._system_messages + await self._model_context.get_messages() - model_result = await self._model_client.create(llm_messages, cancellation_token=cancellation_token) - assert isinstance(model_result.content, str) + reflection_model_result: CreateResult | None = None + if self._model_client_stream: + # Stream the model client. + async for chunk in self._model_client.create_stream( + llm_messages, cancellation_token=cancellation_token + ): + if isinstance(chunk, CreateResult): + reflection_model_result = chunk + elif isinstance(chunk, str): + yield ModelClientStreamingChunkEvent(content=chunk, source=self.name) + else: + raise RuntimeError(f"Invalid chunk type: {type(chunk)}") + assert isinstance(reflection_model_result, CreateResult) + else: + reflection_model_result = await self._model_client.create( + llm_messages, cancellation_token=cancellation_token + ) + assert isinstance(reflection_model_result.content, str) # Add the response to the model context. - await self._model_context.add_message(AssistantMessage(content=model_result.content, source=self.name)) + await self._model_context.add_message( + AssistantMessage(content=reflection_model_result.content, source=self.name) + ) # Yield the response. yield Response( chat_message=TextMessage( - content=model_result.content, source=self.name, models_usage=model_result.usage + content=reflection_model_result.content, + source=self.name, + models_usage=reflection_model_result.usage, ), inner_messages=inner_messages, ) @@ -538,6 +588,7 @@ def _to_config(self) -> AssistantAgentConfig: system_message=self._system_messages[0].content if self._system_messages and isinstance(self._system_messages[0].content, str) else None, + model_client_stream=self._model_client_stream, reflect_on_tool_use=self._reflect_on_tool_use, tool_call_summary_format=self._tool_call_summary_format, ) @@ -553,6 +604,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self: model_context=None, description=config.description, system_message=config.system_message, + model_client_stream=config.model_client_stream, reflect_on_tool_use=config.reflect_on_tool_use, tool_call_summary_format=config.tool_call_summary_format, ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 97b9de76242c..b2e4e4fde48c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -9,6 +9,7 @@ AgentEvent, BaseChatMessage, ChatMessage, + ModelClientStreamingChunkEvent, TextMessage, ) from ..state import BaseState @@ -178,8 +179,11 @@ async def run_stream( output_messages.append(message.chat_message) yield TaskResult(messages=output_messages) else: - output_messages.append(message) yield message + if isinstance(message, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue + output_messages.append(message) @abstractmethod async def on_reset(self, cancellation_token: CancellationToken) -> None: diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py index c43c472c2915..c4fa5bb32cba 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py @@ -13,6 +13,7 @@ AgentEvent, BaseChatMessage, ChatMessage, + ModelClientStreamingChunkEvent, TextMessage, ) from ._base_chat_agent import BaseChatAgent @@ -150,6 +151,9 @@ async def on_messages_stream( # Skip the task messages. continue yield inner_msg + if isinstance(inner_msg, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue inner_messages.append(inner_msg) assert result is not None diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py index dc7905ff79db..e5bfa98d3b02 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_handoff.py @@ -1,7 +1,7 @@ import logging from typing import Any, Dict -from autogen_core.tools import FunctionTool, BaseTool +from autogen_core.tools import BaseTool, FunctionTool from pydantic import BaseModel, Field, model_validator from .. import EVENT_LOGGER_NAME diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 25d9e732d335..17249e674854 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -128,6 +128,15 @@ class MemoryQueryEvent(BaseAgentEvent): type: Literal["MemoryQueryEvent"] = "MemoryQueryEvent" +class ModelClientStreamingChunkEvent(BaseAgentEvent): + """An event signaling a text output chunk from a model client in streaming mode.""" + + content: str + """The partial text chunk.""" + + type: Literal["ModelClientStreamingChunkEvent"] = "ModelClientStreamingChunkEvent" + + ChatMessage = Annotated[ TextMessage | MultiModalMessage | StopMessage | ToolCallSummaryMessage | HandoffMessage, Field(discriminator="type") ] @@ -135,7 +144,11 @@ class MemoryQueryEvent(BaseAgentEvent): AgentEvent = Annotated[ - ToolCallRequestEvent | ToolCallExecutionEvent | MemoryQueryEvent | UserInputRequestedEvent, + ToolCallRequestEvent + | ToolCallExecutionEvent + | MemoryQueryEvent + | UserInputRequestedEvent + | ModelClientStreamingChunkEvent, Field(discriminator="type"), ] """Events emitted by agents and teams when they work, not used for agent-to-agent communication.""" @@ -154,4 +167,5 @@ class MemoryQueryEvent(BaseAgentEvent): "ToolCallSummaryMessage", "MemoryQueryEvent", "UserInputRequestedEvent", + "ModelClientStreamingChunkEvent", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 61e3783a80e5..b49766b5ffcc 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -21,7 +21,7 @@ from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TaskResult, Team, TerminationCondition -from ...messages import AgentEvent, BaseChatMessage, ChatMessage, TextMessage +from ...messages import AgentEvent, BaseChatMessage, ChatMessage, ModelClientStreamingChunkEvent, TextMessage from ...state import TeamState from ._chat_agent_container import ChatAgentContainer from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination @@ -190,6 +190,9 @@ async def run( and it may not reset the termination condition. To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead. + Returns: + result: The result of the task as :class:`~autogen_agentchat.base.TaskResult`. The result contains the messages produced by the team and the stop reason. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: @@ -279,9 +282,15 @@ async def run_stream( cancellation_token: CancellationToken | None = None, ) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]: """Run the team and produces a stream of messages and the final result - of the type :class:`TaskResult` as the last item in the stream. Once the + of the type :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. Once the team is stopped, the termination condition is reset. + .. note:: + + If an agent produces :class:`~autogen_agentchat.messages.ModelClientStreamingChunkEvent`, + the message will be yielded in the stream but it will not be included in the + :attr:`~autogen_agentchat.base.TaskResult.messages`. + Args: task (str | ChatMessage | Sequence[ChatMessage] | None): The task to run the team with. Can be a string, a single :class:`ChatMessage` , or a list of :class:`ChatMessage`. cancellation_token (CancellationToken | None): The cancellation token to kill the task immediately. @@ -289,6 +298,9 @@ async def run_stream( and it may not reset the termination condition. To gracefully stop the team, use :class:`~autogen_agentchat.conditions.ExternalTermination` instead. + Returns: + stream: an :class:`~collections.abc.AsyncGenerator` that yields :class:`~autogen_agentchat.messages.AgentEvent`, :class:`~autogen_agentchat.messages.ChatMessage`, and the final result :class:`~autogen_agentchat.base.TaskResult` as the last item in the stream. + Example using the :class:`~autogen_agentchat.teams.RoundRobinGroupChat` team: .. code-block:: python @@ -422,6 +434,9 @@ async def stop_runtime() -> None: if message is None: break yield message + if isinstance(message, ModelClientStreamingChunkEvent): + # Skip the model client streaming chunk events. + continue output_messages.append(message) # Yield the final result. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 39968c10b75e..1f80166d32a1 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -10,7 +10,13 @@ from autogen_agentchat.agents import UserProxyAgent from autogen_agentchat.base import Response, TaskResult -from autogen_agentchat.messages import AgentEvent, ChatMessage, MultiModalMessage, UserInputRequestedEvent +from autogen_agentchat.messages import ( + AgentEvent, + ChatMessage, + ModelClientStreamingChunkEvent, + MultiModalMessage, + UserInputRequestedEvent, +) def _is_running_in_iterm() -> bool: @@ -106,6 +112,8 @@ async def Console( last_processed: Optional[T] = None + streaming_chunks: List[str] = [] + async for message in stream: if isinstance(message, TaskResult): duration = time.time() - start_time @@ -159,13 +167,28 @@ async def Console( else: # Cast required for mypy to be happy message = cast(AgentEvent | ChatMessage, message) # type: ignore - output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n" - if message.models_usage: - if output_stats: - output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n" - total_usage.completion_tokens += message.models_usage.completion_tokens - total_usage.prompt_tokens += message.models_usage.prompt_tokens - await aprint(output, end="") + if not streaming_chunks: + # Print message sender. + await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n") + if isinstance(message, ModelClientStreamingChunkEvent): + await aprint(message.content, end="") + streaming_chunks.append(message.content) + else: + if streaming_chunks: + streaming_chunks.clear() + # Chunked messages are already printed, so we just print a newline. + await aprint("", end="\n") + else: + # Print message content. + await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n") + if message.models_usage: + if output_stats: + await aprint( + f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]", + end="\n", + ) + total_usage.completion_tokens += message.models_usage.completion_tokens + total_usage.prompt_tokens += message.models_usage.prompt_tokens if last_processed is None: raise ValueError("No TaskResult or Response was processed.") diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index d50a96409c7e..6c2852c667cd 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -11,6 +11,7 @@ ChatMessage, HandoffMessage, MemoryQueryEvent, + ModelClientStreamingChunkEvent, MultiModalMessage, TextMessage, ToolCallExecutionEvent, @@ -20,10 +21,11 @@ from autogen_core import FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext -from autogen_core.models import FunctionExecutionResult, LLMMessage +from autogen_core.models import CreateResult, FunctionExecutionResult, LLMMessage, RequestUsage from autogen_core.models._model_client import ModelFamily from autogen_core.tools import FunctionTool from autogen_ext.models.openai import OpenAIChatCompletionClient +from autogen_ext.models.replay import ReplayChatCompletionClient from openai.resources.chat.completions import AsyncCompletions from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -776,3 +778,65 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N ) agent3_config = agent3.dump_component() assert agent3_config.provider == "autogen_agentchat.agents.AssistantAgent" + + +@pytest.mark.asyncio +async def test_model_client_stream() -> None: + mock_client = ReplayChatCompletionClient( + [ + "Response to message 3", + ] + ) + agent = AssistantAgent( + "test_agent", + model_client=mock_client, + model_client_stream=True, + ) + chunks: List[str] = [] + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message.messages[-1].content == "Response to message 3" + elif isinstance(message, ModelClientStreamingChunkEvent): + chunks.append(message.content) + assert "".join(chunks) == "Response to message 3" + + +@pytest.mark.asyncio +async def test_model_client_stream_with_tool_calls() -> None: + mock_client = ReplayChatCompletionClient( + [ + CreateResult( + content=[ + FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'), + FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'), + ], + finish_reason="function_calls", + usage=RequestUsage(prompt_tokens=10, completion_tokens=5), + cached=False, + ), + "Example response 2 to task", + ] + ) + mock_client._model_info["function_calling"] = True # pyright: ignore + agent = AssistantAgent( + "test_agent", + model_client=mock_client, + model_client_stream=True, + reflect_on_tool_use=True, + tools=[_pass_function, _echo_function], + ) + chunks: List[str] = [] + async for message in agent.run_stream(task="task"): + if isinstance(message, TaskResult): + assert message.messages[-1].content == "Example response 2 to task" + assert message.messages[1].content == [ + FunctionCall(id="1", name="_pass_function", arguments=r'{"input": "task"}'), + FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'), + ] + assert message.messages[2].content == [ + FunctionExecutionResult(call_id="1", content="pass"), + FunctionExecutionResult(call_id="3", content="task"), + ] + elif isinstance(message, ModelClientStreamingChunkEvent): + chunks.append(message.content) + assert "".join(chunks) == "Example response 2 to task" diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb index 3a07f68fee3e..d1b90aa787a9 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/agents.ipynb @@ -403,6 +403,117 @@ "await Console(agent.run_stream(task=\"I am happy.\"))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming Tokens\n", + "\n", + "You can stream the tokens generated by the model client by setting `model_client_stream=True`.\n", + "This will cause the agent to yield {py:class}`~autogen_agentchat.messages.ModelClientStreamingChunkEvent` messages\n", + "in {py:meth}`~autogen_agentchat.agents.BaseChatAgent.on_messages_stream` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream`.\n", + "\n", + "The underlying model API must support streaming tokens for this to work.\n", + "Please check with your model provider to see if this is supported." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' South' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Buenos' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Aires' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Argentina' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' São' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Paulo' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Brazil' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n", + "Response(chat_message=TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in South America are Buenos Aires in Argentina and São Paulo in Brazil.', type='TextMessage'), inner_messages=[])\n" + ] + } + ], + "source": [ + "model_client = OpenAIChatCompletionClient(model=\"gpt-4o\")\n", + "\n", + "streaming_assistant = AssistantAgent(\n", + " name=\"assistant\",\n", + " model_client=model_client,\n", + " system_message=\"You are a helpful assistant.\",\n", + " model_client_stream=True, # Enable streaming tokens.\n", + ")\n", + "\n", + "# Use an async function and asyncio.run() in a script.\n", + "async for message in streaming_assistant.on_messages_stream( # type: ignore\n", + " [TextMessage(content=\"Name two cities in South America\", source=\"user\")],\n", + " cancellation_token=CancellationToken(),\n", + "):\n", + " print(message)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can see the streaming chunks in the output above.\n", + "The chunks are generated by the model client and are yielded by the agent as they are received.\n", + "The final response, the concatenation of all the chunks, is yielded right after the last chunk.\n", + "\n", + "Similarly, {py:meth}`~autogen_agentchat.agents.BaseChatAgent.run_stream` will also yield the same streaming chunks,\n", + "followed by a full text message right after the last chunk." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "source='user' models_usage=None content='Name two cities in North America.' type='TextMessage'\n", + "source='assistant' models_usage=None content='Two' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' cities' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' North' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' America' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' are' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' New' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' York' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' City' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' the' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' United' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' States' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' and' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Toronto' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' in' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content=' Canada' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=None content='.' type='ModelClientStreamingChunkEvent'\n", + "source='assistant' models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0) content='Two cities in North America are New York City in the United States and Toronto in Canada.' type='TextMessage'\n", + "TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Name two cities in North America.', type='TextMessage'), TextMessage(source='assistant', models_usage=RequestUsage(prompt_tokens=0, completion_tokens=0), content='Two cities in North America are New York City in the United States and Toronto in Canada.', type='TextMessage')], stop_reason=None)\n" + ] + } + ], + "source": [ + "async for message in streaming_assistant.run_stream(task=\"Name two cities in North America.\"): # type: ignore\n", + " print(message)" + ] + }, { "cell_type": "markdown", "metadata": {}, diff --git a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py index 5ae7b6b665eb..b5c178d774bf 100644 --- a/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py +++ b/python/packages/autogen-ext/src/autogen_ext/models/replay/_replay_chat_completion_client.py @@ -185,6 +185,9 @@ async def create_stream( yield token + " " else: yield token + yield CreateResult( + finish_reason="stop", content=response, usage=self._cur_usage, cached=self._cached_bool_value + ) self._update_total_usage() else: self._cur_usage = RequestUsage( @@ -226,7 +229,7 @@ def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> t total_tokens += len(tokens) all_tokens.extend(tokens) else: - logger.warning("Token count has been done only on string content", RuntimeWarning) + logger.warning("Token count has been done only on string content") elif isinstance(messages, Sequence): for message in messages: if isinstance(message.content, str): # type: ignore [reportAttributeAccessIssue, union-attr] @@ -234,7 +237,7 @@ def _tokenize(self, messages: Union[str, LLMMessage, Sequence[LLMMessage]]) -> t total_tokens += len(tokens) all_tokens.extend(tokens) else: - logger.warning("Token count has been done only on string content", RuntimeWarning) + logger.warning("Token count has been done only on string content") return all_tokens, total_tokens def _update_total_usage(self) -> None: diff --git a/python/packages/autogen-ext/src/autogen_ext/ui/_rich_console.py b/python/packages/autogen-ext/src/autogen_ext/ui/_rich_console.py index 58940fe340d2..1951205e8ed5 100644 --- a/python/packages/autogen-ext/src/autogen_ext/ui/_rich_console.py +++ b/python/packages/autogen-ext/src/autogen_ext/ui/_rich_console.py @@ -16,6 +16,7 @@ from autogen_agentchat.messages import ( AgentEvent, ChatMessage, + ModelClientStreamingChunkEvent, MultiModalMessage, UserInputRequestedEvent, ) @@ -185,6 +186,9 @@ async def RichConsole( elif isinstance(message, UserInputRequestedEvent): if user_input_manager is not None: user_input_manager.notify_event_received(message.request_id) + elif isinstance(message, ModelClientStreamingChunkEvent): + # TODO: Handle model client streaming chunk events. + pass else: # Cast required for mypy to be happy message = cast(AgentEvent | ChatMessage, message) # type: ignore diff --git a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py index 5a023b02dff7..ea6d84b2b0b6 100644 --- a/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py +++ b/python/packages/autogen-ext/tests/models/test_chat_completion_cache.py @@ -107,14 +107,14 @@ async def test_cache_create_stream() -> None: async for completion in cached_client.create_stream( [system_prompt, UserMessage(content=prompts[0], source="user")] ): - original_streamed_results.append(completion) + original_streamed_results.append(copy.copy(completion)) total_usage0 = copy.copy(cached_client.total_usage()) cached_completion_results: List[Union[str, CreateResult]] = [] async for completion in cached_client.create_stream( [system_prompt, UserMessage(content=prompts[0], source="user")] ): - cached_completion_results.append(completion) + cached_completion_results.append(copy.copy(completion)) total_usage1 = copy.copy(cached_client.total_usage()) assert total_usage1.prompt_tokens == total_usage0.prompt_tokens diff --git a/python/packages/autogen-ext/tests/models/test_reply_chat_completion_client.py b/python/packages/autogen-ext/tests/models/test_reply_chat_completion_client.py index 7c3fe584b656..b600d2f7d2ac 100644 --- a/python/packages/autogen-ext/tests/models/test_reply_chat_completion_client.py +++ b/python/packages/autogen-ext/tests/models/test_reply_chat_completion_client.py @@ -67,12 +67,16 @@ async def test_replay_chat_completion_client_create_stream() -> None: reply_model_client = ReplayChatCompletionClient(messages) for i in range(num_messages): - result: List[str] = [] + chunks: List[str] = [] + result: CreateResult | None = None async for completion in reply_model_client.create_stream([UserMessage(content="dummy", source="_")]): - text = completion.content if isinstance(completion, CreateResult) else completion - assert isinstance(text, str) - result.append(text) - assert "".join(result) == messages[i] + if isinstance(completion, CreateResult): + result = completion + else: + assert isinstance(completion, str) + chunks.append(completion) + assert result is not None + assert "".join(chunks) == messages[i] == result.content with pytest.raises(ValueError, match="No more mock responses available"): await reply_model_client.create([UserMessage(content="dummy", source="_")])