diff --git a/.github/workflows/dotnet-build.yml b/.github/workflows/dotnet-build.yml index 1da80b327d26..bf7570239d61 100644 --- a/.github/workflows/dotnet-build.yml +++ b/.github/workflows/dotnet-build.yml @@ -107,6 +107,38 @@ jobs: - name: Unit Test V2 run: dotnet test --no-build -bl --configuration Release --filter "Category=UnitV2" + grpc-unit-tests: + name: Dotnet Grpc unit tests + needs: paths-filter + if: needs.paths-filter.outputs.hasChanges == 'true' + defaults: + run: + working-directory: dotnet + strategy: + fail-fast: false + matrix: + os: [ ubuntu-latest ] + runs-on: ${{ matrix.os }} + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + with: + lfs: true + - name: Setup .NET 8.0 + uses: actions/setup-dotnet@v4 + with: + dotnet-version: '8.0.x' + - name: Install dev certs + run: dotnet --version && dotnet dev-certs https --trust + - name: Restore dependencies + run: | + # dotnet nuget add source --name dotnet-tool https://pkgs.dev.azure.com/dnceng/public/_packaging/dotnet-tools/nuget/v3/index.json --configfile NuGet.config + dotnet restore -bl + - name: Build + run: dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true + - name: GRPC tests + run: dotnet test --no-build -bl --configuration Release --filter "Category=GRPC" + integration-test: strategy: fail-fast: true @@ -224,6 +256,8 @@ jobs: with: dotnet-version: '8.0.x' global-json-file: dotnet/global.json + - name: Install dev certs + run: dotnet --version && dotnet dev-certs https --trust - name: Restore dependencies run: | dotnet restore -bl diff --git a/.github/workflows/dotnet-release.yml b/.github/workflows/dotnet-release.yml index 23f4258a0e0c..fa114267136c 100644 --- a/.github/workflows/dotnet-release.yml +++ b/.github/workflows/dotnet-release.yml @@ -52,6 +52,7 @@ jobs: run: | echo "Build AutoGen" dotnet build --no-restore --configuration Release -bl /p:SignAssembly=true + - run: sudo dotnet dev-certs https --trust --no-password - name: Unit Test run: dotnet test --no-build -bl --configuration Release env: diff --git a/dotnet/AutoGen.sln b/dotnet/AutoGen.sln index ab7a07464c52..74b7ac965592 100644 --- a/dotnet/AutoGen.sln +++ b/dotnet/AutoGen.sln @@ -118,6 +118,10 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Hello", "Hello", "{F42F9C8E EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc", "src\Microsoft.AutoGen\Core.Grpc\Microsoft.AutoGen.Core.Grpc.csproj", "{3D83C6DB-ACEA-48F3-959F-145CCD2EE135}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "GettingStartedGrpc", "samples\GettingStartedGrpc\GettingStartedGrpc.csproj", "{C3740DF1-18B1-4607-81E4-302F0308C848}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AutoGen.Core.Grpc.Tests", "test\Microsoft.AutoGen.Core.Grpc.Tests\Microsoft.AutoGen.Core.Grpc.Tests.csproj", "{23A028D3-5EB1-4FA0-9CD1-A1340B830579}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -306,6 +310,14 @@ Global {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Debug|Any CPU.Build.0 = Debug|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.ActiveCfg = Release|Any CPU {AAD593FE-A49B-425E-A9FE-A0022CD25E3D}.Release|Any CPU.Build.0 = Release|Any CPU + {C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {C3740DF1-18B1-4607-81E4-302F0308C848}.Debug|Any CPU.Build.0 = Debug|Any CPU + {C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.ActiveCfg = Release|Any CPU + {C3740DF1-18B1-4607-81E4-302F0308C848}.Release|Any CPU.Build.0 = Release|Any CPU + {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Debug|Any CPU.Build.0 = Debug|Any CPU + {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.ActiveCfg = Release|Any CPU + {23A028D3-5EB1-4FA0-9CD1-A1340B830579}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -359,6 +371,8 @@ Global {3D83C6DB-ACEA-48F3-959F-145CCD2EE135} = {18BF8DD7-0585-48BF-8F97-AD333080CE06} {AAD593FE-A49B-425E-A9FE-A0022CD25E3D} = {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} {F42F9C8E-7BD9-4687-9B63-AFFA461AF5C1} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} + {C3740DF1-18B1-4607-81E4-302F0308C848} = {CE0AA8D5-12B8-4628-9589-DAD8CB0DDCF6} + {23A028D3-5EB1-4FA0-9CD1-A1340B830579} = {F823671B-3ECA-4AE6-86DA-25E920D3FE64} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {93384647-528D-46C8-922C-8DB36A382F0B} diff --git a/dotnet/samples/GettingStartedGrpc/Checker.cs b/dotnet/samples/GettingStartedGrpc/Checker.cs new file mode 100644 index 000000000000..7f75acbfafd6 --- /dev/null +++ b/dotnet/samples/GettingStartedGrpc/Checker.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Checker.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.Extensions.Hosting; +using TerminationF = System.Func; + +namespace GettingStartedGrpcSample; + +[TypeSubscription("default")] +public class Checker( + AgentId id, + IAgentRuntime runtime, + IHostApplicationLifetime hostApplicationLifetime, + TerminationF runUntilFunc + ) : + BaseAgent(id, runtime, "Modifier", null), + IHandle +{ + public async ValueTask HandleAsync(Events.CountUpdate item, MessageContext messageContext) + { + if (!runUntilFunc(item.NewCount)) + { + Console.WriteLine($"\nChecker:\n{item.NewCount} passed the check, continue."); + await this.PublishMessageAsync(new Events.CountMessage { Content = item.NewCount }, new TopicId("default")); + } + else + { + Console.WriteLine($"\nChecker:\n{item.NewCount} failed the check, stopping."); + hostApplicationLifetime.StopApplication(); + } + } +} diff --git a/dotnet/samples/GettingStartedGrpc/GettingStartedGrpc.csproj b/dotnet/samples/GettingStartedGrpc/GettingStartedGrpc.csproj new file mode 100644 index 000000000000..a419cd2fe906 --- /dev/null +++ b/dotnet/samples/GettingStartedGrpc/GettingStartedGrpc.csproj @@ -0,0 +1,26 @@ + + + + Exe + net8.0 + getting_started + enable + enable + + + + + + + + + + + + + + + + + + diff --git a/dotnet/samples/GettingStartedGrpc/Modifier.cs b/dotnet/samples/GettingStartedGrpc/Modifier.cs new file mode 100644 index 000000000000..ad3a9d8d97a6 --- /dev/null +++ b/dotnet/samples/GettingStartedGrpc/Modifier.cs @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Modifier.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; + +using ModifyF = System.Func; + +namespace GettingStartedGrpcSample; + +[TypeSubscription("default")] +public class Modifier( + AgentId id, + IAgentRuntime runtime, + ModifyF modifyFunc + ) : + BaseAgent(id, runtime, "Modifier", null), + IHandle +{ + + public async ValueTask HandleAsync(Events.CountMessage item, MessageContext messageContext) + { + int newValue = modifyFunc(item.Content); + Console.WriteLine($"\nModifier:\nModified {item.Content} to {newValue}"); + + var updateMessage = new Events.CountUpdate { NewCount = newValue }; + await this.PublishMessageAsync(updateMessage, topic: new TopicId("default")); + } +} diff --git a/dotnet/samples/GettingStartedGrpc/Program.cs b/dotnet/samples/GettingStartedGrpc/Program.cs new file mode 100644 index 000000000000..aa9cc5417082 --- /dev/null +++ b/dotnet/samples/GettingStartedGrpc/Program.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Program.cs +using GettingStartedGrpcSample; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core; +using Microsoft.AutoGen.Core.Grpc; +using Microsoft.Extensions.DependencyInjection.Extensions; +using ModifyF = System.Func; +using TerminationF = System.Func; + +ModifyF modifyFunc = (int x) => x - 1; +TerminationF runUntilFunc = (int x) => +{ + return x <= 1; +}; + +AgentsAppBuilder appBuilder = new AgentsAppBuilder(); +appBuilder.AddGrpcAgentWorker("http://localhost:50051"); + +appBuilder.Services.TryAddSingleton(modifyFunc); +appBuilder.Services.TryAddSingleton(runUntilFunc); + +appBuilder.AddAgent("Checker"); +appBuilder.AddAgent("Modifier"); + +var app = await appBuilder.BuildAsync(); +await app.StartAsync(); + +// Send the initial count to the agents app, running on the `local` runtime, and pass through the registered services via the application `builder` +await app.PublishMessageAsync(new GettingStartedGrpcSample.Events.CountMessage +{ + Content = 10 +}, new TopicId("default")); + +// Run until application shutdown +await app.WaitForShutdownAsync(); diff --git a/dotnet/samples/GettingStartedGrpc/message.proto b/dotnet/samples/GettingStartedGrpc/message.proto new file mode 100644 index 000000000000..d4acac2e5711 --- /dev/null +++ b/dotnet/samples/GettingStartedGrpc/message.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +option csharp_namespace = "GettingStartedGrpcSample.Events"; + +message CountMessage { + int32 content = 1; +} + +message CountUpdate { + int32 new_count = 1; +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs new file mode 100644 index 000000000000..7cf66d4cee9d --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/AgentsAppBuilderExtensions.cs @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// AgentsAppBuilderExtensions.cs + +using System.Diagnostics; +using Grpc.Core; +using Grpc.Net.Client.Configuration; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; +using Microsoft.Extensions.Logging; +namespace Microsoft.AutoGen.Core.Grpc; + +public static class AgentsAppBuilderExtensions +{ + private const string _defaultAgentServiceAddress = "http://localhost:53071"; + + // TODO: How do we ensure AddGrpcAgentWorker and UseInProcessRuntime are mutually exclusive? + public static AgentsAppBuilder AddGrpcAgentWorker(this AgentsAppBuilder builder, string? agentServiceAddress = null) + { + builder.Services.AddGrpcClient(options => + { + options.Address = new Uri(agentServiceAddress ?? builder.Configuration.GetValue("AGENT_HOST", _defaultAgentServiceAddress)); + options.ChannelOptionsActions.Add(channelOptions => + { + var loggerFactory = new LoggerFactory(); + if (Debugger.IsAttached) + { + channelOptions.HttpHandler = new SocketsHttpHandler + { + EnableMultipleHttp2Connections = false, + KeepAlivePingDelay = TimeSpan.FromSeconds(200), + KeepAlivePingTimeout = TimeSpan.FromSeconds(100), + KeepAlivePingPolicy = HttpKeepAlivePingPolicy.Always + }; + } + else + { + channelOptions.HttpHandler = new SocketsHttpHandler + { + EnableMultipleHttp2Connections = true, + KeepAlivePingDelay = TimeSpan.FromSeconds(20), + KeepAlivePingTimeout = TimeSpan.FromSeconds(10), + KeepAlivePingPolicy = HttpKeepAlivePingPolicy.WithActiveRequests + }; + } + + var methodConfig = new MethodConfig + { + Names = { MethodName.Default }, + RetryPolicy = new RetryPolicy + { + MaxAttempts = 5, + InitialBackoff = TimeSpan.FromSeconds(1), + MaxBackoff = TimeSpan.FromSeconds(5), + BackoffMultiplier = 1.5, + RetryableStatusCodes = { StatusCode.Unavailable } + } + }; + + channelOptions.ServiceConfig = new() { MethodConfigs = { methodConfig } }; + channelOptions.ThrowOperationCanceledOnCancellation = true; + }); + }); + + builder.Services.TryAddSingleton(DistributedContextPropagator.Current); + builder.Services.AddSingleton(); + builder.Services.AddHostedService(services => + { + return (services.GetRequiredService() as GrpcAgentRuntime)!; + }); + + return builder; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/CloudEventExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/CloudEventExtensions.cs new file mode 100644 index 000000000000..1ee46660ce8e --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/CloudEventExtensions.cs @@ -0,0 +1,43 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// CloudEventExtensions.cs + +using Microsoft.AutoGen.Contracts; + +namespace Microsoft.AutoGen.Core.Grpc; + +internal static class CloudEventExtensions +{ + // Convert an ISubscrptionDefinition to a Protobuf Subscription + internal static CloudEvent CreateCloudEvent(Google.Protobuf.WellKnownTypes.Any payload, TopicId topic, string dataType, AgentId? sender, string messageId) + { + var attributes = new Dictionary + { + { + Constants.DATA_CONTENT_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.DATA_CONTENT_TYPE_PROTOBUF_VALUE } + }, + { + Constants.DATA_SCHEMA_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = dataType } + }, + { + Constants.MESSAGE_KIND_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = Constants.MESSAGE_KIND_VALUE_PUBLISH } + } + }; + + if (sender != null) + { + var senderNonNull = (AgentId)sender; + attributes.Add(Constants.AGENT_SENDER_TYPE_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Type }); + attributes.Add(Constants.AGENT_SENDER_KEY_ATTR, new CloudEvent.Types.CloudEventAttributeValue { CeString = senderNonNull.Key }); + } + + return new CloudEvent + { + ProtoData = payload, + Type = topic.Type, + Source = topic.Source, + Id = messageId, + Attributes = { attributes } + }; + + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/Constants.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/Constants.cs new file mode 100644 index 000000000000..c3e9592c1dc2 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/Constants.cs @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Constants.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class Constants +{ + public const string DATA_CONTENT_TYPE_PROTOBUF_VALUE = "application/x-protobuf"; + public const string DATA_CONTENT_TYPE_JSON_VALUE = "application/json"; + public const string DATA_CONTENT_TYPE_TEXT_VALUE = "text/plain"; + + public const string DATA_CONTENT_TYPE_ATTR = "datacontenttype"; + public const string DATA_SCHEMA_ATTR = "dataschema"; + public const string AGENT_SENDER_TYPE_ATTR = "agagentsendertype"; + public const string AGENT_SENDER_KEY_ATTR = "agagentsenderkey"; + + public const string MESSAGE_KIND_ATTR = "agmsgkind"; + public const string MESSAGE_KIND_VALUE_PUBLISH = "publish"; + public const string MESSAGE_KIND_VALUE_RPC_REQUEST = "rpc_request"; + public const string MESSAGE_KIND_VALUE_RPC_RESPONSE = "rpc_response"; +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs new file mode 100644 index 000000000000..1ff1036016d1 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcAgentRuntime.cs @@ -0,0 +1,430 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentRuntime.cs + +using System.Collections.Concurrent; +using Grpc.Core; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.Core.Grpc; + +internal sealed class AgentsContainer(IAgentRuntime hostingRuntime) +{ + private readonly IAgentRuntime hostingRuntime = hostingRuntime; + + private Dictionary agentInstances = new(); + public Dictionary Subscriptions = new(); + private Dictionary>> agentFactories = new(); + + public async ValueTask EnsureAgentAsync(Contracts.AgentId agentId) + { + if (!this.agentInstances.TryGetValue(agentId, out IHostableAgent? agent)) + { + if (!this.agentFactories.TryGetValue(agentId.Type, out Func>? factoryFunc)) + { + throw new Exception($"Agent with name {agentId.Type} not found."); + } + + agent = await factoryFunc(agentId, this.hostingRuntime); + this.agentInstances.Add(agentId, agent); + } + + return this.agentInstances[agentId]; + } + + public async ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) + { + if (!lazy) + { + await this.EnsureAgentAsync(agentId); + } + + return agentId; + } + + public AgentType RegisterAgentFactory(AgentType type, Func> factoryFunc) + { + if (this.agentFactories.ContainsKey(type)) + { + throw new Exception($"Agent factory with type {type} already exists."); + } + + this.agentFactories.Add(type, factoryFunc); + return type; + } + + public void AddSubscription(ISubscriptionDefinition subscription) + { + if (this.Subscriptions.ContainsKey(subscription.Id)) + { + throw new Exception($"Subscription with id {subscription.Id} already exists."); + } + + this.Subscriptions.Add(subscription.Id, subscription); + } + + public bool RemoveSubscriptionAsync(string subscriptionId) + { + if (!this.Subscriptions.ContainsKey(subscriptionId)) + { + throw new Exception($"Subscription with id {subscriptionId} does not exist."); + } + + return this.Subscriptions.Remove(subscriptionId); + } + + public HashSet RegisteredAgentTypes => this.agentFactories.Keys.ToHashSet(); + public IEnumerable LiveAgents => this.agentInstances.Values; +} + +public sealed class GrpcAgentRuntime : IHostedService, IAgentRuntime, IMessageSink, IDisposable +{ + public GrpcAgentRuntime(AgentRpc.AgentRpcClient client, + IHostApplicationLifetime hostApplicationLifetime, + IServiceProvider serviceProvider, + ILogger logger) + { + this._client = client; + this._logger = logger; + this._shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(hostApplicationLifetime.ApplicationStopping); + + this._messageRouter = new GrpcMessageRouter(client, this, _clientId, logger, this._shutdownCts.Token); + this._agentsContainer = new AgentsContainer(this); + + this.ServiceProvider = serviceProvider; + } + + // Request ID -> ResultSink<...> + private readonly ConcurrentDictionary> _pendingRequests = new(); + + private readonly AgentRpc.AgentRpcClient _client; + private readonly GrpcMessageRouter _messageRouter; + + private readonly ILogger _logger; + private readonly CancellationTokenSource _shutdownCts; + + private readonly AgentsContainer _agentsContainer; + + public IServiceProvider ServiceProvider { get; } + + private Guid _clientId = Guid.NewGuid(); + private CallOptions CallOptions + { + get + { + var metadata = new Metadata + { + { "client-id", this._clientId.ToString() } + }; + return new CallOptions(headers: metadata); + } + } + + public IProtoSerializationRegistry SerializationRegistry { get; } = new ProtobufSerializationRegistry(); + + public void Dispose() + { + this._shutdownCts.Cancel(); + this._messageRouter.Dispose(); + } + + private async ValueTask HandleRequest(RpcRequest request, CancellationToken cancellationToken = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.Target is null) + { + throw new InvalidOperationException("Target is null."); + } + + var agentId = request.Target; + var agent = await this._agentsContainer.EnsureAgentAsync(agentId.FromProtobuf()); + + // Convert payload back to object + var payload = request.Payload; + var message = payload.ToObject(SerializationRegistry); + + var messageContext = new MessageContext(request.RequestId, cancellationToken) + { + Sender = request.Source?.FromProtobuf() ?? null, + Topic = null, + IsRpc = true + }; + + var result = await agent.OnMessageAsync(message, messageContext); + + if (result is not null) + { + var response = new RpcResponse + { + RequestId = request.RequestId, + Payload = result.ToPayload(SerializationRegistry) + }; + + var responseMessage = new Message + { + Response = response + }; + + await this._messageRouter.RouteMessageAsync(responseMessage, cancellationToken); + } + } + + private async ValueTask HandleResponse(RpcResponse request, CancellationToken _ = default) + { + if (request is null) + { + throw new InvalidOperationException("Request is null."); + } + if (request.Payload is null) + { + throw new InvalidOperationException("Payload is null."); + } + if (request.RequestId is null) + { + throw new InvalidOperationException("RequestId is null."); + } + + if (_pendingRequests.TryRemove(request.RequestId, out var resultSink)) + { + var payload = request.Payload; + var message = payload.ToObject(SerializationRegistry); + resultSink.SetResult(message); + } + } + + private async ValueTask HandlePublish(CloudEvent evt, CancellationToken cancellationToken = default) + { + if (evt is null) + { + throw new InvalidOperationException("CloudEvent is null."); + } + if (evt.ProtoData is null) + { + throw new InvalidOperationException("ProtoData is null."); + } + if (evt.Attributes is null) + { + throw new InvalidOperationException("Attributes is null."); + } + + var topic = new TopicId(evt.Type, evt.Source); + Contracts.AgentId? sender = null; + if (evt.Attributes.TryGetValue(Constants.AGENT_SENDER_TYPE_ATTR, out var typeValue) && evt.Attributes.TryGetValue(Constants.AGENT_SENDER_KEY_ATTR, out var keyValue)) + { + sender = new Contracts.AgentId + { + Type = typeValue.CeString, + Key = keyValue.CeString + }; + } + + var messageId = evt.Id; + var typeName = evt.Attributes[Constants.DATA_SCHEMA_ATTR].CeString; + var serializer = SerializationRegistry.GetSerializer(typeName) ?? throw new Exception(); + var message = serializer.Deserialize(evt.ProtoData); + + var messageContext = new MessageContext(messageId, cancellationToken) + { + Sender = sender, + Topic = topic, + IsRpc = false + }; + + // Iterate over subscriptions values to find receiving agents + foreach (var subscription in this._agentsContainer.Subscriptions.Values) + { + if (subscription.Matches(topic)) + { + var recipient = subscription.MapToAgent(topic); + var agent = await this._agentsContainer.EnsureAgentAsync(recipient); + await agent.OnMessageAsync(message, messageContext); + } + } + } + + public ValueTask StartAsync(CancellationToken cancellationToken) + { + return this._messageRouter.StartAsync(cancellationToken); + } + + Task IHostedService.StartAsync(CancellationToken cancellationToken) => this._messageRouter.StartAsync(cancellationToken).AsTask(); + + public Task StopAsync(CancellationToken cancellationToken) + { + return this._messageRouter.StopAsync(); + } + + public async ValueTask SendMessageAsync(object message, Contracts.AgentId recepient, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + + var payload = message.ToPayload(SerializationRegistry); + var request = new RpcRequest + { + RequestId = Guid.NewGuid().ToString(), + Source = sender?.ToProtobuf() ?? null, + Target = recepient.ToProtobuf(), + Payload = payload, + }; + + Message msg = new() + { + Request = request + }; + + // Create a future that will be completed when the response is received + var resultSink = new ResultSink(); + this._pendingRequests.TryAdd(request.RequestId, resultSink); + await this._messageRouter.RouteMessageAsync(msg, cancellationToken); + + return await resultSink.Future; + } + + public async ValueTask PublishMessageAsync(object message, TopicId topic, Contracts.AgentId? sender = null, string? messageId = null, CancellationToken cancellationToken = default) + { + if (!SerializationRegistry.Exists(message.GetType())) + { + SerializationRegistry.RegisterSerializer(message.GetType()); + } + var protoAny = (SerializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + var typeName = SerializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType()); + + var cloudEvent = CloudEventExtensions.CreateCloudEvent(protoAny, topic, typeName, sender, messageId ?? Guid.NewGuid().ToString()); + + Message msg = new() + { + CloudEvent = cloudEvent + }; + + await this._messageRouter.RouteMessageAsync(msg, cancellationToken); + } + + public ValueTask GetAgentAsync(Contracts.AgentId agentId, bool lazy = true) => this._agentsContainer.GetAgentAsync(agentId, lazy); + + public ValueTask GetAgentAsync(AgentType agentType, string key = "default", bool lazy = true) + => this.GetAgentAsync(new Contracts.AgentId(agentType, key), lazy); + + public ValueTask GetAgentAsync(string agent, string key = "default", bool lazy = true) + => this.GetAgentAsync(new Contracts.AgentId(agent, key), lazy); + + public async ValueTask> SaveAgentStateAsync(Contracts.AgentId agentId) + { + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + return await agent.SaveStateAsync(); + } + + public async ValueTask LoadAgentStateAsync(Contracts.AgentId agentId, IDictionary state) + { + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + await agent.LoadStateAsync(state); + } + + public async ValueTask GetAgentMetadataAsync(Contracts.AgentId agentId) + { + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + return agent.Metadata; + } + + public async ValueTask AddSubscriptionAsync(ISubscriptionDefinition subscription) + { + this._agentsContainer.AddSubscription(subscription); + + var _ = await this._client.AddSubscriptionAsync(new AddSubscriptionRequest + { + Subscription = subscription.ToProtobuf() + }, this.CallOptions); + } + + public async ValueTask RemoveSubscriptionAsync(string subscriptionId) + { + this._agentsContainer.RemoveSubscriptionAsync(subscriptionId); + + await this._client.RemoveSubscriptionAsync(new RemoveSubscriptionRequest + { + Id = subscriptionId + }, this.CallOptions); + } + + public async ValueTask RegisterAgentFactoryAsync(AgentType type, Func> factoryFunc) + { + this._agentsContainer.RegisterAgentFactory(type, factoryFunc); + + await this._client.RegisterAgentAsync(new RegisterAgentTypeRequest + { + Type = type, + }, this.CallOptions); + + return type; + } + + public ValueTask TryGetAgentProxyAsync(Contracts.AgentId agentId) + { + // TODO: Do we want to support getting remote agent proxies? + return ValueTask.FromResult(new AgentProxy(agentId, this)); + } + + public async ValueTask> SaveStateAsync() + { + Dictionary state = new(); + foreach (var agent in this._agentsContainer.LiveAgents) + { + state[agent.Id.ToString()] = await agent.SaveStateAsync(); + } + + return state; + } + + public async ValueTask LoadStateAsync(IDictionary state) + { + HashSet registeredTypes = this._agentsContainer.RegisteredAgentTypes; + + foreach (var agentIdStr in state.Keys) + { + Contracts.AgentId agentId = Contracts.AgentId.FromStr(agentIdStr); + if (state[agentIdStr] is not IDictionary agentStateDict) + { + throw new Exception($"Agent state for {agentId} is not a {typeof(IDictionary)}: {state[agentIdStr].GetType()}"); + } + + if (registeredTypes.Contains(agentId.Type)) + { + IHostableAgent agent = await this._agentsContainer.EnsureAgentAsync(agentId); + await agent.LoadStateAsync(agentStateDict); + } + } + } + + public async ValueTask OnMessageAsync(Message message, CancellationToken cancellation = default) + { + switch (message.MessageCase) + { + case Message.MessageOneofCase.Request: + var request = message.Request ?? throw new InvalidOperationException("Request is null."); + await HandleRequest(request); + break; + case Message.MessageOneofCase.Response: + var response = message.Response ?? throw new InvalidOperationException("Response is null."); + await HandleResponse(response); + break; + case Message.MessageOneofCase.CloudEvent: + var cloudEvent = message.CloudEvent ?? throw new InvalidOperationException("CloudEvent is null."); + await HandlePublish(cloudEvent); + break; + default: + throw new InvalidOperationException($"Unexpected message '{message}'."); + } + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs new file mode 100644 index 000000000000..e46b392c708f --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/GrpcMessageRouter.cs @@ -0,0 +1,296 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcMessageRouter.cs + +using System.Threading.Channels; +using Grpc.Core; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.Core.Grpc; + +// TODO: Consider whether we want to just reuse IHandle +internal interface IMessageSink +{ + public ValueTask OnMessageAsync(TMessage message, CancellationToken cancellation = default); +} + +internal sealed class AutoRestartChannel : IDisposable +{ + private readonly object _channelLock = new(); + private readonly AgentRpc.AgentRpcClient _client; + private readonly Guid _clientId; + private readonly ILogger _logger; + private readonly CancellationTokenSource _shutdownCts; + private AsyncDuplexStreamingCall? _channel; + + public AutoRestartChannel(AgentRpc.AgentRpcClient client, + Guid clientId, + ILogger logger, + CancellationToken shutdownCancellation = default) + { + _client = client; + _clientId = clientId; + _logger = logger; + _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + } + + public void EnsureConnected() + { + _logger.LogInformation("Connecting to gRPC endpoint " + Environment.GetEnvironmentVariable("AGENT_HOST")); + + if (this.RecreateChannel(null) == null) + { + throw new Exception("Failed to connect to gRPC endpoint."); + }; + } + + public AsyncDuplexStreamingCall StreamingCall + { + get + { + if (_channel is { } channel) + { + return channel; + } + + lock (_channelLock) + { + if (_channel is not null) + { + return _channel; + } + + return RecreateChannel(null); + } + } + } + + public AsyncDuplexStreamingCall RecreateChannel() => RecreateChannel(this._channel); + + private AsyncDuplexStreamingCall RecreateChannel(AsyncDuplexStreamingCall? ownedChannel) + { + // Make sure we are only re-creating the channel if it does not exit or we are the owner. + if (_channel is null || _channel == ownedChannel) + { + lock (_channelLock) + { + if (_channel is null || _channel == ownedChannel) + { + var metadata = new Metadata + { + { "client-id", _clientId.ToString() } + }; + _channel?.Dispose(); + _channel = _client.OpenChannel(cancellationToken: _shutdownCts.Token, headers: metadata); + } + } + } + + return _channel; + } + + public void Dispose() + { + IDisposable? channelDisposable = Interlocked.Exchange(ref this._channel, null); + channelDisposable?.Dispose(); + } +} + +internal sealed class GrpcMessageRouter(AgentRpc.AgentRpcClient client, + IMessageSink incomingMessageSink, + Guid clientId, + ILogger logger, + CancellationToken shutdownCancellation = default) : IDisposable +{ + private static readonly BoundedChannelOptions DefaultChannelOptions = new BoundedChannelOptions(1024) + { + AllowSynchronousContinuations = true, + SingleReader = true, + SingleWriter = false, + FullMode = BoundedChannelFullMode.Wait + }; + + private readonly ILogger _logger = logger; + + private readonly CancellationTokenSource _shutdownCts = CancellationTokenSource.CreateLinkedTokenSource(shutdownCancellation); + + private readonly IMessageSink _incomingMessageSink = incomingMessageSink; + private readonly Channel<(Message Message, TaskCompletionSource WriteCompletionSource)> _outboundMessagesChannel + // TODO: Enable a way to configure the channel options + = Channel.CreateBounded<(Message, TaskCompletionSource)>(DefaultChannelOptions); + + private readonly AutoRestartChannel _incomingMessageChannel = new AutoRestartChannel(client, clientId, logger, shutdownCancellation); + + private Task? _readTask; + private Task? _writeTask; + + private async Task RunReadPump() + { + var cachedChannel = _incomingMessageChannel.StreamingCall; + while (!_shutdownCts.Token.IsCancellationRequested) + { + try + { + await foreach (var message in cachedChannel.ResponseStream.ReadAllAsync(_shutdownCts.Token)) + { + // next if message is null + if (message == null) + { + continue; + } + + await _incomingMessageSink.OnMessageAsync(message, _shutdownCts.Token); + } + } + catch (OperationCanceledException) + { + // Time to shut down. + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + _logger.LogError(ex, "Error reading from channel."); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + } + catch + { + // Shutdown requested. + break; + } + } + } + + private async Task RunWritePump() + { + var cachedChannel = this._incomingMessageChannel.StreamingCall; + var outboundMessages = _outboundMessagesChannel.Reader; + while (!_shutdownCts.IsCancellationRequested) + { + (Message Message, TaskCompletionSource WriteCompletionSource) item = default; + try + { + await outboundMessages.WaitToReadAsync().ConfigureAwait(false); + + // Read the next message if we don't already have an unsent message + // waiting to be sent. + if (!outboundMessages.TryRead(out item)) + { + break; + } + + while (!_shutdownCts.IsCancellationRequested) + { + await cachedChannel.RequestStream.WriteAsync(item.Message, _shutdownCts.Token).ConfigureAwait(false); + item.WriteCompletionSource.TrySetResult(); + break; + } + } + catch (OperationCanceledException) + { + // Time to shut down. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.Unavailable) + { + // we could not connect to the endpoint - most likely we have the wrong port or failed ssl + // we need to let the user know what port we tried to connect to and then do backoff and retry + _logger.LogError(ex, "Error connecting to GRPC endpoint {Endpoint}.", Environment.GetEnvironmentVariable("AGENT_HOST")); + break; + } + catch (RpcException ex) when (ex.StatusCode == StatusCode.OK) + { + _logger.LogError(ex, "Error writing to channel, continuing (Status OK). {ex}", cachedChannel.ToString()); + break; + } + catch (Exception ex) when (!_shutdownCts.IsCancellationRequested) + { + item.WriteCompletionSource?.TrySetException(ex); + _logger.LogError(ex, $"Error writing to channel.{ex}"); + cachedChannel = this._incomingMessageChannel.RecreateChannel(); + continue; + } + catch + { + // Shutdown requested. + item.WriteCompletionSource?.TrySetCanceled(); + break; + } + } + + while (outboundMessages.TryRead(out var item)) + { + item.WriteCompletionSource.TrySetCanceled(); + } + } + + public ValueTask RouteMessageAsync(Message message, CancellationToken cancellation = default) + { + var tcs = new TaskCompletionSource(); + return _outboundMessagesChannel.Writer.WriteAsync((message, tcs), cancellation); + } + + public ValueTask StartAsync(CancellationToken cancellation) + { + // TODO: Should we error out on a noncancellable token? + + this._incomingMessageChannel.EnsureConnected(); + var didSuppress = false; + + // Make sure we do not mistakenly flow the ExecutionContext into the background pumping tasks. + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump, cancellation); + _writeTask = Task.Run(RunWritePump, cancellation); + + return ValueTask.CompletedTask; + } + catch (Exception ex) + { + return ValueTask.FromException(ex); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + } + + // No point in returning a ValueTask here, since we are awaiting the two tasks + public async Task StopAsync() + { + _shutdownCts.Cancel(); + + _outboundMessagesChannel.Writer.TryComplete(); + + List pendingTasks = new(); + if (_readTask is { } readTask) + { + pendingTasks.Add(readTask); + } + + if (_writeTask is { } writeTask) + { + pendingTasks.Add(writeTask); + } + + await Task.WhenAll(pendingTasks).ConfigureAwait(false); + + this._incomingMessageChannel.Dispose(); + } + + public void Dispose() + { + _outboundMessagesChannel.Writer.TryComplete(); + this._incomingMessageChannel.Dispose(); + } +} + diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs new file mode 100644 index 000000000000..c2ca53e33710 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentMessageSerializer.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; +/// +/// Interface for serializing and deserializing agent messages. +/// +public interface IAgentMessageSerializer +{ + /// + /// Serialize an agent message. + /// + /// The message to serialize. + /// The serialized message. + Google.Protobuf.WellKnownTypes.Any Serialize(object message); + + /// + /// Deserialize an agent message. + /// + /// The message to deserialize. + /// The deserialized message. + object Deserialize(Google.Protobuf.WellKnownTypes.Any message); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs new file mode 100644 index 000000000000..8179ff4b494b --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IAgentRuntimeExtensions.cs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IAgentRuntimeExtensions.cs + +using System.Diagnostics; +using Google.Protobuf.Collections; +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; +using Microsoft.Extensions.DependencyInjection; +using static Microsoft.AutoGen.Contracts.CloudEvent.Types; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class GrpcAgentRuntimeExtensions +{ + public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime runtime, IDictionary metadata) + { + var dcp = runtime.ServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static (string?, string?) GetTraceIdAndState(GrpcAgentRuntime worker, MapField metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.ExtractTraceIdAndState(metadata, + static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }, + out var traceParent, + out var traceState); + return (traceParent, traceState); + } + public static void Update(GrpcAgentRuntime worker, RpcRequest request, Activity? activity = null) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.Inject(activity, request.Metadata, static (carrier, key, value) => + { + var metadata = (IDictionary)carrier!; + if (metadata.TryGetValue(key, out _)) + { + metadata[key] = value; + } + else + { + metadata.Add(key, value); + } + }); + } + public static void Update(GrpcAgentRuntime worker, CloudEvent cloudEvent, Activity? activity = null) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + dcp.Inject(activity, cloudEvent.Attributes, static (carrier, key, value) => + { + var mapField = (MapField)carrier!; + if (mapField.TryGetValue(key, out var ceValue)) + { + mapField[key] = new CloudEventAttributeValue { CeString = value }; + } + else + { + mapField.Add(key, new CloudEventAttributeValue { CeString = value }); + } + }); + } + + public static IDictionary ExtractMetadata(GrpcAgentRuntime worker, IDictionary metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (IDictionary)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out fieldValue); + }); + + return baggage as IDictionary ?? new Dictionary(); + } + public static IDictionary ExtractMetadata(GrpcAgentRuntime worker, MapField metadata) + { + var dcp = worker.ServiceProvider.GetRequiredService(); + var baggage = dcp.ExtractBaggage(metadata, static (object? carrier, string fieldName, out string? fieldValue, out IEnumerable? fieldValues) => + { + var metadata = (MapField)carrier!; + fieldValues = null; + metadata.TryGetValue(fieldName, out var ceValue); + fieldValue = ceValue?.CeString; + }); + + return baggage as IDictionary ?? new Dictionary(); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtobufMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtobufMessageSerializer.cs new file mode 100644 index 000000000000..7d92614b7c3f --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/IProtobufMessageSerializer.cs @@ -0,0 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// IProtobufMessageSerializer.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtobufMessageSerializer +{ + Google.Protobuf.WellKnownTypes.Any Serialize(object input); + object Deserialize(Google.Protobuf.WellKnownTypes.Any input); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs new file mode 100644 index 000000000000..c736a1c38cde --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ISerializationRegistry.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ISerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface IProtoSerializationRegistry +{ + /// + /// Registers a serializer for the specified type. + /// + /// The type to register. + void RegisterSerializer(System.Type type) => RegisterSerializer(type, new ProtobufMessageSerializer(type)); + + void RegisterSerializer(System.Type type, IProtobufMessageSerializer serializer); + + /// + /// Gets the serializer for the specified type. + /// + /// The type to get the serializer for. + /// The serializer for the specified type. + IProtobufMessageSerializer? GetSerializer(System.Type type) => GetSerializer(TypeNameResolver.ResolveTypeName(type)); + IProtobufMessageSerializer? GetSerializer(string typeName); + + ITypeNameResolver TypeNameResolver { get; } + + bool Exists(System.Type type); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs new file mode 100644 index 000000000000..67ba1c577f4a --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ITypeNameResolver.cs @@ -0,0 +1,9 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ITypeNameResolver.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public interface ITypeNameResolver +{ + string ResolveTypeName(Type input); +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/Microsoft.AutoGen.Core.Grpc.csproj b/dotnet/src/Microsoft.AutoGen/Core.Grpc/Microsoft.AutoGen.Core.Grpc.csproj index c28a9b1c9087..6a68de1d8903 100644 --- a/dotnet/src/Microsoft.AutoGen/Core.Grpc/Microsoft.AutoGen.Core.Grpc.csproj +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/Microsoft.AutoGen.Core.Grpc.csproj @@ -14,7 +14,6 @@ - diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs new file mode 100644 index 000000000000..3175817c0eee --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufConversionExtensions.cs @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufConversionExtensions.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public static class ProtobufConversionExtensions +{ + // Convert an ISubscrptionDefinition to a Protobuf Subscription + public static Subscription? ToProtobuf(this ISubscriptionDefinition subscriptionDefinition) + { + // Check if is a TypeSubscription + if (subscriptionDefinition is Contracts.TypeSubscription typeSubscription) + { + return new Subscription + { + Id = typeSubscription.Id, + TypeSubscription = new Protobuf.TypeSubscription + { + TopicType = typeSubscription.TopicType, + AgentType = typeSubscription.AgentType + } + }; + } + + // Check if is a TypePrefixSubscription + if (subscriptionDefinition is Contracts.TypePrefixSubscription typePrefixSubscription) + { + return new Subscription + { + Id = typePrefixSubscription.Id, + TypePrefixSubscription = new Protobuf.TypePrefixSubscription + { + TopicTypePrefix = typePrefixSubscription.TopicTypePrefix, + AgentType = typePrefixSubscription.AgentType + } + }; + } + + return null; + } + + // Convert AgentId from Protobuf to AgentId + public static Contracts.AgentId FromProtobuf(this Protobuf.AgentId agentId) + { + return new Contracts.AgentId(agentId.Type, agentId.Key); + } + + // Convert AgentId from AgentId to Protobuf + public static Protobuf.AgentId ToProtobuf(this Contracts.AgentId agentId) + { + return new Protobuf.AgentId + { + Type = agentId.Type, + Key = agentId.Key + }; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs new file mode 100644 index 000000000000..09da49640ad0 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufMessageSerializer.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufMessageSerializer.cs + +using Google.Protobuf; +using Google.Protobuf.WellKnownTypes; + +namespace Microsoft.AutoGen.Core.Grpc; + +/// +/// Interface for serializing and deserializing agent messages. +/// +public class ProtobufMessageSerializer : IProtobufMessageSerializer +{ + private System.Type _concreteType; + + public ProtobufMessageSerializer(System.Type concreteType) + { + _concreteType = concreteType; + } + + public object Deserialize(Any message) + { + // Check if the concrete type is a proto IMessage + if (typeof(IMessage).IsAssignableFrom(_concreteType)) + { + var nameOfMethod = nameof(Any.Unpack); + var result = message.GetType().GetMethods().Where(m => m.Name == nameOfMethod && m.IsGenericMethod).First().MakeGenericMethod(_concreteType).Invoke(message, null); + return result as IMessage ?? throw new ArgumentException("Failed to deserialize", nameof(message)); + } + + // Raise an exception if the concrete type is not a proto IMessage + throw new ArgumentException("Concrete type must be a proto IMessage", nameof(_concreteType)); + } + + public Any Serialize(object message) + { + // Check if message is a proto IMessage + if (message is IMessage protoMessage) + { + return Any.Pack(protoMessage); + } + + // Raise an exception if the message is not a proto IMessage + throw new ArgumentException("Message must be a proto IMessage", nameof(message)); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufSerializationRegistry.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufSerializationRegistry.cs new file mode 100644 index 000000000000..1bc0449d5688 --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufSerializationRegistry.cs @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufSerializationRegistry.cs + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtobufSerializationRegistry : IProtoSerializationRegistry +{ + private readonly Dictionary _serializers + = new Dictionary(); + + public ITypeNameResolver TypeNameResolver => new ProtobufTypeNameResolver(); + + public bool Exists(Type type) + { + return _serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type)); + } + + public IProtobufMessageSerializer? GetSerializer(Type type) + { + return GetSerializer(TypeNameResolver.ResolveTypeName(type)); + } + + public IProtobufMessageSerializer? GetSerializer(string typeName) + { + _serializers.TryGetValue(typeName, out var serializer); + return serializer; + } + + public void RegisterSerializer(Type type, IProtobufMessageSerializer serializer) + { + if (_serializers.ContainsKey(TypeNameResolver.ResolveTypeName(type))) + { + throw new InvalidOperationException($"Serializer already registered for {type.FullName}"); + } + _serializers[TypeNameResolver.ResolveTypeName(type)] = serializer; + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufTypeNameResolver.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufTypeNameResolver.cs new file mode 100644 index 000000000000..e376f9a13daa --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/ProtobufTypeNameResolver.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ProtobufTypeNameResolver.cs + +using Google.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +public class ProtobufTypeNameResolver : ITypeNameResolver +{ + public string ResolveTypeName(Type input) + { + if (typeof(IMessage).IsAssignableFrom(input)) + { + // TODO: Consider changing this to avoid instantiation... + var protoMessage = (IMessage?)Activator.CreateInstance(input) ?? throw new InvalidOperationException($"Failed to create instance of {input.FullName}"); + return protoMessage.Descriptor.FullName; + } + else + { + throw new ArgumentException("Input must be a protobuf message."); + } + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core.Grpc/RpcExtensions.cs b/dotnet/src/Microsoft.AutoGen/Core.Grpc/RpcExtensions.cs new file mode 100644 index 000000000000..5c264887856c --- /dev/null +++ b/dotnet/src/Microsoft.AutoGen/Core.Grpc/RpcExtensions.cs @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// RpcExtensions.cs + +using Google.Protobuf; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc; + +internal static class RpcExtensions +{ + + public static Payload ToPayload(this object message, IProtoSerializationRegistry serializationRegistry) + { + if (!serializationRegistry.Exists(message.GetType())) + { + serializationRegistry.RegisterSerializer(message.GetType()); + } + var rpcMessage = (serializationRegistry.GetSerializer(message.GetType()) ?? throw new Exception()).Serialize(message); + + var typeName = serializationRegistry.TypeNameResolver.ResolveTypeName(message.GetType()); + const string PAYLOAD_DATA_CONTENT_TYPE = "application/x-protobuf"; + + // Protobuf any to byte array + Payload payload = new() + { + DataType = typeName, + DataContentType = PAYLOAD_DATA_CONTENT_TYPE, + Data = rpcMessage.ToByteString() + }; + + return payload; + } + + public static object ToObject(this Payload payload, IProtoSerializationRegistry serializationRegistry) + { + var typeName = payload.DataType; + var data = payload.Data; + var serializer = serializationRegistry.GetSerializer(typeName) ?? throw new Exception(); + var any = Google.Protobuf.WellKnownTypes.Any.Parser.ParseFrom(data); + return serializer.Deserialize(any); + } +} diff --git a/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs b/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs index bae09a9f1917..cecd8d9ec48d 100644 --- a/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs +++ b/dotnet/src/Microsoft.AutoGen/Core/AgentsApp.cs @@ -4,6 +4,7 @@ using System.Diagnostics; using System.Reflection; using Microsoft.AutoGen.Contracts; +using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; @@ -21,6 +22,7 @@ public AgentsAppBuilder(HostApplicationBuilder? baseBuilder = null) } public IServiceCollection Services => this.builder.Services; + public IConfiguration Configuration => this.builder.Configuration; public void AddAgentsFromAssemblies() { diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs index 7514609e145b..a2a970bebb41 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/AgentGrpcTests.cs @@ -1,263 +1,152 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // AgentGrpcTests.cs - -using System.Collections.Concurrent; -using System.Text.Json; using FluentAssertions; -using Google.Protobuf.Reflection; using Microsoft.AutoGen.Contracts; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; +// using Microsoft.AutoGen.Core.Tests; +using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf; using Microsoft.Extensions.Logging; using Xunit; -using static Microsoft.AutoGen.Core.Grpc.Tests.AgentGrpcTests; namespace Microsoft.AutoGen.Core.Grpc.Tests; -[Trait("Category", "UnitV2")] +[Trait("Category", "GRPC")] public class AgentGrpcTests { - /// - /// Verify that if the agent is not initialized via AgentWorker, it should throw the correct exception. - /// - /// void [Fact] - public async Task Agent_ShouldThrowException_WhenNotInitialized() + public async Task AgentShouldNotReceiveMessagesWhenNotSubscribedTest() { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(false); // Do not initialize + var fixture = new GrpcAgentRuntimeFixture(); + var runtime = (GrpcAgentRuntime)await fixture.Start(); - // Expect an exception when calling AddSubscriptionAsync because the agent is uninitialized - await Assert.ThrowsAsync( - async () => await agent.AddSubscriptionAsync("TestEvent") - ); - } + Logger logger = new(new LoggerFactory()); + TestProtobufAgent agent = null!; - /// - /// validate that the agent is initialized correctly with implicit subs - /// - /// void - [Fact] - public async Task Agent_ShouldInitializeCorrectly() - { - using var runtime = new GrpcRuntime(); - var (worker, agent) = runtime.Start(); - Assert.Equal(nameof(GrpcAgentRuntime), worker.GetType().Name); - await Task.Delay(5000); - var subscriptions = await agent.GetSubscriptionsAsync(); - Assert.Equal(2, subscriptions.Count); - } - /// - /// Test AddSubscriptionAsync method - /// - /// void - [Fact] - public async Task SubscribeAsync_UnsubscribeAsync_and_GetSubscriptionsTest() - { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); - await agent.AddSubscriptionAsync("TestEvent"); - await Task.Delay(100); - var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - var found = false; - foreach (var subscription in subscriptions) + await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => { - if (subscription.TypeSubscription.TopicType == "TestEvent") - { - found = true; - } - } - Assert.True(found); - await agent.RemoveSubscriptionAsync("TestEvent").ConfigureAwait(true); - await Task.Delay(1000); - subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - found = false; - foreach (var subscription in subscriptions) - { - if (subscription.TypeSubscription.TopicType == "TestEvent") - { - found = true; - } - } - Assert.False(found); + agent = new TestProtobufAgent(id, runtime, logger); + return await ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + var topicType = "TestTopic"; + + await runtime.PublishMessageAsync(new Protobuf.TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true); + + agent.ReceivedMessages.Any().Should().BeFalse("Agent should not receive messages when not subscribed."); + fixture.Dispose(); } - /// - /// Test StoreAsync and ReadAsync methods - /// - /// void [Fact] - public async Task StoreAsync_and_ReadAsyncTest() + public async Task AgentShouldReceiveMessagesWhenSubscribedTest() { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); - Dictionary state = new() - { - { "testdata", "Active" } - }; - await agent.StoreAsync(new AgentState + var fixture = new GrpcAgentRuntimeFixture(); + var runtime = (GrpcAgentRuntime)await fixture.Start(); + + Logger logger = new(new LoggerFactory()); + SubscribedProtobufAgent agent = null!; + + await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => { - AgentId = agent.AgentId, - TextData = JsonSerializer.Serialize(state) - }).ConfigureAwait(true); - var readState = await agent.ReadAsync(agent.AgentId).ConfigureAwait(true); - var read = JsonSerializer.Deserialize>(readState.TextData) ?? new Dictionary { { "data", "No state data found" } }; - read.TryGetValue("testdata", out var value); - Assert.Equal("Active", value); - } + agent = new SubscribedProtobufAgent(id, runtime, logger); + return await ValueTask.FromResult(agent); + }); + + // Ensure the agent is actually created + AgentId agentId = await runtime.GetAgentAsync("MyAgent", lazy: false); + + // Validate agent ID + agentId.Should().Be(agent.Id, "Agent ID should match the registered agent"); + + await runtime.RegisterImplicitAgentSubscriptionsAsync("MyAgent"); - /// - /// Test PublishMessageAsync method and ReceiveMessage method - /// - /// void - [Fact] - public async Task PublishMessageAsync_and_ReceiveMessageTest() - { - using var runtime = new GrpcRuntime(); - var (_, agent) = runtime.Start(); var topicType = "TestTopic"; - await agent.AddSubscriptionAsync(topicType).ConfigureAwait(true); - var subscriptions = await agent.GetSubscriptionsAsync().ConfigureAwait(true); - var found = false; - foreach (var subscription in subscriptions) - { - if (subscription.TypeSubscription.TopicType == topicType) - { - found = true; - } - } - Assert.True(found); - await agent.PublishMessageAsync(new TextMessage() - { - Source = topicType, - TextMessage_ = "buffer" - }, topicType).ConfigureAwait(true); - await Task.Delay(100); - Assert.True(TestAgent.ReceivedMessages.ContainsKey(topicType)); - runtime.Stop(); - } - [Fact] - public async Task InvokeCorrectHandler() - { - var agent = new TestAgent(new AgentsMetadata(TypeRegistry.Empty, new Dictionary(), new Dictionary>(), new Dictionary>()), new Logger(new LoggerFactory())); + await runtime.PublishMessageAsync(new TextMessage { Source = topicType, Content = "test" }, new TopicId(topicType)).ConfigureAwait(true); - await agent.HandleObjectAsync("hello world"); - await agent.HandleObjectAsync(42); + // Wait for the message to be processed + await Task.Delay(100); - agent.ReceivedItems.Should().HaveCount(2); - agent.ReceivedItems[0].Should().Be("hello world"); - agent.ReceivedItems[1].Should().Be(42); + agent.ReceivedMessages.Any().Should().BeTrue("Agent should receive messages when subscribed."); + fixture.Dispose(); } - /// - /// The test agent is a simple agent that is used for testing purposes. - /// - public class TestAgent( - [FromKeyedServices("AgentsMetadata")] AgentsMetadata eventTypes, - Logger? logger = null) : Agent(eventTypes, logger), IHandle + [Fact] + public async Task SendMessageAsyncShouldReturnResponseTest() { - public Task Handle(TextMessage item, CancellationToken cancellationToken = default) - { - ReceivedMessages[item.Source] = item.TextMessage_; - return Task.CompletedTask; - } - public Task Handle(string item) - { - ReceivedItems.Add(item); - return Task.CompletedTask; - } - public Task Handle(int item) + // Arrange + var fixture = new GrpcAgentRuntimeFixture(); + var runtime = (GrpcAgentRuntime)await fixture.Start(); + + Logger logger = new(new LoggerFactory()); + await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => await ValueTask.FromResult(new TestProtobufAgent(id, runtime, logger))); + var agentId = new AgentId("MyAgent", "default"); + var response = await runtime.SendMessageAsync(new RpcTextMessage { Source = "TestTopic", Content = "Request" }, agentId); + + // Assert + Assert.NotNull(response); + Assert.IsType(response); + if (response is RpcTextMessage responseString) { - ReceivedItems.Add(item); - return Task.CompletedTask; + Assert.Equal("Request", responseString.Content); } - public List ReceivedItems { get; private set; } = []; - - /// - /// Key: source - /// Value: message - /// - public static ConcurrentDictionary ReceivedMessages { get; private set; } = new(); + fixture.Dispose(); } -} - -/// -/// GrpcRuntimeFixture - provides a fixture for the agent runtime. -/// -/// -/// This fixture is used to provide a runtime for the agent tests. -/// However, it is shared between tests. So operations from one test can affect another. -/// -public sealed class GrpcRuntime : IDisposable -{ - public IHost Client { get; private set; } - public IHost? AppHost { get; private set; } - public GrpcRuntime() + public class ReceiverAgent(AgentId id, + IAgentRuntime runtime) : BaseAgent(id, runtime, "Receiver Agent", null), + IHandle { - Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development"); - AppHost = Host.CreateDefaultBuilder().Build(); - Client = Host.CreateDefaultBuilder().Build(); - } + public ValueTask HandleAsync(TextMessage item, MessageContext messageContext) + { + ReceivedItems.Add(item.Content); + return ValueTask.CompletedTask; + } - private static int GetAvailablePort() - { - using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0); - listener.Start(); - int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port; - listener.Stop(); - return port; + public List ReceivedItems { get; private set; } = []; } - private static async Task StartClientAsync() - { - return await AgentsApp.StartAsync().ConfigureAwait(false); - } - private static async Task StartAppHostAsync() + [Fact] + public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest() { - return await Microsoft.AutoGen.Runtime.Grpc.Host.StartAsync(local: false, useGrpc: true).ConfigureAwait(false); + var fixture = new GrpcAgentRuntimeFixture(); + var runtime = (GrpcAgentRuntime)await fixture.Start(); + ReceiverAgent? agent = null; + await runtime.RegisterAgentFactoryAsync("MyAgent", async (id, runtime) => + { + agent = new ReceiverAgent(id, runtime); + return await ValueTask.FromResult(agent); + }); - } + Assert.Null(agent); + await runtime.GetAgentAsync("MyAgent", lazy: false); + Assert.NotNull(agent); + Assert.True(agent.ReceivedItems.Count == 0); - /// - /// Start - gets a new port and starts fresh instances - /// - public (IAgentRuntime, TestAgent) Start(bool initialize = true) - { - int port = GetAvailablePort(); // Get a new port per test run + var topicTypeName = "TestTopic"; + await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName)); + await Task.Delay(100); - // Update environment variables so each test runs independently - Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString()); - Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}"); + Assert.True(agent.ReceivedItems.Count == 0); - AppHost = StartAppHostAsync().GetAwaiter().GetResult(); - Client = StartClientAsync().GetAwaiter().GetResult(); + var subscription = new TypeSubscription(topicTypeName, "MyAgent"); + await runtime.AddSubscriptionAsync(subscription); - var agent = ActivatorUtilities.CreateInstance(Client.Services); - var worker = Client.Services.GetRequiredService(); - if (initialize) - { - Agent.Initialize(worker, agent); - } + await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName)); + await Task.Delay(100); - return (worker, agent); - } + Assert.True(agent.ReceivedItems.Count == 1); + Assert.Equal("test", agent.ReceivedItems[0]); - /// - /// Stop - stops the agent and ensures cleanup - /// - public void Stop() - { - Client?.StopAsync().GetAwaiter().GetResult(); - AppHost?.StopAsync().GetAwaiter().GetResult(); - } + await runtime.RemoveSubscriptionAsync(subscription.Id); + await runtime.PublishMessageAsync(new TextMessage { Source = "topic", Content = "test" }, new TopicId(topicTypeName)); + await Task.Delay(100); - /// - /// Dispose - Ensures cleanup after each test - /// - public void Dispose() - { - Stop(); + Assert.True(agent.ReceivedItems.Count == 1); + fixture.Dispose(); } } diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentRuntimeFixture.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentRuntimeFixture.cs new file mode 100644 index 000000000000..bade7f785757 --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentRuntimeFixture.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentRuntimeFixture.cs +using Microsoft.AspNetCore.Builder; +using Microsoft.AutoGen.Contracts; +// using Microsoft.AutoGen.Core.Tests; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Microsoft.AutoGen.Core.Grpc.Tests; +/// +/// Fixture for setting up the gRPC agent runtime for testing. +/// +public sealed class GrpcAgentRuntimeFixture : IDisposable +{ + /// the gRPC agent runtime. + public AgentsApp? Client { get; private set; } + /// mock server for testing. + public WebApplication? Server { get; private set; } + + public GrpcAgentRuntimeFixture() + { + } + /// + /// Start - gets a new port and starts fresh instances + /// + public async Task Start(bool initialize = true) + { + int port = GetAvailablePort(); // Get a new port per test run + + // Update environment variables so each test runs independently + Environment.SetEnvironmentVariable("ASPNETCORE_HTTPS_PORTS", port.ToString()); + Environment.SetEnvironmentVariable("AGENT_HOST", $"https://localhost:{port}"); + Environment.SetEnvironmentVariable("ASPNETCORE_ENVIRONMENT", "Development"); + Server = ServerBuilder().Result; + await Server.StartAsync().ConfigureAwait(true); + Client = ClientBuilder().Result; + await Client.StartAsync().ConfigureAwait(true); + + var worker = Client.Services.GetRequiredService(); + + return (worker); + } + private static async Task ClientBuilder() + { + var appBuilder = new AgentsAppBuilder(); + appBuilder.AddGrpcAgentWorker(); + appBuilder.AddAgent("TestAgent"); + return await appBuilder.BuildAsync(); + } + private static async Task ServerBuilder() + { + var builder = WebApplication.CreateBuilder(); + builder.Services.AddGrpc(); + var app = builder.Build(); + app.MapGrpcService(); + return app; + } + private static int GetAvailablePort() + { + using var listener = new System.Net.Sockets.TcpListener(System.Net.IPAddress.Loopback, 0); + listener.Start(); + int port = ((System.Net.IPEndPoint)listener.LocalEndpoint).Port; + listener.Stop(); + return port; + } + /// + /// Stop - stops the agent and ensures cleanup + /// + public void Stop() + { + (Client as IHost)?.StopAsync(TimeSpan.FromSeconds(30)).GetAwaiter().GetResult(); + Server?.StopAsync().GetAwaiter().GetResult(); + } + + /// + /// Dispose - Ensures cleanup after each test + /// + public void Dispose() + { + Stop(); + } + +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs new file mode 100644 index 000000000000..98c47764269d --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/GrpcAgentServiceFixture.cs @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// GrpcAgentServiceFixture.cs +using Grpc.Core; +using Microsoft.AutoGen.Protobuf; +namespace Microsoft.AutoGen.Core.Grpc.Tests; + +/// +/// This fixture is largely just a loopback as we are testing the client side logic of the GrpcAgentRuntime in isolation from the rest of the system. +/// +public sealed class GrpcAgentServiceFixture() : AgentRpc.AgentRpcBase +{ + public override async Task OpenChannel(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + try + { + var workerProcess = new TestGrpcWorkerConnection(requestStream, responseStream, context); + await workerProcess.Connect().ConfigureAwait(true); + } + catch + { + if (context.CancellationToken.IsCancellationRequested) + { + return; + } + throw; + } + } + public override async Task GetState(AgentId request, ServerCallContext context) => new GetStateResponse { AgentState = new AgentState { AgentId = request } }; + public override async Task SaveState(AgentState request, ServerCallContext context) => new SaveStateResponse { }; + public override async Task AddSubscription(AddSubscriptionRequest request, ServerCallContext context) => new AddSubscriptionResponse { }; + public override async Task RemoveSubscription(RemoveSubscriptionRequest request, ServerCallContext context) => new RemoveSubscriptionResponse { }; + public override async Task GetSubscriptions(GetSubscriptionsRequest request, ServerCallContext context) => new GetSubscriptionsResponse { }; + public override async Task RegisterAgent(RegisterAgentTypeRequest request, ServerCallContext context) => new RegisterAgentTypeResponse { }; +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj index f14497e75fbc..e3573c93451a 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/Microsoft.AutoGen.Core.Grpc.Tests.csproj @@ -10,8 +10,17 @@ - + + + + + + + + + + diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestGrpcWorkerConnection.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestGrpcWorkerConnection.cs new file mode 100644 index 000000000000..20b8169db11f --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestGrpcWorkerConnection.cs @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TestGrpcWorkerConnection.cs + +using System.Threading.Channels; +using Grpc.Core; +using Microsoft.AutoGen.Protobuf; + +namespace Microsoft.AutoGen.Core.Grpc.Tests; + +internal sealed class TestGrpcWorkerConnection : IAsyncDisposable +{ + private static long s_nextConnectionId; + private Task _readTask = Task.CompletedTask; + private Task _writeTask = Task.CompletedTask; + private readonly string _connectionId = Interlocked.Increment(ref s_nextConnectionId).ToString(); + private readonly object _lock = new(); + private readonly HashSet _supportedTypes = []; + private readonly CancellationTokenSource _shutdownCancellationToken = new(); + public Task Completion { get; private set; } = Task.CompletedTask; + public IAsyncStreamReader RequestStream { get; } + public IServerStreamWriter ResponseStream { get; } + public ServerCallContext ServerCallContext { get; } + private readonly Channel _outboundMessages; + public TestGrpcWorkerConnection(IAsyncStreamReader requestStream, IServerStreamWriter responseStream, ServerCallContext context) + { + RequestStream = requestStream; + ResponseStream = responseStream; + ServerCallContext = context; + _outboundMessages = Channel.CreateUnbounded(new UnboundedChannelOptions { AllowSynchronousContinuations = true, SingleReader = true, SingleWriter = false }); + } + public Task Connect() + { + var didSuppress = false; + if (!ExecutionContext.IsFlowSuppressed()) + { + didSuppress = true; + ExecutionContext.SuppressFlow(); + } + + try + { + _readTask = Task.Run(RunReadPump); + _writeTask = Task.Run(RunWritePump); + } + finally + { + if (didSuppress) + { + ExecutionContext.RestoreFlow(); + } + } + + return Completion = Task.WhenAll(_readTask, _writeTask); + } + public void AddSupportedType(string type) + { + lock (_lock) + { + _supportedTypes.Add(type); + } + } + public HashSet GetSupportedTypes() + { + lock (_lock) + { + return new HashSet(_supportedTypes); + } + } + public async Task SendMessage(Message message) + { + await _outboundMessages.Writer.WriteAsync(message).ConfigureAwait(false); + } + public async Task RunReadPump() + { + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + try + { + await foreach (var message in RequestStream.ReadAllAsync(_shutdownCancellationToken.Token)) + { + //_gateway.OnReceivedMessageAsync(this, message, _shutdownCancellationToken.Token).Ignore(); + switch (message.MessageCase) + { + case Message.MessageOneofCase.Request: + await SendMessage(new Message { Request = message.Request }).ConfigureAwait(false); + break; + case Message.MessageOneofCase.Response: + await SendMessage(new Message { Response = message.Response }).ConfigureAwait(false); + break; + case Message.MessageOneofCase.CloudEvent: + await SendMessage(new Message { CloudEvent = message.CloudEvent }).ConfigureAwait(false); + break; + default: + // if it wasn't recognized return bad request + throw new RpcException(new Status(StatusCode.InvalidArgument, $"Unknown message type for message '{message}'")); + }; + } + } + catch (OperationCanceledException) + { + } + finally + { + _shutdownCancellationToken.Cancel(); + //_gateway.OnRemoveWorkerProcess(this); + } + } + + public async Task RunWritePump() + { + await Task.CompletedTask.ConfigureAwait(ConfigureAwaitOptions.ForceYielding); + try + { + await foreach (var message in _outboundMessages.Reader.ReadAllAsync(_shutdownCancellationToken.Token)) + { + await ResponseStream.WriteAsync(message); + } + } + catch (OperationCanceledException) + { + } + finally + { + _shutdownCancellationToken.Cancel(); + } + } + + public async ValueTask DisposeAsync() + { + _shutdownCancellationToken.Cancel(); + await Completion.ConfigureAwait(ConfigureAwaitOptions.SuppressThrowing); + } + + public override string ToString() => $"Connection-{_connectionId}"; +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestProtobufAgent.cs b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestProtobufAgent.cs new file mode 100644 index 000000000000..6f5ad4aa9e5b --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/TestProtobufAgent.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// TestProtobufAgent.cs + +using Microsoft.AutoGen.Contracts; +using Microsoft.AutoGen.Core.Grpc.Tests.Protobuf; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AutoGen.Core.Grpc.Tests; + +/// +/// The test agent is a simple agent that is used for testing purposes. +/// +public class TestProtobufAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : BaseAgent(id, runtime, "Test Agent", logger), + IHandle, + IHandle + +{ + public ValueTask HandleAsync(TextMessage item, MessageContext messageContext) + { + ReceivedMessages[item.Source] = item.Content; + return ValueTask.CompletedTask; + } + + public ValueTask HandleAsync(RpcTextMessage item, MessageContext messageContext) + { + ReceivedMessages[item.Source] = item.Content; + return ValueTask.FromResult(new RpcTextMessage { Source = item.Source, Content = item.Content }); + } + + public List ReceivedItems { get; private set; } = []; + + /// + /// Key: source + /// Value: message + /// + private readonly Dictionary _receivedMessages = new(); + public Dictionary ReceivedMessages => _receivedMessages; +} + +[TypeSubscription("TestTopic")] +public class SubscribedProtobufAgent : TestProtobufAgent +{ + public SubscribedProtobufAgent(AgentId id, + IAgentRuntime runtime, + Logger? logger = null) : base(id, runtime, logger) + { + } +} diff --git a/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/messages.proto b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/messages.proto new file mode 100644 index 000000000000..7f2c275e691f --- /dev/null +++ b/dotnet/test/Microsoft.AutoGen.Core.Grpc.Tests/messages.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +option csharp_namespace = "Microsoft.AutoGen.Core.Grpc.Tests.Protobuf"; + +message TextMessage { + string content = 1; + string source = 2; +} + +message RpcTextMessage { + string content = 1; + string source = 2; +} \ No newline at end of file diff --git a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs index cc39b3564c66..c091f9eb7478 100644 --- a/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs +++ b/dotnet/test/Microsoft.AutoGen.Core.Tests/AgentTests.cs @@ -109,7 +109,7 @@ public ValueTask HandleAsync(string item, MessageContext messageContext) } [Fact] - public async Task SubscribeAsyncRemoveSubscriptionAsyncAndGetSubscriptionsTest() + public async Task SubscribeAsyncRemoveSubscriptionAsyncTest() { var runtime = new InProcessRuntime(); await runtime.StartAsync(); diff --git a/protos/agent_events.proto b/protos/agent_events.proto deleted file mode 100644 index a97df6e5855f..000000000000 --- a/protos/agent_events.proto +++ /dev/null @@ -1,43 +0,0 @@ -syntax = "proto3"; - -package agents; - -option csharp_namespace = "Microsoft.AutoGen.Contracts"; -message TextMessage { - string textMessage = 1; - string source = 2; -} -message Input { - string message = 1; -} -message InputProcessed { - string route = 1; -} -message Output { - string message = 1; -} -message OutputWritten { - string route = 1; -} -message IOError { - string message = 1; -} -message NewMessageReceived { - string message = 1; -} -message ResponseGenerated { - string response = 1; -} -message GoodBye { - string message = 1; -} -message MessageStored { - string message = 1; -} -message ConversationClosed { - string user_id = 1; - string user_message = 2; -} -message Shutdown { - string message = 1; -} diff --git a/protos/agent_states.proto b/protos/agent_states.proto deleted file mode 100644 index 945772861cc8..000000000000 --- a/protos/agent_states.proto +++ /dev/null @@ -1,8 +0,0 @@ -syntax = "proto3"; -package agents; - -option csharp_namespace = "Microsoft.AutoGen.Contracts"; - -message AgentState { - string message = 1; -}