Skip to content

Commit

Permalink
Console async printing and optional stats (#4956)
Browse files Browse the repository at this point in the history
* async printing

* Make stats output option
  • Loading branch information
jackgerrits authored Jan 9, 2025
1 parent 79b0b6d commit 02ad98b
Showing 1 changed file with 32 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
from typing import AsyncGenerator, List, Optional, TypeVar, cast

from aioconsole import aprint # type: ignore
from autogen_core import Image
from autogen_core.models import RequestUsage

Expand All @@ -25,6 +26,7 @@ async def Console(
stream: AsyncGenerator[AgentEvent | ChatMessage | T, None],
*,
no_inline_images: bool = False,
output_stats: bool = True,
) -> T:
"""
Consumes the message stream from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand All @@ -35,6 +37,7 @@ async def Console(
stream (AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None] | AsyncGenerator[AgentEvent | ChatMessage | Response, None]): Message stream to render.
This can be from :meth:`~autogen_agentchat.base.TaskRunner.run_stream` or :meth:`~autogen_agentchat.base.ChatAgent.on_messages_stream`.
no_inline_images (bool, optional): If terminal is iTerm2 will render images inline. Use this to disable this behavior. Defaults to False.
output_stats (bool, optional): If True, will output a summary of the messages and inline token usage info. Defaults to True.
Returns:
last_processed: A :class:`~autogen_agentchat.base.TaskResult` if the stream is from :meth:`~autogen_agentchat.base.TaskRunner.run_stream`
Expand All @@ -49,16 +52,16 @@ async def Console(
async for message in stream:
if isinstance(message, TaskResult):
duration = time.time() - start_time
output = (
f"{'-' * 10} Summary {'-' * 10}\n"
f"Number of messages: {len(message.messages)}\n"
f"Finish reason: {message.stop_reason}\n"
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
sys.stdout.flush()
if output_stats:
output = (
f"{'-' * 10} Summary {'-' * 10}\n"
f"Number of messages: {len(message.messages)}\n"
f"Finish reason: {message.stop_reason}\n"
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore

Expand All @@ -68,26 +71,26 @@ async def Console(
# Print final response.
output = f"{'-' * 10} {message.chat_message.source} {'-' * 10}\n{_message_to_str(message.chat_message, render_image_iterm=render_image_iterm)}\n"
if message.chat_message.models_usage:
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
if output_stats:
output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens
total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens
sys.stdout.write(output)
sys.stdout.flush()
await aprint(output, end="")

# Print summary.
if message.inner_messages is not None:
num_inner_messages = len(message.inner_messages)
else:
num_inner_messages = 0
output = (
f"{'-' * 10} Summary {'-' * 10}\n"
f"Number of inner messages: {num_inner_messages}\n"
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
sys.stdout.write(output)
sys.stdout.flush()
if output_stats:
if message.inner_messages is not None:
num_inner_messages = len(message.inner_messages)
else:
num_inner_messages = 0
output = (
f"{'-' * 10} Summary {'-' * 10}\n"
f"Number of inner messages: {num_inner_messages}\n"
f"Total prompt tokens: {total_usage.prompt_tokens}\n"
f"Total completion tokens: {total_usage.completion_tokens}\n"
f"Duration: {duration:.2f} seconds\n"
)
await aprint(output, end="")
# mypy ignore
last_processed = message # type: ignore

Expand All @@ -96,11 +99,11 @@ async def Console(
message = cast(AgentEvent | ChatMessage, message) # type: ignore
output = f"{'-' * 10} {message.source} {'-' * 10}\n{_message_to_str(message, render_image_iterm=render_image_iterm)}\n"
if message.models_usage:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
if output_stats:
output += f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]\n"
total_usage.completion_tokens += message.models_usage.completion_tokens
total_usage.prompt_tokens += message.models_usage.prompt_tokens
sys.stdout.write(output)
sys.stdout.flush()
await aprint(output, end="")

if last_processed is None:
raise ValueError("No TaskResult or Response was processed.")
Expand Down

0 comments on commit 02ad98b

Please sign in to comment.