|
3 | 3 |
|
4 | 4 | import argparse |
5 | 5 | import ast |
| 6 | +import json |
6 | 7 | import re |
7 | 8 | import shutil |
8 | 9 | import subprocess |
9 | 10 | import sys |
10 | 11 | from collections.abc import Callable |
| 12 | +from dataclasses import dataclass |
11 | 13 | from pathlib import Path |
12 | 14 |
|
13 | 15 | ROOT = Path(__file__).resolve().parents[1] |
|
25 | 27 | re.DOTALL, |
26 | 28 | ) |
27 | 29 |
|
| 30 | +STDIO_TYPE_LITERAL = 'Literal["2#-datamodel-code-generator-#-object-#-special-#"]' |
| 31 | +STDIO_TYPE_PATTERN = re.compile( |
| 32 | + r"^ type:\s*Literal\[['\"]2#-datamodel-code-generator-#-object-#-special-#['\"]\]" |
| 33 | + r"(?:\s*=\s*['\"][^'\"]+['\"])?\s*$", |
| 34 | + re.MULTILINE, |
| 35 | +) |
| 36 | + |
28 | 37 | # Map of numbered classes produced by datamodel-code-generator to descriptive names. |
29 | 38 | # Keep this in sync with the Rust/TypeScript SDK nomenclature. |
30 | 39 | RENAME_MAP: dict[str, str] = { |
|
109 | 118 | ) |
110 | 119 |
|
111 | 120 |
|
| 121 | +@dataclass(frozen=True) |
| 122 | +class _ProcessingStep: |
| 123 | + """A named transformation applied to the generated schema content.""" |
| 124 | + |
| 125 | + name: str |
| 126 | + apply: Callable[[str], str] |
| 127 | + |
| 128 | + |
112 | 129 | def parse_args() -> argparse.Namespace: |
113 | 130 | parser = argparse.ArgumentParser(description="Generate src/acp/schema.py from the ACP JSON schema.") |
114 | 131 | parser.add_argument( |
@@ -159,68 +176,158 @@ def generate_schema(*, format_output: bool = True) -> None: |
159 | 176 | ] |
160 | 177 |
|
161 | 178 | subprocess.check_call(cmd) # noqa: S603 |
162 | | - warnings = rename_types(SCHEMA_OUT) |
| 179 | + warnings = postprocess_generated_schema(SCHEMA_OUT) |
163 | 180 | for warning in warnings: |
164 | 181 | print(f"Warning: {warning}", file=sys.stderr) |
165 | 182 |
|
166 | 183 | if format_output: |
167 | 184 | format_with_ruff(SCHEMA_OUT) |
168 | 185 |
|
169 | 186 |
|
170 | | -def rename_types(output_path: Path) -> list[str]: |
| 187 | +def postprocess_generated_schema(output_path: Path) -> list[str]: |
171 | 188 | if not output_path.exists(): |
172 | 189 | raise RuntimeError(f"Generated schema not found at {output_path}") # noqa: TRY003 |
173 | 190 |
|
174 | | - content = output_path.read_text(encoding="utf-8") |
| 191 | + raw_content = output_path.read_text(encoding="utf-8") |
| 192 | + header_block = _build_header_block() |
| 193 | + |
| 194 | + content = _strip_existing_header(raw_content) |
| 195 | + content = _remove_backcompat_block(content) |
| 196 | + content, leftover_classes = _rename_numbered_models(content) |
| 197 | + |
| 198 | + processing_steps: tuple[_ProcessingStep, ...] = ( |
| 199 | + _ProcessingStep("apply field overrides", _apply_field_overrides), |
| 200 | + _ProcessingStep("apply default overrides", _apply_default_overrides), |
| 201 | + _ProcessingStep("normalize stdio literal", _normalize_stdio_model), |
| 202 | + _ProcessingStep("attach description comments", _add_description_comments), |
| 203 | + _ProcessingStep("ensure custom BaseModel", _ensure_custom_base_model), |
| 204 | + ) |
| 205 | + |
| 206 | + for step in processing_steps: |
| 207 | + content = step.apply(content) |
| 208 | + |
| 209 | + missing_targets = _find_missing_targets(content) |
| 210 | + |
| 211 | + content = _inject_enum_aliases(content) |
| 212 | + alias_block = _build_alias_block() |
| 213 | + final_content = header_block + content.rstrip() + "\n\n" + alias_block |
| 214 | + if not final_content.endswith("\n"): |
| 215 | + final_content += "\n" |
| 216 | + output_path.write_text(final_content, encoding="utf-8") |
| 217 | + |
| 218 | + warnings: list[str] = [] |
| 219 | + if leftover_classes: |
| 220 | + warnings.append( |
| 221 | + "Unrenamed schema models detected: " |
| 222 | + + ", ".join(leftover_classes) |
| 223 | + + ". Update RENAME_MAP in scripts/gen_schema.py." |
| 224 | + ) |
| 225 | + if missing_targets: |
| 226 | + warnings.append( |
| 227 | + "Renamed schema targets not found after generation: " |
| 228 | + + ", ".join(sorted(missing_targets)) |
| 229 | + + ". Check RENAME_MAP or upstream schema changes." |
| 230 | + ) |
| 231 | + warnings.extend(_validate_schema_alignment()) |
| 232 | + |
| 233 | + return warnings |
| 234 | + |
175 | 235 |
|
| 236 | +def _build_header_block() -> str: |
176 | 237 | header_lines = ["# Generated from schema/schema.json. Do not edit by hand."] |
177 | 238 | if VERSION_FILE.exists(): |
178 | 239 | ref = VERSION_FILE.read_text(encoding="utf-8").strip() |
179 | 240 | if ref: |
180 | 241 | header_lines.append(f"# Schema ref: {ref}") |
| 242 | + return "\n".join(header_lines) + "\n\n" |
| 243 | + |
| 244 | + |
| 245 | +def _build_alias_block() -> str: |
| 246 | + alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] |
| 247 | + return BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" |
181 | 248 |
|
| 249 | + |
| 250 | +def _strip_existing_header(content: str) -> str: |
182 | 251 | existing_header = re.match(r"(#.*\n)+", content) |
183 | 252 | if existing_header: |
184 | | - content = content[existing_header.end() :] |
185 | | - content = content.lstrip("\n") |
| 253 | + return content[existing_header.end() :].lstrip("\n") |
| 254 | + return content.lstrip("\n") |
| 255 | + |
186 | 256 |
|
| 257 | +def _remove_backcompat_block(content: str) -> str: |
187 | 258 | marker_index = content.find(BACKCOMPAT_MARKER) |
188 | 259 | if marker_index != -1: |
189 | | - content = content[:marker_index].rstrip() |
| 260 | + return content[:marker_index].rstrip() |
| 261 | + return content |
| 262 | + |
190 | 263 |
|
| 264 | +def _rename_numbered_models(content: str) -> tuple[str, list[str]]: |
| 265 | + renamed = content |
191 | 266 | for old, new in sorted(RENAME_MAP.items(), key=lambda item: len(item[0]), reverse=True): |
192 | 267 | pattern = re.compile(rf"\b{re.escape(old)}\b") |
193 | | - content = pattern.sub(new, content) |
| 268 | + renamed = pattern.sub(new, renamed) |
194 | 269 |
|
195 | 270 | leftover_class_pattern = re.compile(r"^class (\w+\d+)\(", re.MULTILINE) |
196 | | - leftover_classes = sorted(set(leftover_class_pattern.findall(content))) |
| 271 | + leftover_classes = sorted(set(leftover_class_pattern.findall(renamed))) |
| 272 | + return renamed, leftover_classes |
197 | 273 |
|
198 | | - header_block = "\n".join(header_lines) + "\n\n" |
199 | | - content = _apply_field_overrides(content) |
200 | | - content = _apply_default_overrides(content) |
201 | | - content = _add_description_comments(content) |
202 | | - content = _ensure_custom_base_model(content) |
203 | 274 |
|
204 | | - alias_lines = [f"{old} = {new}" for old, new in sorted(RENAME_MAP.items())] |
205 | | - alias_block = BACKCOMPAT_MARKER + "\n" + "\n".join(alias_lines) + "\n" |
| 275 | +def _find_missing_targets(content: str) -> list[str]: |
| 276 | + missing: list[str] = [] |
| 277 | + for new_name in RENAME_MAP.values(): |
| 278 | + pattern = re.compile(rf"^class {re.escape(new_name)}\(", re.MULTILINE) |
| 279 | + if not pattern.search(content): |
| 280 | + missing.append(new_name) |
| 281 | + return missing |
206 | 282 |
|
207 | | - content = _inject_enum_aliases(content) |
208 | | - content = header_block + content.rstrip() + "\n\n" + alias_block |
209 | | - if not content.endswith("\n"): |
210 | | - content += "\n" |
211 | | - output_path.write_text(content, encoding="utf-8") |
212 | 283 |
|
| 284 | +def _validate_schema_alignment() -> list[str]: |
213 | 285 | warnings: list[str] = [] |
214 | | - if leftover_classes: |
215 | | - warnings.append( |
216 | | - "Unrenamed schema models detected: " |
217 | | - + ", ".join(leftover_classes) |
218 | | - + ". Update RENAME_MAP in scripts/gen_schema.py." |
219 | | - ) |
| 286 | + if not SCHEMA_JSON.exists(): |
| 287 | + warnings.append("schema/schema.json missing; unable to validate enum aliases.") |
| 288 | + return warnings |
220 | 289 |
|
| 290 | + try: |
| 291 | + schema_enums = _load_schema_enum_literals() |
| 292 | + except json.JSONDecodeError as exc: |
| 293 | + warnings.append(f"Failed to parse schema/schema.json: {exc}") |
| 294 | + return warnings |
| 295 | + |
| 296 | + for enum_name, expected_values in ENUM_LITERAL_MAP.items(): |
| 297 | + schema_values = schema_enums.get(enum_name) |
| 298 | + if schema_values is None: |
| 299 | + warnings.append( |
| 300 | + f"Enum '{enum_name}' not found in schema.json; update ENUM_LITERAL_MAP or investigate schema changes." |
| 301 | + ) |
| 302 | + continue |
| 303 | + if tuple(schema_values) != expected_values: |
| 304 | + warnings.append( |
| 305 | + f"Enum mismatch for '{enum_name}': schema.json -> {schema_values}, generated aliases -> {expected_values}" |
| 306 | + ) |
221 | 307 | return warnings |
222 | 308 |
|
223 | 309 |
|
| 310 | +def _load_schema_enum_literals() -> dict[str, tuple[str, ...]]: |
| 311 | + schema_data = json.loads(SCHEMA_JSON.read_text(encoding="utf-8")) |
| 312 | + defs = schema_data.get("$defs", {}) |
| 313 | + enum_literals: dict[str, tuple[str, ...]] = {} |
| 314 | + |
| 315 | + for name, definition in defs.items(): |
| 316 | + values: list[str] = [] |
| 317 | + if "enum" in definition: |
| 318 | + values = [str(item) for item in definition["enum"]] |
| 319 | + elif "oneOf" in definition: |
| 320 | + values = [ |
| 321 | + str(option["const"]) |
| 322 | + for option in definition.get("oneOf", []) |
| 323 | + if isinstance(option, dict) and "const" in option |
| 324 | + ] |
| 325 | + if values: |
| 326 | + enum_literals[name] = tuple(values) |
| 327 | + |
| 328 | + return enum_literals |
| 329 | + |
| 330 | + |
224 | 331 | def _ensure_custom_base_model(content: str) -> str: |
225 | 332 | if "class BaseModel(_BaseModel):" in content: |
226 | 333 | return content |
@@ -323,6 +430,19 @@ def replace_block( |
323 | 430 | return content |
324 | 431 |
|
325 | 432 |
|
| 433 | +def _normalize_stdio_model(content: str) -> str: |
| 434 | + replacement_line = ' type: Literal["stdio"] = "stdio"' |
| 435 | + new_content, count = STDIO_TYPE_PATTERN.subn(replacement_line, content) |
| 436 | + if count == 0: |
| 437 | + return content |
| 438 | + if count > 1: |
| 439 | + print( |
| 440 | + "Warning: multiple stdio type placeholders detected; manual review required.", |
| 441 | + file=sys.stderr, |
| 442 | + ) |
| 443 | + return new_content |
| 444 | + |
| 445 | + |
326 | 446 | def _add_description_comments(content: str) -> str: |
327 | 447 | lines = content.splitlines() |
328 | 448 | new_lines: list[str] = [] |
|
0 commit comments