Skip to content

Commit 037a070

Browse files
committed
switch to pydantic v2
1 parent d9e7927 commit 037a070

File tree

7 files changed

+121
-78
lines changed

7 files changed

+121
-78
lines changed

tests/config/test_loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
from pathlib import Path
44

5-
import pydantic.v1 as pd
65
from click.testing import CliRunner
6+
from pydantic import Field
77

88
from tidy3d.config import get_manager, reload_config
99
from tidy3d.config import registry as config_registry
@@ -49,7 +49,7 @@ def test_profile_preserves_comments(config_manager, mock_config_dir):
4949
class ProfileComment(ConfigSection):
5050
"""Profile comment plugin."""
5151

52-
knob: int = pd.Field(
52+
knob: int = Field(
5353
1,
5454
description="Profile knob description.",
5555
json_schema_extra={"persist": True},
@@ -83,7 +83,7 @@ def test_cli_reset_config(mock_config_dir):
8383
class CLIPlugin(ConfigSection):
8484
"""CLI plugin configuration."""
8585

86-
knob: int = pd.Field(
86+
knob: int = Field(
8787
3,
8888
description="CLI knob description.",
8989
json_schema_extra={"persist": True},
@@ -120,7 +120,7 @@ def test_plugin_descriptions(mock_config_dir):
120120
class CommentPlugin(ConfigSection):
121121
"""Comment plugin configuration."""
122122

123-
knob: int = pd.Field(
123+
knob: int = Field(
124124
3,
125125
description="Plugin knob description.",
126126
json_schema_extra={"persist": True},

tests/config/test_plugins.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
import pydantic.v1 as pd
43
import toml
4+
from pydantic import Field
55

66
from tidy3d.config.__init__ import get_manager, reload_config
77
from tidy3d.config.registry import get_sections, register_plugin
@@ -14,8 +14,8 @@ def ensure_dummy_plugin():
1414

1515
@register_plugin("dummy")
1616
class DummyPlugin(ConfigSection):
17-
enabled: bool = pd.Field(False, json_schema_extra={"persist": True})
18-
precision: int = pd.Field(1, json_schema_extra={"persist": True})
17+
enabled: bool = Field(False, json_schema_extra={"persist": True})
18+
precision: int = Field(1, json_schema_extra={"persist": True})
1919

2020

2121
def test_plugin_defaults_available(mock_config_dir):

tests/config/test_profiles.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ def test_save_custom_profile(config_manager):
2020
config_manager.save()
2121

2222
profile_path = config_manager.config_dir / "profiles" / "customer.toml"
23-
assert not profile_path.exists()
23+
assert profile_path.exists()
24+
data = toml.load(profile_path)
25+
assert data["logging"]["level"] == "DEBUG"

tidy3d/config/manager.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from collections.abc import Iterable
99
from copy import deepcopy
1010
from pathlib import Path
11-
from typing import Any, Optional
11+
from typing import Any, Optional, get_args, get_origin
1212

13-
from pydantic.v1 import BaseModel
13+
from pydantic import BaseModel
1414

1515
from tidy3d.log import log
1616

@@ -61,7 +61,7 @@ def dict(self, *args, **kwargs): # type: ignore[override]
6161
model = self._manager._get_model(self._path)
6262
if model is None:
6363
return {}
64-
return model.dict(*args, **kwargs)
64+
return model.model_dump(*args, **kwargs)
6565

6666

6767
class PluginsAccessor:
@@ -380,7 +380,7 @@ def __setattr__(self, name: str, value: Any) -> None:
380380
return
381381
if name in self._section_models:
382382
if isinstance(value, BaseModel):
383-
payload = value.dict(exclude_unset=False)
383+
payload = value.model_dump(exclude_unset=False)
384384
else:
385385
payload = value
386386
self.update_section(name, **payload)
@@ -404,16 +404,33 @@ def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, A
404404
return node if isinstance(node, dict) else None
405405

406406

407+
def _resolve_model_type(annotation: Any) -> Optional[type[BaseModel]]:
408+
"""Return the first BaseModel subclass found in an annotation (if any)."""
409+
410+
if isinstance(annotation, type) and issubclass(annotation, BaseModel):
411+
return annotation
412+
413+
origin = get_origin(annotation)
414+
if origin is None:
415+
return None
416+
417+
for arg in get_args(annotation):
418+
nested = _resolve_model_type(arg)
419+
if nested is not None:
420+
return nested
421+
return None
422+
423+
407424
def _serialize_value(value: Any) -> Any:
408425
if isinstance(value, BaseModel):
409-
return value.dict(exclude_unset=False)
426+
return value.model_dump(exclude_unset=False)
410427
if hasattr(value, "get_secret_value"):
411428
return value.get_secret_value()
412429
return value
413430

414431

415432
def _model_dict(model: BaseModel) -> dict[str, Any]:
416-
data = model.dict(exclude_unset=False)
433+
data = model.model_dump(exclude_unset=False)
417434
for key, value in list(data.items()):
418435
if hasattr(value, "get_secret_value"):
419436
data[key] = value.get_secret_value()
@@ -422,10 +439,10 @@ def _model_dict(model: BaseModel) -> dict[str, Any]:
422439

423440
def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[str, Any]:
424441
persisted: dict[str, Any] = {}
425-
for field_name, field in schema.__fields__.items():
426-
extra = getattr(field.field_info, "extra", {}) or {}
427-
schema_extra = extra.get("json_schema_extra", {})
428-
persist = schema_extra.get("persist") if isinstance(schema_extra, dict) else False
442+
for field_name, field in schema.model_fields.items():
443+
schema_extra = field.json_schema_extra or {}
444+
annotation = field.annotation
445+
persist = bool(schema_extra.get("persist")) if isinstance(schema_extra, dict) else False
429446
if not persist:
430447
continue
431448
if field_name not in data:
@@ -435,10 +452,10 @@ def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[st
435452
persisted[field_name] = None
436453
continue
437454

438-
field_type = getattr(field, "type_", None)
439-
if isinstance(field_type, type) and issubclass(field_type, BaseModel):
455+
nested_type = _resolve_model_type(annotation)
456+
if nested_type is not None:
440457
nested_source = value if isinstance(value, dict) else {}
441-
nested_persisted = _extract_persisted(field_type, nested_source)
458+
nested_persisted = _extract_persisted(nested_type, nested_source)
442459
if nested_persisted:
443460
persisted[field_name] = nested_persisted
444461
continue

tidy3d/config/registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from typing import Callable, Optional, TypeVar
66

7-
from pydantic.v1 import BaseModel
7+
from pydantic import BaseModel
88

99
T = TypeVar("T", bound=BaseModel)
1010

0 commit comments

Comments
 (0)