Skip to content

Commit 893fcf4

Browse files
authored
fix: preserve _meta during ACP serialization (#20)
Signed-off-by: Chojan Shang <[email protected]>
1 parent 4049995 commit 893fcf4

File tree

6 files changed

+91
-11
lines changed

6 files changed

+91
-11
lines changed

examples/echo_agent.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@ async def newSession(self, params: NewSessionRequest) -> NewSessionResponse:
2929

3030
async def prompt(self, params: PromptRequest) -> PromptResponse:
3131
for block in params.prompt:
32-
text = getattr(block, "text", "")
33-
await self._conn.sessionUpdate(
34-
session_notification(
35-
params.sessionId,
36-
update_agent_message(text_block(text)),
37-
)
38-
)
32+
text = block.get("text", "") if isinstance(block, dict) else getattr(block, "text", "")
33+
chunk = update_agent_message(text_block(text))
34+
chunk.field_meta = {"echo": True}
35+
chunk.content.field_meta = {"echo": True}
36+
37+
notification = session_notification(params.sessionId, chunk)
38+
notification.field_meta = {"source": "echo_agent"}
39+
40+
await self._conn.sessionUpdate(notification)
3941
return PromptResponse(stopReason="end_turn")
4042

4143

scripts/gen_schema.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def rename_types(output_path: Path) -> list[str]:
199199
content = _apply_field_overrides(content)
200200
content = _apply_default_overrides(content)
201201
content = _add_description_comments(content)
202+
content = _ensure_custom_base_model(content)
202203

203204
alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())]
204205
alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n"
@@ -220,6 +221,37 @@ def rename_types(output_path: Path) -> list[str]:
220221
return warnings
221222

222223

224+
def _ensure_custom_base_model(content: str) -> str:
225+
if "class BaseModel(_BaseModel):" in content:
226+
return content
227+
lines = content.splitlines()
228+
for idx, line in enumerate(lines):
229+
if not line.startswith("from pydantic import "):
230+
continue
231+
imports = [part.strip() for part in line[len("from pydantic import ") :].split(",")]
232+
has_alias = any(part == "BaseModel as _BaseModel" for part in imports)
233+
has_config = any(part == "ConfigDict" for part in imports)
234+
new_imports = []
235+
for part in imports:
236+
if part == "BaseModel":
237+
new_imports.append("BaseModel as _BaseModel")
238+
has_alias = True
239+
else:
240+
new_imports.append(part)
241+
if not has_alias:
242+
new_imports.append("BaseModel as _BaseModel")
243+
if not has_config:
244+
new_imports.append("ConfigDict")
245+
lines[idx] = "from pydantic import " + ", ".join(new_imports)
246+
insert_idx = idx + 1
247+
lines.insert(insert_idx, "")
248+
lines.insert(insert_idx + 1, "class BaseModel(_BaseModel):")
249+
lines.insert(insert_idx + 2, " model_config = ConfigDict(populate_by_name=True)")
250+
lines.insert(insert_idx + 3, "")
251+
break
252+
return "\n".join(lines) + "\n"
253+
254+
223255
def _apply_field_overrides(content: str) -> str:
224256
for class_name, field_name, new_type, optional in FIELD_TYPE_OVERRIDES:
225257
if optional:

src/acp/connection.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,12 @@ async def _run_request(self, message: dict[str, Any]) -> Any:
192192
try:
193193
result = await self._handler(method, message.get("params"), False)
194194
if isinstance(result, BaseModel):
195-
result = result.model_dump()
195+
result = result.model_dump(
196+
mode="json",
197+
by_alias=True,
198+
exclude_none=True,
199+
exclude_unset=True,
200+
)
196201
payload["result"] = result if result is not None else None
197202
await self._sender.send(payload)
198203
self._notify_observers(StreamDirection.OUTGOING, payload)

src/acp/schema.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from enum import Enum
77
from typing import Annotated, Any, List, Literal, Optional, Union
88

9-
from pydantic import BaseModel, Field, RootModel
10-
9+
from pydantic import BaseModel as _BaseModel, Field, RootModel, ConfigDict
1110

1211
PermissionOptionKind = Literal["allow_once", "allow_always", "reject_once", "reject_always"]
1312
PlanEntryPriority = Literal["high", "medium", "low"]
@@ -17,6 +16,10 @@
1716
ToolKind = Literal["read", "edit", "delete", "move", "search", "execute", "think", "fetch", "switch_mode", "other"]
1817

1918

19+
class BaseModel(_BaseModel):
20+
model_config = ConfigDict(populate_by_name=True)
21+
22+
2023
class Jsonrpc(Enum):
2124
field_2_0 = "2.0"
2225

src/acp/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
def serialize_params(params: BaseModel) -> dict[str, Any]:
2626
"""Return a JSON-serializable representation used for RPC calls."""
27-
return params.model_dump(exclude_none=True, exclude_defaults=True)
27+
return params.model_dump(by_alias=True, exclude_none=True, exclude_defaults=True)
2828

2929

3030
def normalize_result(payload: Any) -> dict[str, Any]:

tests/test_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from acp.schema import AgentMessageChunk, TextContentBlock
2+
from acp.utils import serialize_params
3+
4+
5+
def test_serialize_params_uses_meta_aliases() -> None:
6+
chunk = AgentMessageChunk(
7+
sessionUpdate="agent_message_chunk",
8+
content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}),
9+
field_meta={"outer": "value"},
10+
)
11+
12+
payload = serialize_params(chunk)
13+
14+
assert payload["_meta"] == {"outer": "value"}
15+
assert payload["content"]["_meta"] == {"inner": "value"}
16+
17+
18+
def test_serialize_params_omits_meta_when_absent() -> None:
19+
chunk = AgentMessageChunk(
20+
sessionUpdate="agent_message_chunk",
21+
content=TextContentBlock(type="text", text="demo"),
22+
)
23+
24+
payload = serialize_params(chunk)
25+
26+
assert "_meta" not in payload
27+
assert "_meta" not in payload["content"]
28+
29+
30+
def test_field_meta_can_be_set_by_name_on_models() -> None:
31+
chunk = AgentMessageChunk(
32+
sessionUpdate="agent_message_chunk",
33+
content=TextContentBlock(type="text", text="demo", field_meta={"inner": "value"}),
34+
field_meta={"outer": "value"},
35+
)
36+
37+
assert chunk.field_meta == {"outer": "value"}
38+
assert chunk.content.field_meta == {"inner": "value"}

0 commit comments

Comments
 (0)