Skip to content

Commit

Permalink
Remove channel based control plane APIs, cleanup proto (microsoft#5236)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits authored Jan 28, 2025
1 parent 91249c4 commit 7445e4b
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 375 deletions.
55 changes: 12 additions & 43 deletions protos/agent_worker.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@ syntax = "proto3";

package agents;

option csharp_namespace = "Microsoft.AutoGen.Contracts";
option csharp_namespace = "Microsoft.AutoGen.Protobuf";

import "cloudevent.proto";
import "google/protobuf/any.proto";

message TopicId {
string type = 1;
string source = 2;
}

message AgentId {
string type = 1;
Expand Down Expand Up @@ -39,23 +35,11 @@ message RpcResponse {
map<string, string> metadata = 4;
}

message Event {
string topic_type = 1;
string topic_source = 2;
optional AgentId source = 3;
Payload payload = 4;
map<string, string> metadata = 5;
}

message RegisterAgentTypeRequest {
string request_id = 1; // TODO: remove once message based requests are removed
string type = 2;
string type = 1;
}

message RegisterAgentTypeResponse {
string request_id = 1; // TODO: remove once message based requests are removed
bool success = 2;
optional string error = 3;
}

message TypeSubscription {
Expand All @@ -77,40 +61,24 @@ message Subscription {
}

message AddSubscriptionRequest {
string request_id = 1; // TODO: remove once message based requests are removed
Subscription subscription = 2;
Subscription subscription = 1;
}

message AddSubscriptionResponse {
string request_id = 1; // TODO: remove once message based requests are removed
bool success = 2;
optional string error = 3;
}

message RemoveSubscriptionRequest {
string id = 1;
}

message RemoveSubscriptionResponse {
bool success = 1;
optional string error = 2;
}

message GetSubscriptionsRequest {}
message GetSubscriptionsResponse {
repeated Subscription subscriptions = 1;
}

service AgentRpc {
rpc OpenChannel (stream Message) returns (stream Message);
rpc GetState(AgentId) returns (GetStateResponse);
rpc SaveState(AgentState) returns (SaveStateResponse);
rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse);
rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse);
rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse);
rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse);
}

message AgentState {
AgentId agent_id = 1;
string eTag = 2;
Expand All @@ -123,24 +91,25 @@ message AgentState {

message GetStateResponse {
AgentState agent_state = 1;
bool success = 2;
optional string error = 3;
}

message SaveStateResponse {
bool success = 1;
optional string error = 2;
}

message Message {
oneof message {
RpcRequest request = 1;
RpcResponse response = 2;
io.cloudevents.v1.CloudEvent cloudEvent = 3;
RegisterAgentTypeRequest registerAgentTypeRequest = 4;
RegisterAgentTypeResponse registerAgentTypeResponse = 5;
AddSubscriptionRequest addSubscriptionRequest = 6;
AddSubscriptionResponse addSubscriptionResponse = 7;
}
}

service AgentRpc {
rpc OpenChannel (stream Message) returns (stream Message);
rpc GetState(AgentId) returns (GetStateResponse);
rpc SaveState(AgentState) returns (SaveStateResponse);
rpc RegisterAgent(RegisterAgentTypeRequest) returns (RegisterAgentTypeResponse);
rpc AddSubscription(AddSubscriptionRequest) returns (AddSubscriptionResponse);
rpc RemoveSubscription(RemoveSubscriptionRequest) returns (RemoveSubscriptionResponse);
rpc GetSubscriptions(GetSubscriptionsRequest) returns (GetSubscriptionsResponse);
}
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,6 @@ async def _run_read_loop(self) -> None:
message = await self._host_connection.recv()
oneofcase = agent_worker_pb2.Message.WhichOneof(message, "message")
match oneofcase:
case "registerAgentTypeRequest" | "addSubscriptionRequest":
logger.warning(f"Cant handle {oneofcase}, skipping.")
case "request":
task = asyncio.create_task(self._process_request(message.request))
self._background_tasks.add(task)
Expand All @@ -292,20 +290,6 @@ async def _run_read_loop(self) -> None:
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeResponse":
task = asyncio.create_task(
self._process_register_agent_type_response(message.registerAgentTypeResponse)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscriptionResponse":
task = asyncio.create_task(
self._process_add_subscription_response(message.addSubscriptionResponse)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case None:
logger.warning("No message")
except Exception as e:
Expand Down Expand Up @@ -737,28 +721,13 @@ async def factory_wrapper() -> T:

self._agent_factories[type.type] = factory_wrapper

# Create a future for the registration response.
future = asyncio.get_event_loop().create_future()
request_id = await self._get_new_request_id()
self._pending_requests[request_id] = future

# Send the registration request message to the host.
message = agent_worker_pb2.RegisterAgentTypeRequest(request_id=request_id, type=type.type)
response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type)
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
# TODO: just use grpc error handling
if not response.success:
raise RuntimeError(response.error)
return type

async def _process_register_agent_type_response(self, response: agent_worker_pb2.RegisterAgentTypeResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error") and response.error != "":
future.set_exception(RuntimeError(response.error))
else:
future.set_result(None)

async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
Expand Down Expand Up @@ -812,28 +781,14 @@ async def add_subscription(self, subscription: Subscription) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")

# Create a future for the subscription response.
request_id = await self._get_new_request_id()

message = agent_worker_pb2.AddSubscriptionRequest(
request_id=request_id, subscription=subscription_to_proto(subscription)
)
response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
message = agent_worker_pb2.AddSubscriptionRequest(subscription=subscription_to_proto(subscription))
_response: agent_worker_pb2.AddSubscriptionResponse = await self._host_connection.stub.AddSubscription(
message, metadata=self._host_connection.metadata
)
if not response.success:
raise RuntimeError(response.error)

# Add to local subscription manager.
await self._subscription_manager.add_subscription(subscription)

async def _process_add_subscription_response(self, response: agent_worker_pb2.AddSubscriptionResponse) -> None:
future = self._pending_requests.pop(response.request_id)
if response.HasField("error") and response.error != "":
future.set_exception(RuntimeError(response.error))
else:
future.set_result(None)

async def remove_subscription(self, id: str) -> None:
raise NotImplementedError("Subscriptions cannot be removed while using distributed runtime currently.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,6 @@ async def _receive_messages(
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeRequest":
register_agent_type: agent_worker_pb2.RegisterAgentTypeRequest = message.registerAgentTypeRequest
task = asyncio.create_task(
self._process_register_agent_type_request(register_agent_type, client_id)
)
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "addSubscriptionRequest":
add_subscription: agent_worker_pb2.AddSubscriptionRequest = message.addSubscriptionRequest
task = asyncio.create_task(self._process_add_subscription_request(add_subscription, client_id))
self._background_tasks.add(task)
task.add_done_callback(self._raise_on_exception)
task.add_done_callback(self._background_tasks.discard)
case "registerAgentTypeResponse" | "addSubscriptionResponse":
logger.warning(f"Received unexpected message type: {oneofcase}")
case None:
logger.warning("Received empty message")

Expand Down Expand Up @@ -204,53 +188,6 @@ async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
for client_id in client_ids:
await self._send_queues[client_id].put(agent_worker_pb2.Message(cloudEvent=event))

async def _process_register_agent_type_request(
self, register_agent_type_req: agent_worker_pb2.RegisterAgentTypeRequest, client_id: ClientConnectionId
) -> None:
# Register the agent type with the host runtime.
async with self._agent_type_to_client_id_lock:
if register_agent_type_req.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[register_agent_type_req.type]
logger.error(
f"Agent type {register_agent_type_req.type} already registered with client {existing_client_id}."
)
success = False
error = f"Agent type {register_agent_type_req.type} already registered."
else:
self._agent_type_to_client_id[register_agent_type_req.type] = client_id
success = True
error = None
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
registerAgentTypeResponse=agent_worker_pb2.RegisterAgentTypeResponse(
request_id=register_agent_type_req.request_id, success=success, error=error
)
)
)

async def _process_add_subscription_request(
self, add_subscription_req: agent_worker_pb2.AddSubscriptionRequest, client_id: ClientConnectionId
) -> None:
subscription = subscription_from_proto(add_subscription_req.subscription)
try:
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
await self._send_queues[client_id].put(
agent_worker_pb2.Message(
addSubscriptionResponse=agent_worker_pb2.AddSubscriptionResponse(
request_id=add_subscription_req.request_id, success=success, error=error
)
)
)

async def RegisterAgent( # type: ignore
self,
request: agent_worker_pb2.RegisterAgentTypeRequest,
Expand All @@ -263,14 +200,14 @@ async def RegisterAgent( # type: ignore
async with self._agent_type_to_client_id_lock:
if request.type in self._agent_type_to_client_id:
existing_client_id = self._agent_type_to_client_id[request.type]
logger.error(f"Agent type {request.type} already registered with client {existing_client_id}.")
success = False
error = f"Agent type {request.type} already registered."
await context.abort(
grpc.StatusCode.INVALID_ARGUMENT,
f"Agent type {request.type} already registered with client {existing_client_id}.",
)
else:
self._agent_type_to_client_id[request.type] = client_id
success = True
error = None
return agent_worker_pb2.RegisterAgentTypeResponse(request_id=request.request_id, success=success, error=error)

return agent_worker_pb2.RegisterAgentTypeResponse()

async def AddSubscription( # type: ignore
self,
Expand All @@ -286,13 +223,9 @@ async def AddSubscription( # type: ignore
await self._subscription_manager.add_subscription(subscription)
subscription_ids = self._client_id_to_subscription_id_mapping.setdefault(client_id, set())
subscription_ids.add(subscription.id)
success = True
error = None
except ValueError as e:
success = False
error = str(e)
# Send a response back to the client.
return agent_worker_pb2.AddSubscriptionResponse(request_id=request.request_id, success=success, error=error)
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
return agent_worker_pb2.AddSubscriptionResponse()

async def RemoveSubscription( # type: ignore
self,
Expand Down
Loading

0 comments on commit 7445e4b

Please sign in to comment.