From 442df1839774a6f848b379a5926594515dbe9267 Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Wed, 5 Feb 2025 13:02:43 -0500 Subject: [PATCH 01/15] feat: dotnet runtime tests (#5342) --- .../AgentRuntimeTests.cs | 83 +++++++++++++++++++ .../Microsoft.AutoGen.Core.Tests/TestAgent.cs | 30 +++++++ 2 files changed, 113 insertions(+) create mode 100644 dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs new file mode 100644 index 000000000000..812d47c2d207 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentRuntimeTests.cs +using FluentAssertions; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AutoGen.Core.Tests; + +[Trait("Category", "UnitV2")] +public class AgentRuntimeTests() +{ + // Agent will not deliver to self will success when runtime.DeliverToSelf is false (default) + [Fact] + public async Task RuntimeAgentPublishToSelfDefaultNoSendTest() + { + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent has default messages and could not publish to self + agent.Text.Source.Should().Be("DefaultTopic"); + agent.Text.Content.Should().Be("DefaultContent"); + } + + // Agent delivery to self will success when runtime.DeliverToSelf is true + [Fact] + public async Task RuntimeAgentPublishToSelfDeliverToSelfTrueTest() + { + var runtime = new InProcessRuntime(); + runtime.DeliverToSelf = true; + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent sucessfully published to self + agent.Text.Source.Should().Be("TestTopic"); + agent.Text.Content.Should().Be("SelfMessage"); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs index af752ae0e610..b6dadc833be2 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs @@ -72,3 +72,33 @@ public SubscribedAgent(AgentId id, { } } + +/// +/// The test agent showing an agent that subscribes to itself. +/// +[TypeSubscription("TestTopic")] +public class SubscribedSelfPublishAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : BaseAgent(id, runtime, "Test Agent", logger), + IHandle, + IHandle +{ + public async ValueTask HandleAsync(string item, MessageContext messageContext) + { + TextMessage strToText = new TextMessage + { + Source = "TestTopic", + Content = item + }; + // This will publish the new message type which will be handled by the TextMessage handler + await this.PublishMessageAsync(strToText, new TopicId("TestTopic")); + } + public ValueTask HandleAsync(TextMessage item, MessageContext messageContext) + { + _text = item; + return ValueTask.CompletedTask; + } + + private TextMessage _text = new TextMessage { Source = "DefaultTopic", Content = "DefaultContent" }; + public TextMessage Text => _text; +} From be3c60baab101dc8677127a66aac83768f350b40 Mon Sep 17 00:00:00 2001 From: gagb Date: Wed, 5 Feb 2025 10:35:52 -0800 Subject: [PATCH 02/15] docs: add blog link to README for updates and resources (#5368) #5080 image --- README.md | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 034421c310f6..3b5a42706e8c 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,14 @@ [![LinkedIn](https://img.shields.io/badge/LinkedIn-Company?style=flat&logo=linkedin&logoColor=white)](https://www.linkedin.com/company/105812540) [![Discord](https://img.shields.io/badge/discord-chat-green?logo=discord)](https://aka.ms/autogen-discord) [![Documentation](https://img.shields.io/badge/Documentation-AutoGen-blue?logo=read-the-docs)](https://microsoft.github.io/autogen/) +[![Blog](https://img.shields.io/badge/Blog-AutoGen-blue?logo=blogger)](https://devblogs.microsoft.com/autogen/) +
+ Important: This is the official project. We are not affiliated with any fork or startup. See our statement. +
+ # AutoGen **AutoGen** is a framework for creating multi-agent AI applications that can act autonomously or work alongside humans. @@ -130,7 +135,7 @@ With AutoGen you get to join and contribute to a thriving ecosystem. We host wee Interested in contributing? See [CONTRIBUTING.md](./CONTRIBUTING.md) for guidelines on how to get started. We welcome contributions of all kinds, including bug fixes, new features, and documentation improvements. Join our community and help us make AutoGen better! -Have questions? Check out our [Frequently Asked Questions (FAQ)](./FAQ.md) for answers to common queries. If you don't find what you're looking for, feel free to ask in our [GitHub Discussions](https://github.com/microsoft/autogen/discussions) or join our [Discord server](https://aka.ms/autogen-discord) for real-time support. +Have questions? Check out our [Frequently Asked Questions (FAQ)](./FAQ.md) for answers to common queries. If you don't find what you're looking for, feel free to ask in our [GitHub Discussions](https://github.com/microsoft/autogen/discussions) or join our [Discord server](https://aka.ms/autogen-discord) for real-time support. You can also read our [blog](https://devblogs.microsoft.com/autogen/) for updates. ## Legal Notices From 172a16a6150de2bc8e903ecb62c87903927533b3 Mon Sep 17 00:00:00 2001 From: Eitan Yarmush Date: Wed, 5 Feb 2025 19:07:27 -0500 Subject: [PATCH 03/15] Memory component base (#5380) ## Why are these changes needed? Currently the way to accomplish RAG behavior with agent chat, specifically assistant agents is with the memory interface, however there is no way to configure it via the declarative API. ## Related issue number ## Checks - [ ] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Victor Dibia --- .../agents/_assistant_agent.py | 3 ++ .../tests/test_assistant_agent.py | 12 ++++++-- .../src/autogen_core/memory/_base_memory.py | 5 +++- .../src/autogen_core/memory/_list_memory.py | 30 +++++++++++++++++-- .../autogen-core/tests/test_memory.py | 30 ++++++++++++++++++- 5 files changed, 72 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 4d3aecf724b3..b3dda7175ae6 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -62,6 +62,7 @@ class AssistantAgentConfig(BaseModel): tools: List[ComponentModel] | None handoffs: List[HandoffBase | str] | None = None model_context: ComponentModel | None = None + memory: List[ComponentModel] | None = None description: str system_message: str | None = None model_client_stream: bool = False @@ -591,6 +592,7 @@ def _to_config(self) -> AssistantAgentConfig: tools=[tool.dump_component() for tool in self._tools], handoffs=list(self._handoffs.values()), model_context=self._model_context.dump_component(), + memory=[memory.dump_component() for memory in self._memory] if self._memory else None, description=self.description, system_message=self._system_messages[0].content if self._system_messages and isinstance(self._system_messages[0].content, str) @@ -609,6 +611,7 @@ def _from_config(cls, config: AssistantAgentConfig) -> Self: tools=[BaseTool.load_component(tool) for tool in config.tools] if config.tools else None, handoffs=config.handoffs, model_context=None, + memory=[Memory.load_component(memory) for memory in config.memory] if config.memory else None, description=config.description, system_message=config.system_message, model_client_stream=config.model_client_stream, diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 3db9924368fd..040f3e23e6ca 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -18,7 +18,7 @@ ToolCallRequestEvent, ToolCallSummaryMessage, ) -from autogen_core import FunctionCall, Image +from autogen_core import ComponentModel, FunctionCall, Image from autogen_core.memory import ListMemory, Memory, MemoryContent, MemoryMimeType, MemoryQueryResult from autogen_core.model_context import BufferedChatCompletionContext from autogen_core.models import ( @@ -754,7 +754,12 @@ async def test_run_with_memory(monkeypatch: pytest.MonkeyPatch) -> None: "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), memory=[memory2] ) - result = await agent.run(task="test task") + # Test dump and load component with memory + agent_config: ComponentModel = agent.dump_component() + assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent" + agent2 = AssistantAgent.load_component(agent_config) + + result = await agent2.run(task="test task") assert len(result.messages) > 0 memory_event = next((msg for msg in result.messages if isinstance(msg, MemoryQueryEvent)), None) assert memory_event is not None @@ -795,9 +800,10 @@ async def test_assistant_agent_declarative(monkeypatch: pytest.MonkeyPatch) -> N "test_agent", model_client=OpenAIChatCompletionClient(model=model, api_key=""), model_context=model_context, + memory=[ListMemory(name="test_memory")], ) - agent_config = agent.dump_component() + agent_config: ComponentModel = agent.dump_component() assert agent_config.provider == "autogen_agentchat.agents.AssistantAgent" agent2 = AssistantAgent.load_component(agent_config) diff --git a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py index 2ae79b4106bb..577c23fda826 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_base_memory.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, ConfigDict from .._cancellation_token import CancellationToken +from .._component_config import ComponentBase from .._image import Image from ..model_context import ChatCompletionContext @@ -49,7 +50,7 @@ class UpdateContextResult(BaseModel): memories: MemoryQueryResult -class Memory(ABC): +class Memory(ABC, ComponentBase[BaseModel]): """Protocol defining the interface for memory implementations. A memory is the storage for data that can be used to enrich or modify the model context. @@ -64,6 +65,8 @@ class Memory(ABC): See :class:`~autogen_core.memory.ListMemory` for an example implementation. """ + component_type = "memory" + @abstractmethod async def update_context( self, diff --git a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py index f28bab048298..5ad086550081 100644 --- a/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py +++ b/python/packages/autogen-core/src/autogen_core/memory/_list_memory.py @@ -1,12 +1,25 @@ from typing import Any, List +from pydantic import BaseModel +from typing_extensions import Self + from .._cancellation_token import CancellationToken +from .._component_config import Component from ..model_context import ChatCompletionContext from ..models import SystemMessage from ._base_memory import Memory, MemoryContent, MemoryQueryResult, UpdateContextResult -class ListMemory(Memory): +class ListMemoryConfig(BaseModel): + """Configuration for ListMemory component.""" + + name: str | None = None + """Optional identifier for this memory instance.""" + memory_contents: List[MemoryContent] = [] + """List of memory contents stored in this memory instance.""" + + +class ListMemory(Memory, Component[ListMemoryConfig]): """Simple chronological list-based memory implementation. This memory implementation stores contents in a list and retrieves them in @@ -53,9 +66,13 @@ async def main() -> None: """ - def __init__(self, name: str | None = None) -> None: + component_type = "memory" + component_provider_override = "autogen_core.memory.ListMemory" + component_config_schema = ListMemoryConfig + + def __init__(self, name: str | None = None, memory_contents: List[MemoryContent] | None = None) -> None: self._name = name or "default_list_memory" - self._contents: List[MemoryContent] = [] + self._contents: List[MemoryContent] = memory_contents if memory_contents is not None else [] @property def name(self) -> str: @@ -146,3 +163,10 @@ async def clear(self) -> None: async def close(self) -> None: """Cleanup resources if needed.""" pass + + @classmethod + def _from_config(cls, config: ListMemoryConfig) -> Self: + return cls(name=config.name, memory_contents=config.memory_contents) + + def _to_config(self) -> ListMemoryConfig: + return ListMemoryConfig(name=self.name, memory_contents=self._contents) diff --git a/python/packages/autogen-core/tests/test_memory.py b/python/packages/autogen-core/tests/test_memory.py index 04054e1b2250..ce98aaffb97f 100644 --- a/python/packages/autogen-core/tests/test_memory.py +++ b/python/packages/autogen-core/tests/test_memory.py @@ -1,7 +1,7 @@ from typing import Any import pytest -from autogen_core import CancellationToken +from autogen_core import CancellationToken, ComponentModel from autogen_core.memory import ( ListMemory, Memory, @@ -23,6 +23,34 @@ def test_memory_protocol_attributes() -> None: assert hasattr(Memory, "close") +def test_memory_component_load_config_from_base_model() -> None: + """Test that Memory component can be loaded from a BaseModel.""" + config = ComponentModel( + provider="autogen_core.memory.ListMemory", + config={ + "name": "test_memory", + "memory_contents": [MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)], + }, + ) + memory = Memory.load_component(config) + assert isinstance(memory, ListMemory) + assert memory.name == "test_memory" + assert len(memory.content) == 1 + + +def test_memory_component_dump_config_to_base_model() -> None: + """Test that Memory component can be dumped to a BaseModel.""" + memory = ListMemory( + name="test_memory", memory_contents=[MemoryContent(content="test", mime_type=MemoryMimeType.TEXT)] + ) + config = memory.dump_component() + assert isinstance(config, ComponentModel) + assert config.provider == "autogen_core.memory.ListMemory" + assert config.component_type == "memory" + assert config.config["name"] == "test_memory" + assert len(config.config["memory_contents"]) == 1 + + def test_memory_abc_implementation() -> None: """Test that Memory ABC is properly implemented.""" From 7947464e4ae2fa58da4d68e39f283bd7247283c4 Mon Sep 17 00:00:00 2001 From: Wei Jen Lu Date: Thu, 6 Feb 2025 00:52:32 +0000 Subject: [PATCH 04/15] Fixed example code in doc:Custom Agents (#5381) The Tuple class is never used in CountDownAgent class. ## Why are these changes needed? ## Related issue number ## Checks - [x] I've included any doc changes needed for https://microsoft.github.io/autogen/. See https://microsoft.github.io/autogen/docs/Contribute#documentation to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. Co-authored-by: Victor Dibia --- .../agentchat-user-guide/tutorial/custom-agents.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb index 6e2dfa90b494..5b8c4e7f24fa 100644 --- a/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/agentchat-user-guide/tutorial/custom-agents.ipynb @@ -49,7 +49,7 @@ } ], "source": [ - "from typing import AsyncGenerator, List, Sequence, Tuple\n", + "from typing import AsyncGenerator, List, Sequence\n", "\n", "from autogen_agentchat.agents import BaseChatAgent\n", "from autogen_agentchat.base import Response\n", @@ -310,4 +310,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} From d86540e9cd8b3f6270a5f7c83bbf2a67529426a9 Mon Sep 17 00:00:00 2001 From: afourney Date: Wed, 5 Feb 2025 16:57:46 -0800 Subject: [PATCH 05/15] Fix summarize_page in a text-only context, and for unknown models. (#5388) WebSurfer's summarize_page was failing when the model was text-only, or unknown. --- .../web_surfer/_multimodal_web_surfer.py | 49 +++++++++++++------ 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index c83bf98d2309..b848ae8f59ba 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -65,6 +65,8 @@ from ._types import InteractiveRegion, UserContent from .playwright_controller import PlaywrightController +DEFAULT_CONTEXT_SIZE = 128000 + class MultimodalWebSurferConfig(BaseModel): name: str @@ -855,35 +857,50 @@ async def _summarize_page( buffer = "" # for line in re.split(r"([\r\n]+)", page_markdown): for line in page_markdown.splitlines(): - message = UserMessage( - # content=[ + trial_message = UserMessage( content=prompt + buffer + line, - # ag_image, - # ], source=self.name, ) - remaining = self._model_client.remaining_tokens(messages + [message]) - if remaining > self.SCREENSHOT_TOKENS: - buffer += line - else: + try: + remaining = self._model_client.remaining_tokens(messages + [trial_message]) + except KeyError: + # Use the default if the model isn't found + remaining = DEFAULT_CONTEXT_SIZE - self._model_client.count_tokens(messages + [trial_message]) + + if self._model_client.model_info["vision"] and remaining <= 0: break + if self._model_client.model_info["vision"] and remaining <= self.SCREENSHOT_TOKENS: + break + + buffer += line + # Nothing to do buffer = buffer.strip() if len(buffer) == 0: return "Nothing to summarize." # Append the message - messages.append( - UserMessage( - content=[ - prompt + buffer, - ag_image, - ], - source=self.name, + if self._model_client.model_info["vision"]: + # Multimodal + messages.append( + UserMessage( + content=[ + prompt + buffer, + ag_image, + ], + source=self.name, + ) + ) + else: + # Text only + messages.append( + UserMessage( + content=prompt + buffer, + source=self.name, + ) ) - ) # Generate the response response = await self._model_client.create(messages, cancellation_token=cancellation_token) From ac74305913e4b06f44a85e0c9097b2f4dc131312 Mon Sep 17 00:00:00 2001 From: afourney Date: Wed, 5 Feb 2025 20:17:24 -0800 Subject: [PATCH 06/15] Ensure decriptions appear each on one line. Fix web_surfer's desc (#5390) Some agent descriptions were split over multiple lines in the M1 orchestrator. This PR ensures that each description appears on one, and only one, line. This makes it easier for smaller models to understand. --- .../_magentic_one/_magentic_one_orchestrator.py | 15 +++++++-------- .../agents/web_surfer/_multimodal_web_surfer.py | 4 ++-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py index b328ba308bd9..f0f927ac5a1c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_magentic_one/_magentic_one_orchestrator.py @@ -1,5 +1,6 @@ import json import logging +import re from typing import Any, Dict, List, Mapping from autogen_core import AgentId, CancellationToken, DefaultTopicId, MessageContext, event, rpc @@ -80,14 +81,12 @@ def __init__( self._plan = "" self._n_rounds = 0 self._n_stalls = 0 - self._team_description = "\n".join( - [ - f"{topic_type}: {description}".strip() - for topic_type, description in zip( - self._participant_topic_types, self._participant_descriptions, strict=True - ) - ] - ) + + # Produce a team description. Each agent sould appear on a single line. + self._team_description = "" + for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True): + self._team_description += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + self._team_description = self._team_description.strip() def _get_task_ledger_facts_prompt(self, task: str) -> str: return ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT.format(task=task) diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index b848ae8f59ba..90ddbdb2b9c0 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -176,9 +176,9 @@ async def main() -> None: DEFAULT_DESCRIPTION = """ A helpful assistant with access to a web browser. - Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, etc., filling in form fields, etc.). + Ask them to perform web searches, open pages, and interact with content (e.g., clicking links, scrolling the viewport, filling in form fields, etc.). It can also summarize the entire page, or answer questions based on the content of the page. - It can also be asked to sleep and wait for pages to load, in cases where the pages seem to be taking a while to load. + It can also be asked to sleep and wait for pages to load, in cases where the page seems not yet fully loaded. """ DEFAULT_START_PAGE = "https://www.bing.com/" From cf798aef3f198a56e9ce7e6e2ba0f6abd87c42f4 Mon Sep 17 00:00:00 2001 From: afourney Date: Wed, 5 Feb 2025 22:17:18 -0800 Subject: [PATCH 07/15] Various web surfer fixes. (#5393) This PR fixes: A prompting bug when no control had focus. Awkward prompt phrasing. Renamed page_down to scroll_down to better match other prompting and agent descriptions. --- .../agents/web_surfer/_multimodal_web_surfer.py | 14 ++++++++------ .../src/autogen_ext/agents/web_surfer/_prompts.py | 4 ++-- .../agents/web_surfer/_tool_definitions.py | 8 ++++---- .../agents/web_surfer/playwright_controller.py | 6 +++--- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index 90ddbdb2b9c0..73c0799fff38 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -53,9 +53,9 @@ TOOL_CLICK, TOOL_HISTORY_BACK, TOOL_HOVER, - TOOL_PAGE_DOWN, - TOOL_PAGE_UP, TOOL_READ_PAGE_AND_ANSWER, + TOOL_SCROLL_DOWN, + TOOL_SCROLL_UP, TOOL_SLEEP, TOOL_SUMMARIZE_PAGE, TOOL_TYPE, @@ -466,11 +466,11 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo # We can scroll up if viewport["pageTop"] > 5: - tools.append(TOOL_PAGE_UP) + tools.append(TOOL_SCROLL_UP) # Can scroll down if (viewport["pageTop"] + viewport["height"] + 5) < viewport["scrollHeight"]: - tools.append(TOOL_PAGE_DOWN) + tools.append(TOOL_SCROLL_DOWN) # Focus hint focused = await self._playwright_controller.get_focused_rect_id(self._page) @@ -479,6 +479,8 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo name = self._target_name(focused, rects) if name: name = f"(and name '{name}') " + else: + name = "" role = "control" try: @@ -613,10 +615,10 @@ async def _execute_tool( self._last_download = None if reset_prior_metadata and self._prior_metadata_hash is not None: self._prior_metadata_hash = None - elif name == "page_up": + elif name == "scroll_up": action_description = "I scrolled up one page in the browser." await self._playwright_controller.page_up(self._page) - elif name == "page_down": + elif name == "scroll_down": action_description = "I scrolled down one page in the browser." await self._playwright_controller.page_down(self._page) diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py index 050b267e0114..59a0a7c95d5e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py @@ -8,7 +8,7 @@ {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text might be most appropriate, or hovering over element) + - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - on some other website entirely (in which case actions like performing a new web search might be the best option) """ @@ -29,7 +29,7 @@ {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text might be most appropriate, or hovering over element) + - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - on some other website entirely (in which case actions like performing a new web search might be the best option) """ diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_tool_definitions.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_tool_definitions.py index fd2928248596..c80f6de31589 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_tool_definitions.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_tool_definitions.py @@ -87,11 +87,11 @@ def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema: } ) -TOOL_PAGE_UP: ToolSchema = _load_tool( +TOOL_SCROLL_UP: ToolSchema = _load_tool( { "type": "function", "function": { - "name": "page_up", + "name": "scroll_up", "description": "Scrolls the entire browser viewport one page UP towards the beginning.", "parameters": { "type": "object", @@ -107,11 +107,11 @@ def _load_tool(tooldef: Dict[str, Any]) -> ToolSchema: } ) -TOOL_PAGE_DOWN: ToolSchema = _load_tool( +TOOL_SCROLL_DOWN: ToolSchema = _load_tool( { "type": "function", "function": { - "name": "page_down", + "name": "scroll_down", "description": "Scrolls the entire browser viewport one page DOWN towards the end.", "parameters": { "type": "object", diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/playwright_controller.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/playwright_controller.py index 412bc07dba74..b99fbc49afad 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/playwright_controller.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/playwright_controller.py @@ -122,7 +122,7 @@ async def get_visual_viewport(self, page: Page) -> VisualViewport: pass return visualviewport_from_dict(await page.evaluate("MultimodalWebSurfer.getVisualViewport();")) - async def get_focused_rect_id(self, page: Page) -> str: + async def get_focused_rect_id(self, page: Page) -> str | None: """ Retrieve the ID of the currently focused element. @@ -130,7 +130,7 @@ async def get_focused_rect_id(self, page: Page) -> str: page (Page): The Playwright page object. Returns: - str: The ID of the focused element. + str: The ID of the focused element or None if no control has focus. """ assert page is not None try: @@ -138,7 +138,7 @@ async def get_focused_rect_id(self, page: Page) -> str: except Exception: pass result = await page.evaluate("MultimodalWebSurfer.getFocusedElementId();") - return str(result) + return None if result is None else str(result) async def get_page_metadata(self, page: Page) -> Dict[str, Any]: """ From ca428914f59542e40421d00beb07ddc77ad429c9 Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 6 Feb 2025 13:53:24 -0500 Subject: [PATCH 08/15] Refactor grpc channel connection in servicer (#5402) --- .../grpc/_worker_runtime_host_servicer.py | 154 +++++++++++------- 1 file changed, 98 insertions(+), 56 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 74c1ef3b385b..84df493bd1e8 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import asyncio import logging -from _collections_abc import AsyncIterator +from abc import ABC, abstractmethod from asyncio import Future, Task -from typing import Any, Dict, Sequence, Set, Tuple +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar from autogen_core import TopicId from autogen_core._runtime_impl_helpers import SubscriptionManager @@ -38,11 +40,66 @@ async def get_client_id_or_abort(context: grpc.aio.ServicerContext[Any, Any]) -> return client_id # type: ignore +SendT = TypeVar("SendT") +ReceiveT = TypeVar("ReceiveT") + + +class ChannelConnection(ABC, Generic[SendT, ReceiveT]): + def __init__(self, request_iterator: AsyncIterator[ReceiveT], client_id: str) -> None: + self._request_iterator = request_iterator + self._client_id = client_id + self._send_queue: asyncio.Queue[SendT] = asyncio.Queue() + self._receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator)) + + async def _receive_messages(self, client_id: ClientConnectionId, request_iterator: AsyncIterator[ReceiveT]) -> None: + # Receive messages from the client and process them. + async for message in request_iterator: + logger.info(f"Received message from client {client_id}: {message}") + await self._handle_message(message) + + def __aiter__(self) -> AsyncIterator[SendT]: + return self + + async def __anext__(self) -> SendT: + try: + return await self._send_queue.get() + except StopAsyncIteration: + await self._receiving_task + raise + except Exception as e: + logger.error(f"Failed to get message from send queue: {e}", exc_info=True) + await self._receiving_task + raise + + @abstractmethod + async def _handle_message(self, message: ReceiveT) -> None: + pass + + async def send(self, message: SendT) -> None: + await self._send_queue.put(message) + + +class CallbackChannelConnection(ChannelConnection[SendT, ReceiveT]): + def __init__( + self, + request_iterator: AsyncIterator[ReceiveT], + client_id: str, + handle_callback: Callable[[ReceiveT], Awaitable[None]], + ) -> None: + self._handle_callback = handle_callback + super().__init__(request_iterator, client_id) + + async def _handle_message(self, message: ReceiveT) -> None: + await self._handle_callback(message) + + class GrpcWorkerAgentRuntimeHostServicer(agent_worker_pb2_grpc.AgentRpcServicer): """A gRPC servicer that hosts message delivery service for agents.""" def __init__(self) -> None: - self._send_queues: Dict[ClientConnectionId, asyncio.Queue[agent_worker_pb2.Message]] = {} + self._data_connections: Dict[ + ClientConnectionId, ChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message] + ] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {} self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {} @@ -57,32 +114,21 @@ async def OpenChannel( # type: ignore ) -> AsyncIterator[agent_worker_pb2.Message]: client_id = await get_client_id_or_abort(context) - # Register the client with the server and create a send queue for the client. - send_queue: asyncio.Queue[agent_worker_pb2.Message] = asyncio.Queue() - self._send_queues[client_id] = send_queue + async def handle_callback(message: agent_worker_pb2.Message) -> None: + await self._receive_message(client_id, message) + + connection = CallbackChannelConnection[agent_worker_pb2.Message, agent_worker_pb2.Message]( + request_iterator, client_id, handle_callback=handle_callback + ) + self._data_connections[client_id] = connection logger.info(f"Client {client_id} connected.") try: - # Concurrently handle receiving messages from the client and sending messages to the client. - # This task will receive messages from the client. - receiving_task = asyncio.create_task(self._receive_messages(client_id, request_iterator)) - - # Return an async generator that will yield messages from the send queue to the client. - while True: - message = await send_queue.get() - # Yield the message to the client. - try: - yield message - except Exception as e: - logger.error(f"Failed to send message to client {client_id}: {e}", exc_info=True) - break - logger.info(f"Sent message to client {client_id}: {message}") - # Wait for the receiving task to finish. - await receiving_task - + async for message in connection: + yield message finally: # Clean up the client connection. - del self._send_queues[client_id] + del self._data_connections[client_id] # Cancel pending requests sent to this client. for future in self._pending_responses.pop(client_id, {}).values(): future.cancel() @@ -105,33 +151,29 @@ def _raise_on_exception(self, task: Task[Any]) -> None: if exception is not None: raise exception - async def _receive_messages( - self, client_id: ClientConnectionId, request_iterator: AsyncIterator[agent_worker_pb2.Message] - ) -> None: - # Receive messages from the client and process them. - async for message in request_iterator: - logger.info(f"Received message from client {client_id}: {message}") - oneofcase = message.WhichOneof("message") - match oneofcase: - case "request": - request: agent_worker_pb2.RpcRequest = message.request - task = asyncio.create_task(self._process_request(request, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "response": - response: agent_worker_pb2.RpcResponse = message.response - task = asyncio.create_task(self._process_response(response, client_id)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case "cloudEvent": - task = asyncio.create_task(self._process_event(message.cloudEvent)) - self._background_tasks.add(task) - task.add_done_callback(self._raise_on_exception) - task.add_done_callback(self._background_tasks.discard) - case None: - logger.warning("Received empty message") + async def _receive_message(self, client_id: ClientConnectionId, message: agent_worker_pb2.Message) -> None: + logger.info(f"Received message from client {client_id}: {message}") + oneofcase = message.WhichOneof("message") + match oneofcase: + case "request": + request: agent_worker_pb2.RpcRequest = message.request + task = asyncio.create_task(self._process_request(request, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "response": + response: agent_worker_pb2.RpcResponse = message.response + task = asyncio.create_task(self._process_response(response, client_id)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case "cloudEvent": + task = asyncio.create_task(self._process_event(message.cloudEvent)) + self._background_tasks.add(task) + task.add_done_callback(self._raise_on_exception) + task.add_done_callback(self._background_tasks.discard) + case None: + logger.warning("Received empty message") async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id: ClientConnectionId) -> None: # Deliver the message to a client given the target agent type. @@ -140,11 +182,11 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id if target_client_id is None: logger.error(f"Agent {request.target.type} not found, failed to deliver message.") return - target_send_queue = self._send_queues.get(target_client_id) + target_send_queue = self._data_connections.get(target_client_id) if target_send_queue is None: logger.error(f"Client {target_client_id} not found, failed to deliver message.") return - await target_send_queue.put(agent_worker_pb2.Message(request=request)) + await target_send_queue.send(agent_worker_pb2.Message(request=request)) # Create a future to wait for the response from the target. future = asyncio.get_event_loop().create_future() @@ -161,11 +203,11 @@ async def _wait_and_send_response( ) -> None: response = await future message = agent_worker_pb2.Message(response=response) - send_queue = self._send_queues.get(client_id) + send_queue = self._data_connections.get(client_id) if send_queue is None: logger.error(f"Client {client_id} not found, failed to send response message.") return - await send_queue.put(message) + await send_queue.send(message) async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None: # Setting the result of the future will send the response back to the original sender. @@ -186,7 +228,7 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: logger.error(f"Agent {recipient.type} and its client not found for topic {topic_id}.") # Deliver the event to clients. for client_id in client_ids: - await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event)) + await self._data_connections[client_id].send(agent_worker_pb2.Message(cloudEvent=event)) async def RegisterAgent( # type: ignore self, From da6f91870813354217c56b040f2e664924da973d Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 6 Feb 2025 14:30:15 -0500 Subject: [PATCH 09/15] feat: add dotnet code coverage (#5403) --- .github/workflows/dotnet-build.yml | 95 ++++++++----------- .../Microsoft.AutoGen.Core.Grpc.Tests.csproj | 4 + .../Microsoft.AutoGen.Core.Tests.csproj | 4 + .../GrpcGatewayServiceTests.cs | 2 +- ...icrosoft.AutoGen.Runtime.Grpc.Tests.csproj | 4 + 5 files changed, 51 insertions(+), 58 deletions(-) diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml index bf7570239d61..8539502b9e33 100644 --- a/.github/workflows/dotnet-build.yml +++ b/.github/workflows/dotnet-build.yml @@ -46,8 +46,9 @@ jobs: - name: workflows has changes run: echo "workflows has changes" if: steps.filter.outputs.workflows == 'true' + build: - name: Dotnet Build + name: Dotnet Build & Test needs: paths-filter if: needs.paths-filter.outputs.hasChanges == 'true' defaults: @@ -64,35 +65,12 @@ jobs: - uses: actions/checkout@v4 with: lfs: true - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install jupyter and ipykernel - run: | - python -m pip install --upgrade pip - python -m pip install jupyter - python -m pip install ipykernel - - name: list available kernels - run: | - python -m jupyter kernelspec list - - uses: astral-sh/setup-uv@v5 - with: - enable-cache: true - version: "0.5.18" - - run: uv sync --locked --all-extras - working-directory: ./python - - name: Prepare python venv - run: | - source ${{ github.workspace }}/python/.venv/bin/activate - name: Setup .NET 8.0 uses: actions/setup-dotnet@v4 with: dotnet-version: '8.0.x' - name: Restore dependencies - run: | - # dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config - dotnet restore -bl + run: dotnet restore -bl - name: Format check run: | echo "Format check" @@ -104,40 +82,43 @@ jobs: dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true - name: Unit Test V1 run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV1" - - name: Unit Test V2 - run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV2" - - grpc-unit-tests: - name: Dotnet Grpc unit tests - needs: paths-filter - if: needs.paths-filter.outputs.hasChanges == 'true' - defaults: - run: - working-directory: dotnet - strategy: - fail-fast: false - matrix: - os: [ ubuntu-latest ] - runs-on: ${{ matrix.os }} - timeout-minutes: 30 - steps: - - uses: actions/checkout@v4 + - name: Unit Test V2 (With Coverage) + run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV2" --collect:"XPlat Code Coverage" + - name: Install Dev Certs for GRPC + if: matrix.os == 'ubuntu-latest' + run: dotnet dev-certs https --trust + - name: GRPC Tests (With Coverage) + if: matrix.os == 'ubuntu-latest' + run: dotnet test --no-build -bl --configuration Release --filter "Category=GRPC" --collect:"XPlat Code Coverage" + - name: Generate & Merge Coverage Report + if: matrix.os == 'ubuntu-latest' + run: | + # Install reportgenerator + dotnet tool install -g dotnet-reportgenerator-globaltool || dotnet tool update -g dotnet-reportgenerator-globaltool + # Ensure output directory exists + mkdir -p ${{ github.workspace }}/dotnet/coverage-report + # Merge all coverage reports and generate HTML + XML + reportgenerator \ + -reports:${{ github.workspace }}/dotnet/**/TestResults/**/coverage.cobertura.xml \ + -targetdir:${{ github.workspace }}/dotnet/coverage-report \ + -reporttypes:"Cobertura;Html" + ls -R ${{ github.workspace }}/dotnet/coverage-report + - name: Upload Merged Coverage Report + if: matrix.os == 'ubuntu-latest' + uses: actions/upload-artifact@v4 with: - lfs: true - - name: Setup .NET 8.0 - uses: actions/setup-dotnet@v4 + name: CodeCoverageReport + path: ${{ github.workspace }}/dotnet/coverage-report/ + retention-days: 7 + - name: Upload Coverage to Codecov + if: matrix.os == 'ubuntu-latest' + uses: codecov/codecov-action@v5 with: - dotnet-version: '8.0.x' - - name: Install dev certs - run: dotnet --version && dotnet dev-certs https --trust - - name: Restore dependencies - run: | - # dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config - dotnet restore -bl - - name: Build - run: dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true - - name: GRPC tests - run: dotnet test --no-build -bl --configuration Release --filter "Category=GRPC" + files: ${{ github.workspace }}/dotnet/coverage-report/*.xml + flags: unittests + name: dotnet-codecov + fail_ci_if_error: true + token: ${{ secrets.CODECOV_TOKEN }} integration-test: strategy: diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj index e3573c93451a..4f67d9727829 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj @@ -8,6 +8,10 @@ + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/Microsoft.AutoGen.Core.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Core.Tests/Microsoft.AutoGen.Core.Tests.csproj index 29165739b635..eed47c6238b0 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/Microsoft.AutoGen.Core.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/Microsoft.AutoGen.Core.Tests.csproj @@ -8,6 +8,10 @@ + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs index 90b6b2dddcf2..fcae1ec3dcdb 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/GrpcGatewayServiceTests.cs @@ -12,7 +12,7 @@ namespace Microsoft.AutoGen.Runtime.Grpc.Tests; [Collection(ClusterCollection.Name)] -[Trait("Category", "UnitV2")] +[Trait("Category", "GRPC")] public class GrpcGatewayServiceTests { private readonly ClusterFixture _fixture; diff --git a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj index ab0899c0f169..c8b00ee268b0 100644 --- a/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Runtime.Grpc.Tests/Microsoft.AutoGen.Runtime.Grpc.Tests.csproj @@ -8,6 +8,10 @@ + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + From 25f26a338bc70bcb406406433c7bda554588484e Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Thu, 6 Feb 2025 16:54:21 -0500 Subject: [PATCH 10/15] Updates to proto for state apis (#5407) --- .../GrpcAgentServiceFixture.cs | 2 - protos/agent_worker.proto | 57 ++++--- .../grpc/_worker_runtime_host_servicer.py | 21 +-- .../runtimes/grpc/protos/agent_worker_pb2.py | 26 ++-- .../runtimes/grpc/protos/agent_worker_pb2.pyi | 144 +++++++++++++----- .../grpc/protos/agent_worker_pb2_grpc.py | 73 ++------- .../grpc/protos/agent_worker_pb2_grpc.pyi | 35 ++--- 7 files changed, 186 insertions(+), 172 deletions(-) diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs index 98c47764269d..1ca37809a57e 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs @@ -25,8 +25,6 @@ public override async Task OpenChannel(IAsyncStreamReader requestStream throw; } } - public override async Task GetState(AgentId request, ServerCallContext context) => new GetStateResponse { AgentState = new AgentState { AgentId = request } }; - public override async Task SaveState(AgentState request, ServerCallContext context) => new SaveStateResponse { }; public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) => new AddSubscriptionResponse { }; public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) => new RemoveSubscriptionResponse { }; public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) => new GetSubscriptionsResponse { }; diff --git a/protos/agent_worker.proto b/protos/agent_worker.proto index 07375fe97250..52fe809a20c9 100644 --- a/protos/agent_worker.proto +++ b/protos/agent_worker.proto @@ -79,23 +79,6 @@ message GetSubscriptionsResponse { repeated Subscription subscriptions = 1; } -message AgentState { - AgentId agent_id = 1; - string eTag = 2; - oneof data { - bytes binary_data = 3; - string text_data = 4; - google.protobuf.Any proto_data = 5; - } -} - -message GetStateResponse { - AgentState agent_state = 1; -} - -message SaveStateResponse { -} - message Message { oneof message { RpcRequest request = 1; @@ -104,10 +87,46 @@ message Message { } } +message SaveStateRequest { + AgentId agentId = 1; +} + +message SaveStateResponse { + string state = 1; + optional string error = 2; +} + +message LoadStateRequest { + AgentId agentId = 1; + string state = 2; +} +message LoadStateResponse { + optional string error = 1; +} + +message ControlMessage { + // A response message should have the same id as the request message + string rpc_id = 1; + // This is either: + // agentid=AGENT_ID + // clientid=CLIENT_ID + string destination = 2; + // This is either: + // agentid=AGENT_ID + // clientid=CLIENT_ID + // Empty string means the message is a response + optional string respond_to = 3; + // One of: + // SaveStateRequest saveStateRequest = 2; + // SaveStateResponse saveStateResponse = 3; + // LoadStateRequest loadStateRequest = 4; + // LoadStateResponse loadStateResponse = 5; + google.protobuf.Any rpcMessage = 4; +} + service AgentRpc { rpc OpenChannel (stream Message) returns (stream Message); - rpc GetState(AgentId) returns (GetStateResponse); - rpc SaveState(AgentState) returns (SaveStateResponse); + rpc OpenControlChannel (stream ControlMessage) returns (stream ControlMessage); rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse); rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse); rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse); diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 84df493bd1e8..daa4ad65101d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -135,6 +135,13 @@ async def handle_callback(message: agent_worker_pb2.Message) -> None: # Remove the client id from the agent type to client id mapping. await self._on_client_disconnect(client_id) + async def OpenControlChannel( # type: ignore + self, + request_iterator: AsyncIterator[agent_worker_pb2.ControlMessage], + context: grpc.aio.ServicerContext[agent_worker_pb2.ControlMessage, agent_worker_pb2.ControlMessage], + ) -> AsyncIterator[agent_worker_pb2.ControlMessage]: + raise NotImplementedError("Method not implemented.") + async def _on_client_disconnect(self, client_id: ClientConnectionId) -> None: async with self._agent_type_to_client_id_lock: agent_types = [agent_type for agent_type, id_ in self._agent_type_to_client_id.items() if id_ == client_id] @@ -288,17 +295,3 @@ async def GetSubscriptions( # type: ignore ) -> agent_worker_pb2.GetSubscriptionsResponse: _client_id = await get_client_id_or_abort(context) raise NotImplementedError("Method not implemented.") - - async def GetState( # type: ignore - self, - request: agent_worker_pb2.AgentId, - context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.GetStateResponse], - ) -> agent_worker_pb2.GetStateResponse: - raise NotImplementedError("Method not implemented!") - - async def SaveState( # type: ignore - self, - request: agent_worker_pb2.AgentState, - context: grpc.aio.ServicerContext[agent_worker_pb2.AgentId, agent_worker_pb2.SaveStateResponse], - ) -> agent_worker_pb2.SaveStateResponse: - raise NotImplementedError("Method not implemented!") diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py index b3e0af61f049..54209d2fb284 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.py @@ -26,7 +26,7 @@ from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x9d\x01\n\nAgentState\x12!\n\x08\x61gent_id\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0c\n\x04\x65Tag\x18\x02 \x01(\t\x12\x15\n\x0b\x62inary_data\x18\x03 \x01(\x0cH\x00\x12\x13\n\ttext_data\x18\x04 \x01(\tH\x00\x12*\n\nproto_data\x18\x05 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00\x42\x06\n\x04\x64\x61ta\";\n\x10GetStateResponse\x12\'\n\x0b\x61gent_state\x18\x01 \x01(\x0b\x32\x12.agents.AgentState\"\x13\n\x11SaveStateResponse\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message2\x90\x04\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12\x35\n\x08GetState\x12\x0f.agents.AgentId\x1a\x18.agents.GetStateResponse\x12:\n\tSaveState\x12\x12.agents.AgentState\x1a\x19.agents.SaveStateResponse\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x12\x61gent_worker.proto\x12\x06\x61gents\x1a\x10\x63loudevent.proto\x1a\x19google/protobuf/any.proto\"$\n\x07\x41gentId\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\x0b\n\x03key\x18\x02 \x01(\t\"E\n\x07Payload\x12\x11\n\tdata_type\x18\x01 \x01(\t\x12\x19\n\x11\x64\x61ta_content_type\x18\x02 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\"\x89\x02\n\nRpcRequest\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12$\n\x06source\x18\x02 \x01(\x0b\x32\x0f.agents.AgentIdH\x00\x88\x01\x01\x12\x1f\n\x06target\x18\x03 \x01(\x0b\x32\x0f.agents.AgentId\x12\x0e\n\x06method\x18\x04 \x01(\t\x12 \n\x07payload\x18\x05 \x01(\x0b\x32\x0f.agents.Payload\x12\x32\n\x08metadata\x18\x06 \x03(\x0b\x32 .agents.RpcRequest.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\t\n\x07_source\"\xb8\x01\n\x0bRpcResponse\x12\x12\n\nrequest_id\x18\x01 \x01(\t\x12 \n\x07payload\x18\x02 \x01(\x0b\x32\x0f.agents.Payload\x12\r\n\x05\x65rror\x18\x03 \x01(\t\x12\x33\n\x08metadata\x18\x04 \x03(\x0b\x32!.agents.RpcResponse.MetadataEntry\x1a/\n\rMetadataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"(\n\x18RegisterAgentTypeRequest\x12\x0c\n\x04type\x18\x01 \x01(\t\"\x1b\n\x19RegisterAgentTypeResponse\":\n\x10TypeSubscription\x12\x12\n\ntopic_type\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"G\n\x16TypePrefixSubscription\x12\x19\n\x11topic_type_prefix\x18\x01 \x01(\t\x12\x12\n\nagent_type\x18\x02 \x01(\t\"\xa2\x01\n\x0cSubscription\x12\n\n\x02id\x18\x01 \x01(\t\x12\x34\n\x10typeSubscription\x18\x02 \x01(\x0b\x32\x18.agents.TypeSubscriptionH\x00\x12@\n\x16typePrefixSubscription\x18\x03 \x01(\x0b\x32\x1e.agents.TypePrefixSubscriptionH\x00\x42\x0e\n\x0csubscription\"D\n\x16\x41\x64\x64SubscriptionRequest\x12*\n\x0csubscription\x18\x01 \x01(\x0b\x32\x14.agents.Subscription\"\x19\n\x17\x41\x64\x64SubscriptionResponse\"\'\n\x19RemoveSubscriptionRequest\x12\n\n\x02id\x18\x01 \x01(\t\"\x1c\n\x1aRemoveSubscriptionResponse\"\x19\n\x17GetSubscriptionsRequest\"G\n\x18GetSubscriptionsResponse\x12+\n\rsubscriptions\x18\x01 \x03(\x0b\x32\x14.agents.Subscription\"\x99\x01\n\x07Message\x12%\n\x07request\x18\x01 \x01(\x0b\x32\x12.agents.RpcRequestH\x00\x12\'\n\x08response\x18\x02 \x01(\x0b\x32\x13.agents.RpcResponseH\x00\x12\x33\n\ncloudEvent\x18\x03 \x01(\x0b\x32\x1d.io.cloudevents.v1.CloudEventH\x00\x42\t\n\x07message\"4\n\x10SaveStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\"@\n\x11SaveStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"C\n\x10LoadStateRequest\x12 \n\x07\x61gentId\x18\x01 \x01(\x0b\x32\x0f.agents.AgentId\x12\r\n\x05state\x18\x02 \x01(\t\"1\n\x11LoadStateResponse\x12\x12\n\x05\x65rror\x18\x01 \x01(\tH\x00\x88\x01\x01\x42\x08\n\x06_error\"\x87\x01\n\x0e\x43ontrolMessage\x12\x0e\n\x06rpc_id\x18\x01 \x01(\t\x12\x13\n\x0b\x64\x65stination\x18\x02 \x01(\t\x12\x17\n\nrespond_to\x18\x03 \x01(\tH\x00\x88\x01\x01\x12(\n\nrpcMessage\x18\x04 \x01(\x0b\x32\x14.google.protobuf.AnyB\r\n\x0b_respond_to2\xe7\x03\n\x08\x41gentRpc\x12\x33\n\x0bOpenChannel\x12\x0f.agents.Message\x1a\x0f.agents.Message(\x01\x30\x01\x12H\n\x12OpenControlChannel\x12\x16.agents.ControlMessage\x1a\x16.agents.ControlMessage(\x01\x30\x01\x12T\n\rRegisterAgent\x12 .agents.RegisterAgentTypeRequest\x1a!.agents.RegisterAgentTypeResponse\x12R\n\x0f\x41\x64\x64Subscription\x12\x1e.agents.AddSubscriptionRequest\x1a\x1f.agents.AddSubscriptionResponse\x12[\n\x12RemoveSubscription\x12!.agents.RemoveSubscriptionRequest\x1a\".agents.RemoveSubscriptionResponse\x12U\n\x10GetSubscriptions\x12\x1f.agents.GetSubscriptionsRequest\x1a .agents.GetSubscriptionsResponseB\x1d\xaa\x02\x1aMicrosoft.AutoGen.Protobufb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -72,14 +72,18 @@ _globals['_GETSUBSCRIPTIONSREQUEST']._serialized_end=1201 _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_start=1203 _globals['_GETSUBSCRIPTIONSRESPONSE']._serialized_end=1274 - _globals['_AGENTSTATE']._serialized_start=1277 - _globals['_AGENTSTATE']._serialized_end=1434 - _globals['_GETSTATERESPONSE']._serialized_start=1436 - _globals['_GETSTATERESPONSE']._serialized_end=1495 - _globals['_SAVESTATERESPONSE']._serialized_start=1497 - _globals['_SAVESTATERESPONSE']._serialized_end=1516 - _globals['_MESSAGE']._serialized_start=1519 - _globals['_MESSAGE']._serialized_end=1672 - _globals['_AGENTRPC']._serialized_start=1675 - _globals['_AGENTRPC']._serialized_end=2203 + _globals['_MESSAGE']._serialized_start=1277 + _globals['_MESSAGE']._serialized_end=1430 + _globals['_SAVESTATEREQUEST']._serialized_start=1432 + _globals['_SAVESTATEREQUEST']._serialized_end=1484 + _globals['_SAVESTATERESPONSE']._serialized_start=1486 + _globals['_SAVESTATERESPONSE']._serialized_end=1550 + _globals['_LOADSTATEREQUEST']._serialized_start=1552 + _globals['_LOADSTATEREQUEST']._serialized_end=1619 + _globals['_LOADSTATERESPONSE']._serialized_start=1621 + _globals['_LOADSTATERESPONSE']._serialized_end=1670 + _globals['_CONTROLMESSAGE']._serialized_start=1673 + _globals['_CONTROLMESSAGE']._serialized_end=1808 + _globals['_AGENTRPC']._serialized_start=1811 + _globals['_AGENTRPC']._serialized_end=2298 # @@protoc_insertion_point(module_scope) diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi index b37fb5ac2979..a12c53e73a7c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2.pyi @@ -313,85 +313,145 @@ class GetSubscriptionsResponse(google.protobuf.message.Message): global___GetSubscriptionsResponse = GetSubscriptionsResponse @typing.final -class AgentState(google.protobuf.message.Message): +class Message(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - AGENT_ID_FIELD_NUMBER: builtins.int - ETAG_FIELD_NUMBER: builtins.int - BINARY_DATA_FIELD_NUMBER: builtins.int - TEXT_DATA_FIELD_NUMBER: builtins.int - PROTO_DATA_FIELD_NUMBER: builtins.int - eTag: builtins.str - binary_data: builtins.bytes - text_data: builtins.str + REQUEST_FIELD_NUMBER: builtins.int + RESPONSE_FIELD_NUMBER: builtins.int + CLOUDEVENT_FIELD_NUMBER: builtins.int @property - def agent_id(self) -> global___AgentId: ... + def request(self) -> global___RpcRequest: ... + @property + def response(self) -> global___RpcResponse: ... @property - def proto_data(self) -> google.protobuf.any_pb2.Any: ... + def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... def __init__( self, *, - agent_id: global___AgentId | None = ..., - eTag: builtins.str = ..., - binary_data: builtins.bytes = ..., - text_data: builtins.str = ..., - proto_data: google.protobuf.any_pb2.Any | None = ..., + request: global___RpcRequest | None = ..., + response: global___RpcResponse | None = ..., + cloudEvent: cloudevent_pb2.CloudEvent | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "proto_data", b"proto_data", "text_data", b"text_data"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["agent_id", b"agent_id", "binary_data", b"binary_data", "data", b"data", "eTag", b"eTag", "proto_data", b"proto_data", "text_data", b"text_data"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["data", b"data"]) -> typing.Literal["binary_data", "text_data", "proto_data"] | None: ... + def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ... -global___AgentState = AgentState +global___Message = Message @typing.final -class GetStateResponse(google.protobuf.message.Message): +class SaveStateRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - AGENT_STATE_FIELD_NUMBER: builtins.int + AGENTID_FIELD_NUMBER: builtins.int @property - def agent_state(self) -> global___AgentState: ... + def agentId(self) -> global___AgentId: ... def __init__( self, *, - agent_state: global___AgentState | None = ..., + agentId: global___AgentId | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["agent_state", b"agent_state"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["agent_state", b"agent_state"]) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId"]) -> None: ... -global___GetStateResponse = GetStateResponse +global___SaveStateRequest = SaveStateRequest @typing.final class SaveStateResponse(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor + STATE_FIELD_NUMBER: builtins.int + ERROR_FIELD_NUMBER: builtins.int + state: builtins.str + error: builtins.str def __init__( self, + *, + state: builtins.str = ..., + error: builtins.str | None = ..., ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error", "state", b"state"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... global___SaveStateResponse = SaveStateResponse @typing.final -class Message(google.protobuf.message.Message): +class LoadStateRequest(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor - REQUEST_FIELD_NUMBER: builtins.int - RESPONSE_FIELD_NUMBER: builtins.int - CLOUDEVENT_FIELD_NUMBER: builtins.int - @property - def request(self) -> global___RpcRequest: ... + AGENTID_FIELD_NUMBER: builtins.int + STATE_FIELD_NUMBER: builtins.int + state: builtins.str @property - def response(self) -> global___RpcResponse: ... + def agentId(self) -> global___AgentId: ... + def __init__( + self, + *, + agentId: global___AgentId | None = ..., + state: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["agentId", b"agentId"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["agentId", b"agentId", "state", b"state"]) -> None: ... + +global___LoadStateRequest = LoadStateRequest + +@typing.final +class LoadStateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ERROR_FIELD_NUMBER: builtins.int + error: builtins.str + def __init__( + self, + *, + error: builtins.str | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_error", b"_error", "error", b"error"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_error", b"_error"]) -> typing.Literal["error"] | None: ... + +global___LoadStateResponse = LoadStateResponse + +@typing.final +class ControlMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + RPC_ID_FIELD_NUMBER: builtins.int + DESTINATION_FIELD_NUMBER: builtins.int + RESPOND_TO_FIELD_NUMBER: builtins.int + RPCMESSAGE_FIELD_NUMBER: builtins.int + rpc_id: builtins.str + """A response message should have the same id as the request message""" + destination: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + """ + respond_to: builtins.str + """This is either: + agentid=AGENT_ID + clientid=CLIENT_ID + Empty string means the message is a response + """ @property - def cloudEvent(self) -> cloudevent_pb2.CloudEvent: ... + def rpcMessage(self) -> google.protobuf.any_pb2.Any: + """One of: + SaveStateRequest saveStateRequest = 2; + SaveStateResponse saveStateResponse = 3; + LoadStateRequest loadStateRequest = 4; + LoadStateResponse loadStateResponse = 5; + """ + def __init__( self, *, - request: global___RpcRequest | None = ..., - response: global___RpcResponse | None = ..., - cloudEvent: cloudevent_pb2.CloudEvent | None = ..., + rpc_id: builtins.str = ..., + destination: builtins.str = ..., + respond_to: builtins.str | None = ..., + rpcMessage: google.protobuf.any_pb2.Any | None = ..., ) -> None: ... - def HasField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> builtins.bool: ... - def ClearField(self, field_name: typing.Literal["cloudEvent", b"cloudEvent", "message", b"message", "request", b"request", "response", b"response"]) -> None: ... - def WhichOneof(self, oneof_group: typing.Literal["message", b"message"]) -> typing.Literal["request", "response", "cloudEvent"] | None: ... + def HasField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["_respond_to", b"_respond_to", "destination", b"destination", "respond_to", b"respond_to", "rpcMessage", b"rpcMessage", "rpc_id", b"rpc_id"]) -> None: ... + def WhichOneof(self, oneof_group: typing.Literal["_respond_to", b"_respond_to"]) -> typing.Literal["respond_to"] | None: ... -global___Message = Message +global___ControlMessage = ControlMessage diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py index 85fd64f42ccb..4a86f17f04ae 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.py @@ -39,15 +39,10 @@ def __init__(self, channel): request_serializer=agent__worker__pb2.Message.SerializeToString, response_deserializer=agent__worker__pb2.Message.FromString, _registered_method=True) - self.GetState = channel.unary_unary( - '/agents.AgentRpc/GetState', - request_serializer=agent__worker__pb2.AgentId.SerializeToString, - response_deserializer=agent__worker__pb2.GetStateResponse.FromString, - _registered_method=True) - self.SaveState = channel.unary_unary( - '/agents.AgentRpc/SaveState', - request_serializer=agent__worker__pb2.AgentState.SerializeToString, - response_deserializer=agent__worker__pb2.SaveStateResponse.FromString, + self.OpenControlChannel = channel.stream_stream( + '/agents.AgentRpc/OpenControlChannel', + request_serializer=agent__worker__pb2.ControlMessage.SerializeToString, + response_deserializer=agent__worker__pb2.ControlMessage.FromString, _registered_method=True) self.RegisterAgent = channel.unary_unary( '/agents.AgentRpc/RegisterAgent', @@ -80,13 +75,7 @@ def OpenChannel(self, request_iterator, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') - def GetState(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - - def SaveState(self, request, context): + def OpenControlChannel(self, request_iterator, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details('Method not implemented!') @@ -124,15 +113,10 @@ def add_AgentRpcServicer_to_server(servicer, server): request_deserializer=agent__worker__pb2.Message.FromString, response_serializer=agent__worker__pb2.Message.SerializeToString, ), - 'GetState': grpc.unary_unary_rpc_method_handler( - servicer.GetState, - request_deserializer=agent__worker__pb2.AgentId.FromString, - response_serializer=agent__worker__pb2.GetStateResponse.SerializeToString, - ), - 'SaveState': grpc.unary_unary_rpc_method_handler( - servicer.SaveState, - request_deserializer=agent__worker__pb2.AgentState.FromString, - response_serializer=agent__worker__pb2.SaveStateResponse.SerializeToString, + 'OpenControlChannel': grpc.stream_stream_rpc_method_handler( + servicer.OpenControlChannel, + request_deserializer=agent__worker__pb2.ControlMessage.FromString, + response_serializer=agent__worker__pb2.ControlMessage.SerializeToString, ), 'RegisterAgent': grpc.unary_unary_rpc_method_handler( servicer.RegisterAgent, @@ -193,34 +177,7 @@ def OpenChannel(request_iterator, _registered_method=True) @staticmethod - def GetState(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/agents.AgentRpc/GetState', - agent__worker__pb2.AgentId.SerializeToString, - agent__worker__pb2.GetStateResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - - @staticmethod - def SaveState(request, + def OpenControlChannel(request_iterator, target, options=(), channel_credentials=None, @@ -230,12 +187,12 @@ def SaveState(request, wait_for_ready=None, timeout=None, metadata=None): - return grpc.experimental.unary_unary( - request, + return grpc.experimental.stream_stream( + request_iterator, target, - '/agents.AgentRpc/SaveState', - agent__worker__pb2.AgentState.SerializeToString, - agent__worker__pb2.SaveStateResponse.FromString, + '/agents.AgentRpc/OpenControlChannel', + agent__worker__pb2.ControlMessage.SerializeToString, + agent__worker__pb2.ControlMessage.FromString, options, channel_credentials, insecure, diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi index ce8a7c12ec69..cc4311825112 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/protos/agent_worker_pb2_grpc.pyi @@ -24,14 +24,9 @@ class AgentRpcStub: agent_worker_pb2.Message, ] - GetState: grpc.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentId, - agent_worker_pb2.GetStateResponse, - ] - - SaveState: grpc.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentState, - agent_worker_pb2.SaveStateResponse, + OpenControlChannel: grpc.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, ] RegisterAgent: grpc.UnaryUnaryMultiCallable[ @@ -60,14 +55,9 @@ class AgentRpcAsyncStub: agent_worker_pb2.Message, ] - GetState: grpc.aio.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentId, - agent_worker_pb2.GetStateResponse, - ] - - SaveState: grpc.aio.UnaryUnaryMultiCallable[ - agent_worker_pb2.AgentState, - agent_worker_pb2.SaveStateResponse, + OpenControlChannel: grpc.aio.StreamStreamMultiCallable[ + agent_worker_pb2.ControlMessage, + agent_worker_pb2.ControlMessage, ] RegisterAgent: grpc.aio.UnaryUnaryMultiCallable[ @@ -99,18 +89,11 @@ class AgentRpcServicer(metaclass=abc.ABCMeta): ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.Message], collections.abc.AsyncIterator[agent_worker_pb2.Message]]: ... @abc.abstractmethod - def GetState( - self, - request: agent_worker_pb2.AgentId, - context: _ServicerContext, - ) -> typing.Union[agent_worker_pb2.GetStateResponse, collections.abc.Awaitable[agent_worker_pb2.GetStateResponse]]: ... - - @abc.abstractmethod - def SaveState( + def OpenControlChannel( self, - request: agent_worker_pb2.AgentState, + request_iterator: _MaybeAsyncIterator[agent_worker_pb2.ControlMessage], context: _ServicerContext, - ) -> typing.Union[agent_worker_pb2.SaveStateResponse, collections.abc.Awaitable[agent_worker_pb2.SaveStateResponse]]: ... + ) -> typing.Union[collections.abc.Iterator[agent_worker_pb2.ControlMessage], collections.abc.AsyncIterator[agent_worker_pb2.ControlMessage]]: ... @abc.abstractmethod def RegisterAgent( From c8e4ad82423ff8f5d8560e16b495c79b0989b94d Mon Sep 17 00:00:00 2001 From: Griffin Bassman Date: Thu, 6 Feb 2025 17:09:26 -0500 Subject: [PATCH 11/15] feat: save/load test for dotnet agents (#5284) --- .../Microsoft.AutoGen/Contracts/AgentProxy.cs | 6 +- .../Contracts/IAgentRuntime.cs | 2 +- .../Microsoft.AutoGen/Contracts/ISaveState.cs | 2 +- .../Core.Grpc/GrpcAgentRuntime.cs | 39 ++--- .../src/Microsoft.AutoGen/Core/BaseAgent.cs | 7 +- .../Core/InProcessRuntime.cs | 26 ++-- .../Core/Properties/AssemblyInfo.cs | 6 + .../AgentRuntimeTests.cs | 83 ----------- .../AgentTests.cs | 23 +-- .../InProcessRuntimeTests.cs | 141 ++++++++++++++++++ .../Microsoft.AutoGen.Core.Tests/TestAgent.cs | 35 ++++- 11 files changed, 230 insertions(+), 140 deletions(-) create mode 100644 dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs delete mode 100644 dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs create mode 100644 dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs b/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs index 44ad9b0e10b2..d37d6284b7d0 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/AgentProxy.cs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentProxy.cs +using System.Text.Json; + namespace Microsoft.AutoGen.Contracts; /// @@ -55,7 +57,7 @@ private T ExecuteAndUnwrap(Func> delegate_) /// /// A dictionary representing the state of the agent. Must be JSON serializable. /// A task representing the asynchronous operation. - public ValueTask LoadStateAsync(IDictionary state) + public ValueTask LoadStateAsync(IDictionary state) { return this.runtime.LoadAgentStateAsync(this.Id, state); } @@ -64,7 +66,7 @@ public ValueTask LoadStateAsync(IDictionary state) /// Saves the state of the agent. The result must be JSON serializable. /// /// A task representing the asynchronous operation, returning a dictionary containing the saved state. - public ValueTask> SaveStateAsync() + public ValueTask> SaveStateAsync() { return this.runtime.SaveAgentStateAsync(this.Id); } diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs index 0d84fbe72d37..c4b2e998f1b0 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/IAgentRuntime.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // IAgentRuntime.cs -using StateDict = System.Collections.Generic.IDictionary; +using StateDict = System.Collections.Generic.IDictionary; namespace Microsoft.AutoGen.Contracts; diff --git a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs index ed6d15d1d8d6..4f98f1fc4842 100644 --- a/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs +++ b/dotnet/src/Microsoft.AutoGen/Contracts/ISaveState.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // ISaveState.cs -using StateDict = System.Collections.Generic.IDictionary; +using StateDict = System.Collections.Generic.IDictionary; namespace Microsoft.AutoGen.Contracts; diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs index 1ff1036016d1..46114884326b 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -2,6 +2,7 @@ // GrpcAgentRuntime.cs using System.Collections.Concurrent; +using System.Text.Json; using Grpc.Core; using Microsoft.AutoGen.Contracts; using Microsoft.AutoGen.Protobuf; @@ -319,13 +320,13 @@ public async ValueTask PublishMessageAsync(object message, TopicId topic, Contra public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) => this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy); - public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); return await agent.SaveStateAsync(); } - public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); await agent.LoadStateAsync(state); @@ -375,37 +376,41 @@ public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) return ValueTask.FromResult(new AgentProxy(agentId, this)); } - public async ValueTask> SaveStateAsync() - { - Dictionary state = new(); - foreach (var agent in this._agentsContainer.LiveAgents) - { - state[agent.Id.ToString()] = await agent.SaveStateAsync(); - } - - return state; - } - - public async ValueTask LoadStateAsync(IDictionary state) + public async ValueTask LoadStateAsync(IDictionary state) { HashSet registeredTypes = this._agentsContainer.RegisteredAgentTypes; foreach (var agentIdStr in state.Keys) { Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr); - if (state[agentIdStr] is not IDictionary agentStateDict) + + if (state[agentIdStr].ValueKind != JsonValueKind.Object) { - throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + throw new Exception($"Agent state for {agentId} is not a valid JSON object."); } + var agentState = JsonSerializer.Deserialize>(state[agentIdStr].GetRawText()) + ?? throw new Exception($"Failed to deserialize state for {agentId}."); + if (registeredTypes.Contains(agentId.Type)) { IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); - await agent.LoadStateAsync(agentStateDict); + await agent.LoadStateAsync(agentState); } } } + public async ValueTask> SaveStateAsync() + { + Dictionary state = new(); + foreach (var agent in this._agentsContainer.LiveAgents) + { + var agentState = await agent.SaveStateAsync(); + state[agent.Id.ToString()] = JsonSerializer.SerializeToElement(agentState); + } + return state; + } + public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default) { switch (message.MessageCase) diff --git a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs index 99ff001ba98a..a3899280fef4 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/BaseAgent.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -92,11 +93,11 @@ private Dictionary ReflectInvokers() return null; } - public virtual ValueTask> SaveStateAsync() + public virtual ValueTask> SaveStateAsync() { - return ValueTask.FromResult>(new Dictionary()); + return ValueTask.FromResult>(new Dictionary()); } - public virtual ValueTask LoadStateAsync(IDictionary state) + public virtual ValueTask LoadStateAsync(IDictionary state) { return ValueTask.CompletedTask; } diff --git a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs index 69b2d314e550..9acf96e648fc 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/InProcessRuntime.cs @@ -3,6 +3,7 @@ using System.Collections.Concurrent; using System.Diagnostics; +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Hosting; @@ -12,7 +13,7 @@ public sealed class InProcessRuntime : IAgentRuntime, IHostedService { public bool DeliverToSelf { get; set; } //= false; - Dictionary agentInstances = new(); + internal Dictionary agentInstances = new(); Dictionary subscriptions = new(); Dictionary>> agentFactories = new(); @@ -152,13 +153,13 @@ public async ValueTask GetAgentMetadataAsync(AgentId agentId) return agent.Metadata; } - public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary state) + public async ValueTask LoadAgentStateAsync(AgentId agentId, IDictionary state) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); await agent.LoadStateAsync(state); } - public async ValueTask> SaveAgentStateAsync(AgentId agentId) + public async ValueTask> SaveAgentStateAsync(AgentId agentId) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); return await agent.SaveStateAsync(); @@ -187,16 +188,21 @@ public ValueTask RemoveSubscriptionAsync(string subscriptionId) return ValueTask.CompletedTask; } - public async ValueTask LoadStateAsync(IDictionary state) + public async ValueTask LoadStateAsync(IDictionary state) { foreach (var agentIdStr in state.Keys) { AgentId agentId = AgentId.FromStr(agentIdStr); - if (state[agentIdStr] is not IDictionary agentState) + + if (state[agentIdStr].ValueKind != JsonValueKind.Object) { - throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + throw new Exception($"Agent state for {agentId} is not a valid JSON object."); } + // Deserialize before using + var agentState = JsonSerializer.Deserialize>(state[agentIdStr].GetRawText()) + ?? throw new Exception($"Failed to deserialize state for {agentId}."); + if (this.agentFactories.ContainsKey(agentId.Type)) { IHostableAgent agent = await this.EnsureAgentAsync(agentId); @@ -205,14 +211,14 @@ public async ValueTask LoadStateAsync(IDictionary state) } } - public async ValueTask> SaveStateAsync() + public async ValueTask> SaveStateAsync() { - Dictionary state = new(); + Dictionary state = new(); foreach (var agentId in this.agentInstances.Keys) { - state[agentId.ToString()] = await this.agentInstances[agentId].SaveStateAsync(); + var agentState = await this.agentInstances[agentId].SaveStateAsync(); + state[agentId.ToString()] = JsonSerializer.SerializeToElement(agentState); } - return state; } diff --git a/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs b/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs new file mode 100644 index 000000000000..8ff44481719e --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AssemblyInfo.cs + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AutoGen.Core.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f1d038d0b85ae392ad72011df91e9343b0b5df1bb8080aa21b9424362d696919e0e9ac3a8bca24e283e10f7a569c6f443e1d4e3ebc84377c87ca5caa562e80f9932bf5ea91b7862b538e13b8ba91c7565cf0e8dfeccfea9c805ae3bda044170ecc7fc6f147aeeac422dd96aeb9eb1f5a5882aa650efe2958f2f8107d2038f2ab")] diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs deleted file mode 100644 index 812d47c2d207..000000000000 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentRuntimeTests.cs +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// AgentRuntimeTests.cs -using FluentAssertions; -using Microsoft.AutoGen.Contracts; -using Microsoft.Extensions.Logging; -using Xunit; - -namespace Microsoft.AutoGen.Core.Tests; - -[Trait("Category", "UnitV2")] -public class AgentRuntimeTests() -{ - // Agent will not deliver to self will success when runtime.DeliverToSelf is false (default) - [Fact] - public async Task RuntimeAgentPublishToSelfDefaultNoSendTest() - { - var runtime = new InProcessRuntime(); - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - SubscribedSelfPublishAgent agent = null!; - - await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => - { - agent = new SubscribedSelfPublishAgent(id, runtime, logger); - return ValueTask.FromResult(agent); - }); - - // Ensure the agent is actually created - AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); - - // Validate agent ID - agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); - - await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); - - var topicType = "TestTopic"; - - await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); - - await runtime.RunUntilIdleAsync(); - - // Agent has default messages and could not publish to self - agent.Text.Source.Should().Be("DefaultTopic"); - agent.Text.Content.Should().Be("DefaultContent"); - } - - // Agent delivery to self will success when runtime.DeliverToSelf is true - [Fact] - public async Task RuntimeAgentPublishToSelfDeliverToSelfTrueTest() - { - var runtime = new InProcessRuntime(); - runtime.DeliverToSelf = true; - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - SubscribedSelfPublishAgent agent = null!; - - await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => - { - agent = new SubscribedSelfPublishAgent(id, runtime, logger); - return ValueTask.FromResult(agent); - }); - - // Ensure the agent is actually created - AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); - - // Validate agent ID - agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); - - await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); - - var topicType = "TestTopic"; - - await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); - - await runtime.RunUntilIdleAsync(); - - // Agent sucessfully published to self - agent.Text.Source.Should().Be("TestTopic"); - agent.Text.Content.Should().Be("SelfMessage"); - } -} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs index c091f9eb7478..805fbc87102b 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs @@ -54,7 +54,7 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => return ValueTask.FromResult(agent); }); - // Ensure the agent is actually created + // Ensure the agent id is registered AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); // Validate agent ID @@ -146,25 +146,4 @@ await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => Assert.True(agent.ReceivedItems.Count == 1); } - - [Fact] - public async Task AgentShouldSaveStateCorrectlyTest() - { - var runtime = new InProcessRuntime(); - await runtime.StartAsync(); - - Logger logger = new(new LoggerFactory()); - TestAgent agent = new TestAgent(new AgentId("TestType", "TestKey"), runtime, logger); - - var state = await agent.SaveStateAsync(); - - // Ensure state is a dictionary - state.Should().NotBeNull(); - state.Should().BeOfType>(); - state.Should().BeEmpty("Default SaveStateAsync should return an empty dictionary."); - - // Add a sample value and verify it updates correctly - state["testKey"] = "testValue"; - state.Should().ContainKey("testKey").WhoseValue.Should().Be("testValue"); - } } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs new file mode 100644 index 000000000000..174f8b7817c2 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/InProcessRuntimeTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// InProcessRuntimeTests.cs +using System.Text.Json; +using FluentAssertions; +using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AutoGen.Core.Tests; + +[Trait("Category", "UnitV2")] +public class InProcessRuntimeTests() +{ + // Agent will not deliver to self will success when runtime.DeliverToSelf is false (default) + [Fact] + public async Task RuntimeAgentPublishToSelfDefaultNoSendTest() + { + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent has default messages and could not publish to self + agent.Text.Source.Should().Be("DefaultTopic"); + agent.Text.Content.Should().Be("DefaultContent"); + } + + // Agent delivery to self will success when runtime.DeliverToSelf is true + [Fact] + public async Task RuntimeAgentPublishToSelfDeliverToSelfTrueTest() + { + var runtime = new InProcessRuntime(); + runtime.DeliverToSelf = true; + await runtime.StartAsync(); + + Logger logger = new(new LoggerFactory()); + SubscribedSelfPublishAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSelfPublishAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync("SelfMessage", new TopicId(topicType)).ConfigureAwait(true); + + await runtime.RunUntilIdleAsync(); + + // Agent sucessfully published to self + agent.Text.Source.Should().Be("TestTopic"); + agent.Text.Content.Should().Be("SelfMessage"); + } + + [Fact] + public async Task RuntimeShouldSaveLoadStateCorrectlyTest() + { + // Create a runtime and register an agent + var runtime = new InProcessRuntime(); + await runtime.StartAsync(); + Logger logger = new(new LoggerFactory()); + SubscribedSaveLoadAgent agent = null!; + await runtime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSaveLoadAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + + // Get agent ID and instantiate agent by publishing + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: true); + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + var topicType = "TestTopic"; + await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true); + await runtime.RunUntilIdleAsync(); + agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed."); + + // Save the state + var savedState = await runtime.SaveStateAsync(); + + // Ensure saved state contains the agent's state + savedState.Should().ContainKey(agentId.ToString()); + + // Ensure the agent's state is stored as a valid JSON object + savedState[agentId.ToString()].ValueKind.Should().Be(JsonValueKind.Object, "Agent state should be stored as a JSON object"); + + // Serialize and Deserialize the state to simulate persistence + string json = JsonSerializer.Serialize(savedState); + json.Should().NotBeNullOrEmpty("Serialized state should not be empty"); + var deserializedState = JsonSerializer.Deserialize>(json) + ?? throw new Exception("Deserialized state is unexpectedly null"); + deserializedState.Should().ContainKey(agentId.ToString()); + + // Start new runtime and restore the state + var newRuntime = new InProcessRuntime(); + await newRuntime.StartAsync(); + await newRuntime.RegisterAgentFactoryAsync("MyAgent", (id, runtime) => + { + agent = new SubscribedSaveLoadAgent(id, runtime, logger); + return ValueTask.FromResult(agent); + }); + await newRuntime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); + + // Show that no agent instances exist in the new runtime + newRuntime.agentInstances.Count.Should().Be(0, "Agent should be registered in the new runtime"); + + // Load the state into the new runtime and show that agent is now instantiated + await newRuntime.LoadStateAsync(deserializedState); + newRuntime.agentInstances.Count.Should().Be(1, "Agent should be registered in the new runtime"); + newRuntime.agentInstances.Should().ContainKey(agentId, "Agent should be loaded into the new runtime"); + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs index b6dadc833be2..ed87a71053af 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/TestAgent.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // TestAgent.cs +using System.Text.Json; using Microsoft.AutoGen.Contracts; using Microsoft.Extensions.Logging; @@ -59,7 +60,7 @@ public ValueTask HandleAsync(RpcTextMessage item, MessageContext message /// Key: source /// Value: message /// - private readonly Dictionary _receivedMessages = new(); + protected Dictionary _receivedMessages = new(); public Dictionary ReceivedMessages => _receivedMessages; } @@ -73,6 +74,38 @@ public SubscribedAgent(AgentId id, } } +[TypeSubscription("TestTopic")] +public class SubscribedSaveLoadAgent : TestAgent +{ + public SubscribedSaveLoadAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : base(id, runtime, logger) + { + } + + public override ValueTask> SaveStateAsync() + { + var jsonSafeDictionary = _receivedMessages.ToDictionary( + kvp => kvp.Key, + kvp => JsonSerializer.SerializeToElement(kvp.Value) // Convert each object to JsonElement + ); + + return ValueTask.FromResult>(jsonSafeDictionary); + } + + public override ValueTask LoadStateAsync(IDictionary state) + { + _receivedMessages.Clear(); + + foreach (var kvp in state) + { + _receivedMessages[kvp.Key] = kvp.Value.Deserialize() ?? throw new Exception($"Failed to deserialize key: {kvp.Key}"); + } + + return ValueTask.CompletedTask; + } +} + /// /// The test agent showing an agent that subscribes to itself. /// From 59e392cd0f5e6075d2f6f5b527a0914825049286 Mon Sep 17 00:00:00 2001 From: afourney Date: Thu, 6 Feb 2025 16:03:17 -0800 Subject: [PATCH 12/15] Get SelectorGroupChat working for Llama models. (#5409) Get's SelectorGroupChat working for llama by: 1. Using a UserMessage rather than a SystemMessage 2. Normalizing how roles are presented (one agent per line) 3. Normalizing how the transcript is constructed (a blank line between every message) --- .../teams/_group_chat/_selector_group_chat.py | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index bf4bc95946ad..de0ef3247c69 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Dict, List, Mapping, Sequence from autogen_core import Component, ComponentModel -from autogen_core.models import ChatCompletionClient, SystemMessage, UserMessage +from autogen_core.models import ChatCompletionClient, ModelFamily, SystemMessage, UserMessage from pydantic import BaseModel from typing_extensions import Self @@ -110,18 +110,17 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: message += " [Image]" else: raise ValueError(f"Unexpected message type in selector: {type(msg)}") - history_messages.append(message) + history_messages.append( + message.rstrip() + "\n\n" + ) # Create some consistency for how messages are separated in the transcript history = "\n".join(history_messages) # Construct agent roles, we are using the participant topic type as the agent name. - roles = "\n".join( - [ - f"{topic_type}: {description}".strip() - for topic_type, description in zip( - self._participant_topic_types, self._participant_descriptions, strict=True - ) - ] - ) + # Each agent sould appear on a single line. + roles = "" + for topic_type, description in zip(self._participant_topic_types, self._participant_descriptions, strict=True): + roles += re.sub(r"\s+", " ", f"{topic_type}: {description}").strip() + "\n" + roles = roles.strip() # Construct agent list to be selected, skip the previous speaker if not allowed. if self._previous_speaker is not None and not self._allow_repeated_speaker: @@ -136,11 +135,20 @@ async def select_speaker(self, thread: List[AgentEvent | ChatMessage]) -> str: roles=roles, participants=str(participants), history=history ) select_speaker_messages: List[SystemMessage | UserMessage] - if self._model_client.model_info["family"].startswith("gemini"): - select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] - else: + if self._model_client.model_info["family"] in [ + ModelFamily.GPT_4, + ModelFamily.GPT_4O, + ModelFamily.GPT_35, + ModelFamily.O1, + ModelFamily.O3, + ]: select_speaker_messages = [SystemMessage(content=select_speaker_prompt)] + else: + # Many other models need a UserMessage to respond to + select_speaker_messages = [UserMessage(content=select_speaker_prompt, source="selector")] + response = await self._model_client.create(messages=select_speaker_messages) + assert isinstance(response.content, str) mentions = self._mentioned_agents(response.content, self._participant_topic_types) if len(mentions) != 1: From 3b2bf82d155a7272614fcb1c1feeb112dbbb41d7 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Thu, 6 Feb 2025 16:59:31 -0800 Subject: [PATCH 13/15] feat: add integration workflow for testing multiple packages (#5412) --- .github/workflows/integration.yml | 55 +++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 .github/workflows/integration.yml diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml new file mode 100644 index 000000000000..4dfdfa6f1e3a --- /dev/null +++ b/.github/workflows/integration.yml @@ -0,0 +1,55 @@ +name: Integration + +on: + workflow_dispatch: + inputs: + branch: + description: 'Branch to run tests' + required: true + type: string + +jobs: + test: + runs-on: ubuntu-latest + environment: integration + strategy: + matrix: + package: + [ + "./packages/autogen-core", + "./packages/autogen-ext", + "./packages/autogen-agentchat", + ] + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ github.event.inputs.branch }} + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + version: "0.5.18" + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run uv sync + run: | + uv sync --locked --all-extras + echo "PKG_NAME=$(basename '${{ matrix.package }}')" >> $GITHUB_ENV + + working-directory: ./python + - name: Run task + run: | + source ${{ github.workspace }}/python/.venv/bin/activate + poe --directory ${{ matrix.package }} test + working-directory: ./python + + - name: Move coverage file + run: | + mv ${{ matrix.package }}/coverage.xml coverage_${{ env.PKG_NAME }}.xml + working-directory: ./python + + - name: Upload coverage artifact + uses: actions/upload-artifact@v4 + with: + name: coverage-${{ env.PKG_NAME }} + path: ./python/coverage_${{ env.PKG_NAME }}.xml From 3c30d8961eee55abe478761e53c8943928caf5d8 Mon Sep 17 00:00:00 2001 From: afourney Date: Thu, 6 Feb 2025 17:47:55 -0800 Subject: [PATCH 14/15] Prompting changes to better support smaller models. (#5386) A series of changes to the `python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py` file have been made to better support smaller models. This includes changes to the prompts, state descriptions, and ordering of messages. Regression tasks with OpenAI models shows no change in GAIA scores, while scores for Llama are significantly improved. --- .../web_surfer/_multimodal_web_surfer.py | 152 ++++++++++-------- .../autogen_ext/agents/web_surfer/_prompts.py | 33 ++-- .../autogen-ext/tests/test_websurfer_agent.py | 2 +- 3 files changed, 99 insertions(+), 88 deletions(-) diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py index 73c0799fff38..2c62caf2cad0 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_multimodal_web_surfer.py @@ -10,12 +10,10 @@ from typing import ( Any, AsyncGenerator, - BinaryIO, Dict, List, Optional, Sequence, - cast, ) from urllib.parse import quote_plus @@ -31,6 +29,7 @@ AssistantMessage, ChatCompletionClient, LLMMessage, + ModelFamily, RequestUsage, SystemMessage, UserMessage, @@ -42,7 +41,6 @@ from ._events import WebSurferEvent from ._prompts import ( - WEB_SURFER_OCR_PROMPT, WEB_SURFER_QA_PROMPT, WEB_SURFER_QA_SYSTEM_MESSAGE, WEB_SURFER_TOOL_PROMPT_MM, @@ -444,6 +442,22 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo # Clone the messages, removing old screenshots history: List[LLMMessage] = remove_images(self._chat_history) + # Split the history, removing the last message + if len(history): + user_request = history.pop() + else: + user_request = UserMessage(content="Empty request.", source="user") + + # Truncate the history for smaller models + if self._model_client.model_info["family"] not in [ + ModelFamily.GPT_4O, + ModelFamily.O1, + ModelFamily.O3, + ModelFamily.GPT_4, + ModelFamily.GPT_35, + ]: + history = [] + # Ask the page for interactive elements, then prepare the state-of-mark screenshot rects = await self._playwright_controller.get_interactive_rects(self._page) viewport = await self._playwright_controller.get_visual_viewport(self._page) @@ -499,21 +513,31 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo other_targets.extend(self._format_target_list(rects_below, rects)) if len(other_targets) > 0: + if len(other_targets) > 30: + other_targets = other_targets[0:30] + other_targets.append("...") other_targets_str = ( - "Additional valid interaction targets (not shown) include:\n" + "\n".join(other_targets) + "\n\n" + "Additional valid interaction targets include (but are not limited to):\n" + + "\n".join(other_targets) + + "\n\n" ) else: other_targets_str = "" + state_description = "Your " + await self._get_state_description() tool_names = "\n".join([t["name"] for t in tools]) + page_title = await self._page.title() + prompt_message = None if self._model_client.model_info["vision"]: text_prompt = WEB_SURFER_TOOL_PROMPT_MM.format( - url=self._page.url, + state_description=state_description, visible_targets=visible_targets, other_targets_str=other_targets_str, focused_hint=focused_hint, tool_names=tool_names, + title=page_title, + url=self._page.url, ).strip() # Scale the screenshot for the MLM, and close the original @@ -522,26 +546,42 @@ async def _generate_reply(self, cancellation_token: CancellationToken) -> UserCo if self.to_save_screenshots: scaled_screenshot.save(os.path.join(self.debug_dir, "screenshot_scaled.png")) # type: ignore - # Add the message - history.append(UserMessage(content=[text_prompt, AGImage.from_pil(scaled_screenshot)], source=self.name)) + # Create the message + prompt_message = UserMessage( + content=[re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), AGImage.from_pil(scaled_screenshot)], + source=self.name, + ) else: - visible_text = await self._playwright_controller.get_visible_text(self._page) - text_prompt = WEB_SURFER_TOOL_PROMPT_TEXT.format( - url=self._page.url, + state_description=state_description, visible_targets=visible_targets, other_targets_str=other_targets_str, focused_hint=focused_hint, tool_names=tool_names, - visible_text=visible_text.strip(), + title=page_title, + url=self._page.url, ).strip() - # Add the message - history.append(UserMessage(content=text_prompt, source=self.name)) + # Create the message + prompt_message = UserMessage(content=re.sub(r"(\n\s*){3,}", "\n\n", text_prompt), source=self.name) + + history.append(prompt_message) + history.append(user_request) + + # {history[-2].content if isinstance(history[-2].content, str) else history[-2].content[0]} + # print(f""" + # ================={len(history)}================= + # {history[-2].content} + # ===== + # {history[-1].content} + # =================================================== + # """) + # Make the request response = await self._model_client.create( history, tools=tools, extra_create_args={"tool_choice": "auto"}, cancellation_token=cancellation_token ) # , "parallel_tool_calls": False}) + self.model_usage.append(response.usage) message = response.content self._last_download = None @@ -716,23 +756,12 @@ async def _execute_tool( metadata_hash = hashlib.md5(page_metadata.encode("utf-8")).hexdigest() if metadata_hash != self._prior_metadata_hash: page_metadata = ( - "\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n" + "\n\nThe following metadata was extracted from the webpage:\n\n" + page_metadata.strip() + "\n" ) else: page_metadata = "" self._prior_metadata_hash = metadata_hash - # Describe the viewport of the new page in words - viewport = await self._playwright_controller.get_visual_viewport(self._page) - percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"]) - percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"]) - if percent_scrolled < 1: # Allow some rounding error - position_text = "at the top of the page" - elif percent_scrolled + percent_visible >= 99: # Allow some rounding error - position_text = "at the bottom of the page" - else: - position_text = str(percent_scrolled) + "% down from the top of the page" - new_screenshot = await self._page.screenshot() if self.to_save_screenshots: current_timestamp = "_" + int(time.time()).__str__() @@ -748,25 +777,40 @@ async def _execute_tool( ) ) - ocr_text = ( - await self._get_ocr_text(new_screenshot, cancellation_token=cancellation_token) - if self.use_ocr is True - else await self._playwright_controller.get_visible_text(self._page) - ) - # Return the complete observation - page_title = await self._page.title() - message_content = f"{action_description}\n\n Here is a screenshot of the webpage: [{page_title}]({self._page.url}).\n The viewport shows {percent_visible}% of the webpage, and is positioned {position_text} {page_metadata}\n" - if self.use_ocr: - message_content += f"Automatic OCR of the page screenshot has detected the following text:\n\n{ocr_text}" - else: - message_content += f"The following text is visible in the viewport:\n\n{ocr_text}" + state_description = "The " + await self._get_state_description() + message_content = ( + f"{action_description}\n\n" + state_description + page_metadata + "\nHere is a screenshot of the page." + ) return [ - message_content, + re.sub(r"(\n\s*){3,}", "\n\n", message_content), # Removing blank lines AGImage.from_pil(PIL.Image.open(io.BytesIO(new_screenshot))), ] + async def _get_state_description(self) -> str: + assert self._playwright_controller is not None + assert self._page is not None + + # Describe the viewport of the new page in words + viewport = await self._playwright_controller.get_visual_viewport(self._page) + percent_visible = int(viewport["height"] * 100 / viewport["scrollHeight"]) + percent_scrolled = int(viewport["pageTop"] * 100 / viewport["scrollHeight"]) + if percent_scrolled < 1: # Allow some rounding error + position_text = "at the top of the page" + elif percent_scrolled + percent_visible >= 99: # Allow some rounding error + position_text = "at the bottom of the page" + else: + position_text = str(percent_scrolled) + "% down from the top of the page" + + visible_text = await self._playwright_controller.get_visible_text(self._page) + + # Return the complete observation + page_title = await self._page.title() + message_content = f"web browser is open to the page [{page_title}]({self._page.url}).\nThe viewport shows {percent_visible}% of the webpage, and is positioned {position_text}\n" + message_content += f"The following text is visible in the viewport:\n\n{visible_text}" + return message_content + def _target_name(self, target: str, rects: Dict[str, InteractiveRegion]) -> str | None: try: return rects[target]["aria_name"].strip() @@ -798,38 +842,6 @@ def _format_target_list(self, ids: List[str], rects: Dict[str, InteractiveRegion return targets - async def _get_ocr_text( - self, image: bytes | io.BufferedIOBase | PIL.Image.Image, cancellation_token: Optional[CancellationToken] = None - ) -> str: - scaled_screenshot = None - if isinstance(image, PIL.Image.Image): - scaled_screenshot = image.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) - else: - pil_image = None - if not isinstance(image, io.BufferedIOBase): - pil_image = PIL.Image.open(io.BytesIO(image)) - else: - pil_image = PIL.Image.open(cast(BinaryIO, image)) - scaled_screenshot = pil_image.resize((self.MLM_WIDTH, self.MLM_HEIGHT)) - pil_image.close() - - # Add the multimodal message and make the request - messages: List[LLMMessage] = [] - messages.append( - UserMessage( - content=[ - WEB_SURFER_OCR_PROMPT, - AGImage.from_pil(scaled_screenshot), - ], - source=self.name, - ) - ) - response = await self._model_client.create(messages, cancellation_token=cancellation_token) - self.model_usage.append(response.usage) - scaled_screenshot.close() - assert isinstance(response.content, str) - return response.content - async def _summarize_page( self, question: str | None = None, diff --git a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py index 59a0a7c95d5e..d1f1885240e2 100644 --- a/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py +++ b/python/packages/autogen-ext/src/autogen_ext/agents/web_surfer/_prompts.py @@ -1,43 +1,42 @@ WEB_SURFER_TOOL_PROMPT_MM = """ -Consider the following screenshot of a web browser, which is open to the page '{url}'. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below: +{state_description} + +Consider the following screenshot of the page. In this screenshot, interactive elements are outlined in bounding boxes of different colors. Each bounding box has a numeric ID label in the same color. Additional information about each visible label is listed below: {visible_targets}{other_targets_str}{focused_hint} -You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools: +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - - on some other website entirely (in which case actions like performing a new web search might be the best option) + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) + +My request follows: """ WEB_SURFER_TOOL_PROMPT_TEXT = """ -Your web browser is open to the page '{url}'. The following text is visible in the viewport: - -``` -{visible_text} -``` +{state_description} You have also identified the following interactive components: {visible_targets}{other_targets_str}{focused_hint} -You are to respond to the most recent request by selecting an appropriate tool from the following set, or by answering the question directly if possible without tools: +You are to respond to my next request by selecting an appropriate tool from the following set, or by answering the question directly if possible: {tool_names} When deciding between tools, consider if the request can be best addressed by: - - the contents of the current viewport (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element might be most appropriate) - - contents found elsewhere on the full webpage (in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate) - - on some other website entirely (in which case actions like performing a new web search might be the best option) -""" + - the contents of the CURRENT VIEWPORT (in which case actions like clicking links, clicking buttons, inputting text, or hovering over an element, might be more appropriate) + - contents found elsewhere on the CURRENT WEBPAGE [{title}]({url}), in which case actions like scrolling, summarization, or full-page Q&A might be most appropriate + - on ANOTHER WEBSITE entirely (in which case actions like performing a new web search might be the best option) -WEB_SURFER_OCR_PROMPT = """ -Please transcribe all visible text on this page, including both main content and the labels of UI elements. +My request follows: """ + WEB_SURFER_QA_SYSTEM_MESSAGE = """ You are a helpful assistant that can summarize long documents to answer question. """ diff --git a/python/packages/autogen-ext/tests/test_websurfer_agent.py b/python/packages/autogen-ext/tests/test_websurfer_agent.py index a2aa33a10931..37423bfe6a50 100644 --- a/python/packages/autogen-ext/tests/test_websurfer_agent.py +++ b/python/packages/autogen-ext/tests/test_websurfer_agent.py @@ -140,7 +140,7 @@ async def test_run_websurfer(monkeypatch: pytest.MonkeyPatch) -> None: result.messages[2] # type: ignore .content[0] # type: ignore .startswith( # type: ignore - "I am waiting a short period of time before taking further action.\n\n Here is a screenshot of the webpage:" + "I am waiting a short period of time before taking further action." ) ) # type: ignore url_after_sleep = agent._page.url # type: ignore From 4c1c12d3506c5e6d63ca6bc773283fec4c2be466 Mon Sep 17 00:00:00 2001 From: afourney Date: Thu, 6 Feb 2025 22:20:06 -0800 Subject: [PATCH 15/15] Flush console output after every message. (#5415) --- .../src/autogen_agentchat/ui/_console.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py index 1f80166d32a1..0a95c842ea08 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/ui/_console.py @@ -75,8 +75,8 @@ def notify_event_received(self, request_id: str) -> None: self.input_events[request_id] = event -def aprint(output: str, end: str = "\n") -> Awaitable[None]: - return asyncio.to_thread(print, output, end=end) +def aprint(output: str, end: str = "\n", flush: bool = False) -> Awaitable[None]: + return asyncio.to_thread(print, output, end=end, flush=flush) async def Console( @@ -126,7 +126,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -141,7 +141,7 @@ async def Console( output += f"[Prompt tokens: {message.chat_message.models_usage.prompt_tokens}, Completion tokens: {message.chat_message.models_usage.completion_tokens}]\n" total_usage.completion_tokens += message.chat_message.models_usage.completion_tokens total_usage.prompt_tokens += message.chat_message.models_usage.prompt_tokens - await aprint(output, end="") + await aprint(output, end="", flush=True) # Print summary. if output_stats: @@ -156,7 +156,7 @@ async def Console( f"Total completion tokens: {total_usage.completion_tokens}\n" f"Duration: {duration:.2f} seconds\n" ) - await aprint(output, end="") + await aprint(output, end="", flush=True) # mypy ignore last_processed = message # type: ignore @@ -169,7 +169,7 @@ async def Console( message = cast(AgentEvent | ChatMessage, message) # type: ignore if not streaming_chunks: # Print message sender. - await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n") + await aprint(f"{'-' * 10} {message.source} {'-' * 10}", end="\n", flush=True) if isinstance(message, ModelClientStreamingChunkEvent): await aprint(message.content, end="") streaming_chunks.append(message.content) @@ -177,15 +177,16 @@ async def Console( if streaming_chunks: streaming_chunks.clear() # Chunked messages are already printed, so we just print a newline. - await aprint("", end="\n") + await aprint("", end="\n", flush=True) else: # Print message content. - await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n") + await aprint(_message_to_str(message, render_image_iterm=render_image_iterm), end="\n", flush=True) if message.models_usage: if output_stats: await aprint( f"[Prompt tokens: {message.models_usage.prompt_tokens}, Completion tokens: {message.models_usage.completion_tokens}]", end="\n", + flush=True, ) total_usage.completion_tokens += message.models_usage.completion_tokens total_usage.prompt_tokens += message.models_usage.prompt_tokens