8
8
from collections .abc import Iterable
9
9
from copy import deepcopy
10
10
from pathlib import Path
11
- from typing import Any , Optional
11
+ from typing import Any , Optional , get_args , get_origin
12
12
13
- from pydantic . v1 import BaseModel
13
+ from pydantic import BaseModel
14
14
15
15
from tidy3d .log import log
16
16
@@ -61,7 +61,7 @@ def dict(self, *args, **kwargs): # type: ignore[override]
61
61
model = self ._manager ._get_model (self ._path )
62
62
if model is None :
63
63
return {}
64
- return model .dict (* args , ** kwargs )
64
+ return model .model_dump (* args , ** kwargs )
65
65
66
66
67
67
class PluginsAccessor :
@@ -380,7 +380,7 @@ def __setattr__(self, name: str, value: Any) -> None:
380
380
return
381
381
if name in self ._section_models :
382
382
if isinstance (value , BaseModel ):
383
- payload = value .dict (exclude_unset = False )
383
+ payload = value .model_dump (exclude_unset = False )
384
384
else :
385
385
payload = value
386
386
self .update_section (name , ** payload )
@@ -404,16 +404,33 @@ def _deep_get(tree: dict[str, Any], path: Iterable[str]) -> Optional[dict[str, A
404
404
return node if isinstance (node , dict ) else None
405
405
406
406
407
+ def _resolve_model_type (annotation : Any ) -> Optional [type [BaseModel ]]:
408
+ """Return the first BaseModel subclass found in an annotation (if any)."""
409
+
410
+ if isinstance (annotation , type ) and issubclass (annotation , BaseModel ):
411
+ return annotation
412
+
413
+ origin = get_origin (annotation )
414
+ if origin is None :
415
+ return None
416
+
417
+ for arg in get_args (annotation ):
418
+ nested = _resolve_model_type (arg )
419
+ if nested is not None :
420
+ return nested
421
+ return None
422
+
423
+
407
424
def _serialize_value (value : Any ) -> Any :
408
425
if isinstance (value , BaseModel ):
409
- return value .dict (exclude_unset = False )
426
+ return value .model_dump (exclude_unset = False )
410
427
if hasattr (value , "get_secret_value" ):
411
428
return value .get_secret_value ()
412
429
return value
413
430
414
431
415
432
def _model_dict (model : BaseModel ) -> dict [str , Any ]:
416
- data = model .dict (exclude_unset = False )
433
+ data = model .model_dump (exclude_unset = False )
417
434
for key , value in list (data .items ()):
418
435
if hasattr (value , "get_secret_value" ):
419
436
data [key ] = value .get_secret_value ()
@@ -422,10 +439,10 @@ def _model_dict(model: BaseModel) -> dict[str, Any]:
422
439
423
440
def _extract_persisted (schema : type [BaseModel ], data : dict [str , Any ]) -> dict [str , Any ]:
424
441
persisted : dict [str , Any ] = {}
425
- for field_name , field in schema .__fields__ .items ():
426
- extra = getattr ( field .field_info , "extra" , {}) or {}
427
- schema_extra = extra . get ( "json_schema_extra" , {})
428
- persist = schema_extra .get ("persist" ) if isinstance (schema_extra , dict ) else False
442
+ for field_name , field in schema .model_fields .items ():
443
+ schema_extra = field .json_schema_extra or {}
444
+ annotation = field . annotation
445
+ persist = bool ( schema_extra .get ("persist" ) ) if isinstance (schema_extra , dict ) else False
429
446
if not persist :
430
447
continue
431
448
if field_name not in data :
@@ -435,10 +452,10 @@ def _extract_persisted(schema: type[BaseModel], data: dict[str, Any]) -> dict[st
435
452
persisted [field_name ] = None
436
453
continue
437
454
438
- field_type = getattr ( field , "type_" , None )
439
- if isinstance ( field_type , type ) and issubclass ( field_type , BaseModel ) :
455
+ nested_type = _resolve_model_type ( annotation )
456
+ if nested_type is not None :
440
457
nested_source = value if isinstance (value , dict ) else {}
441
- nested_persisted = _extract_persisted (field_type , nested_source )
458
+ nested_persisted = _extract_persisted (nested_type , nested_source )
442
459
if nested_persisted :
443
460
persisted [field_name ] = nested_persisted
444
461
continue
0 commit comments