|
24 | 24 | update_agent_message, |
25 | 25 | PROTOCOL_VERSION, |
26 | 26 | ) |
27 | | -from acp.schema import AgentCapabilities, McpCapabilities, PromptCapabilities |
| 27 | +from acp.schema import AgentCapabilities, AgentMessageChunk, Implementation |
28 | 28 |
|
29 | 29 |
|
30 | 30 | class ExampleAgent(Agent): |
31 | 31 | def __init__(self, conn: AgentSideConnection) -> None: |
32 | 32 | self._conn = conn |
33 | 33 | self._next_session_id = 0 |
| 34 | + self._sessions: set[str] = set() |
34 | 35 |
|
35 | | - async def _send_chunk(self, session_id: str, content: Any) -> None: |
36 | | - await self._conn.sessionUpdate( |
37 | | - session_notification( |
38 | | - session_id, |
39 | | - update_agent_message(content), |
40 | | - ) |
41 | | - ) |
| 36 | + async def _send_agent_message(self, session_id: str, content: Any) -> None: |
| 37 | + update = content if isinstance(content, AgentMessageChunk) else update_agent_message(content) |
| 38 | + await self._conn.sessionUpdate(session_notification(session_id, update)) |
42 | 39 |
|
43 | 40 | async def initialize(self, params: InitializeRequest) -> InitializeResponse: # noqa: ARG002 |
44 | 41 | logging.info("Received initialize request") |
45 | | - mcp_caps: McpCapabilities = McpCapabilities(http=False, sse=False) |
46 | | - prompt_caps: PromptCapabilities = PromptCapabilities(audio=False, embeddedContext=False, image=False) |
47 | | - agent_caps: AgentCapabilities = AgentCapabilities( |
48 | | - loadSession=False, |
49 | | - mcpCapabilities=mcp_caps, |
50 | | - promptCapabilities=prompt_caps, |
51 | | - ) |
52 | 42 | return InitializeResponse( |
53 | 43 | protocolVersion=PROTOCOL_VERSION, |
54 | | - agentCapabilities=agent_caps, |
| 44 | + agentCapabilities=AgentCapabilities(), |
| 45 | + agentInfo=Implementation(name="example-agent", title="Example Agent", version="0.1.0"), |
55 | 46 | ) |
56 | 47 |
|
57 | 48 | async def authenticate(self, params: AuthenticateRequest) -> AuthenticateResponse | None: # noqa: ARG002 |
58 | | - logging.info("Received authenticate request") |
| 49 | + logging.info("Received authenticate request %s", params.methodId) |
59 | 50 | return AuthenticateResponse() |
60 | 51 |
|
61 | 52 | async def newSession(self, params: NewSessionRequest) -> NewSessionResponse: # noqa: ARG002 |
62 | 53 | logging.info("Received new session request") |
63 | 54 | session_id = str(self._next_session_id) |
64 | 55 | self._next_session_id += 1 |
65 | | - return NewSessionResponse(sessionId=session_id) |
| 56 | + self._sessions.add(session_id) |
| 57 | + return NewSessionResponse(sessionId=session_id, modes=None) |
66 | 58 |
|
67 | 59 | async def loadSession(self, params: LoadSessionRequest) -> LoadSessionResponse | None: # noqa: ARG002 |
68 | | - logging.info("Received load session request") |
| 60 | + logging.info("Received load session request %s", params.sessionId) |
| 61 | + self._sessions.add(params.sessionId) |
69 | 62 | return LoadSessionResponse() |
70 | 63 |
|
71 | 64 | async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeResponse | None: # noqa: ARG002 |
72 | | - logging.info("Received set session mode request") |
| 65 | + logging.info("Received set session mode request %s -> %s", params.sessionId, params.modeId) |
73 | 66 | return SetSessionModeResponse() |
74 | 67 |
|
75 | 68 | async def prompt(self, params: PromptRequest) -> PromptResponse: |
76 | | - logging.info("Received prompt request") |
| 69 | + logging.info("Received prompt request for session %s", params.sessionId) |
| 70 | + if params.sessionId not in self._sessions: |
| 71 | + self._sessions.add(params.sessionId) |
77 | 72 |
|
78 | | - # Notify the client what it just sent and then echo each content block back. |
79 | | - await self._send_chunk( |
80 | | - params.sessionId, |
81 | | - text_block("Client sent:"), |
82 | | - ) |
| 73 | + await self._send_agent_message(params.sessionId, text_block("Client sent:")) |
83 | 74 | for block in params.prompt: |
84 | | - await self._send_chunk(params.sessionId, block) |
| 75 | + await self._send_agent_message(params.sessionId, block) |
85 | 76 |
|
86 | 77 | return PromptResponse(stopReason="end_turn") |
87 | 78 |
|
88 | 79 | async def cancel(self, params: CancelNotification) -> None: # noqa: ARG002 |
89 | | - logging.info("Received cancel notification") |
| 80 | + logging.info("Received cancel notification for session %s", params.sessionId) |
90 | 81 |
|
91 | 82 | async def extMethod(self, method: str, params: dict) -> dict: # noqa: ARG002 |
92 | 83 | logging.info("Received extension method call: %s", method) |
|
0 commit comments