11# Author: pstefanou12@
22"""Pydantic configuration models for delphi.ai algorithms."""
33
4- from __future__ import annotations
4+ import pydantic
55
6- from pydantic import BaseModel , ConfigDict , Field , model_validator
76
8-
9- def make_config (args : dict | BaseModel , config_class : type [BaseModel ]) -> BaseModel :
7+ def make_config (
8+ args : dict | pydantic .BaseModel , config_class : type [pydantic .BaseModel ]
9+ ) -> pydantic .BaseModel :
1010 """Construct a Pydantic config from a dict or return an existing config.
1111
1212 Args:
@@ -28,7 +28,7 @@ def make_config(args: dict | BaseModel, config_class: type[BaseModel]) -> BaseMo
2828 )
2929
3030
31- class TrainerConfig (BaseModel ):
31+ class TrainerConfig (pydantic . BaseModel ):
3232 """Configuration for the Trainer.
3333
3434 Attributes:
@@ -57,33 +57,33 @@ class TrainerConfig(BaseModel):
5757 checkpoint_every: Epoch frequency for periodic checkpoints; 0 disables.
5858 """
5959
60- model_config = ConfigDict (extra = "ignore" )
60+ model_config = pydantic . ConfigDict (extra = "ignore" )
6161
62- epochs : int | None = Field (default = None , ge = 1 )
63- iterations : int | None = Field (default = None , ge = 1 )
64- trials : int = Field (default = 1 , ge = 1 )
65- ema_decay : float = Field (default = 0.99 , ge = 0.0 , le = 1.0 )
66- tol : float = Field (default = 1e-3 , ge = 0.0 )
62+ epochs : int | None = pydantic . Field (default = None , ge = 1 )
63+ iterations : int | None = pydantic . Field (default = None , ge = 1 )
64+ trials : int = pydantic . Field (default = 1 , ge = 1 )
65+ ema_decay : float = pydantic . Field (default = 0.99 , ge = 0.0 , le = 1.0 )
66+ tol : float = pydantic . Field (default = 1e-3 , ge = 0.0 )
6767 early_stopping : bool = False
6868 verbose : bool = False
6969 disable_no_grad : bool = False
70- val_interval : int | None = Field (default = None , ge = 1 )
71- patience : int | None = Field (default = None , ge = 1 )
72- grad_tol : float = Field (default = 0.0 , ge = 0.0 )
73- grad_tol_window : int = Field (default = 1 , ge = 1 )
70+ val_interval : int | None = pydantic . Field (default = None , ge = 1 )
71+ patience : int | None = pydantic . Field (default = None , ge = 1 )
72+ grad_tol : float = pydantic . Field (default = 0.0 , ge = 0.0 )
73+ grad_tol_window : int = pydantic . Field (default = 1 , ge = 1 )
7474 loss_tol : float | None = None
75- log_every : int = Field (default = 50 , ge = 1 )
75+ log_every : int = pydantic . Field (default = 50 , ge = 1 )
7676 max_grad_norm : float | None = None
7777 tqdm : bool = False
7878 device : str = "cpu"
7979 use_amp : bool = False
80- accumulate_grad_batches : int = Field (default = 1 , ge = 1 )
81- record_params_every : int = Field (default = 0 , ge = 0 )
80+ accumulate_grad_batches : int = pydantic . Field (default = 1 , ge = 1 )
81+ record_params_every : int = pydantic . Field (default = 0 , ge = 0 )
8282 checkpoint_dir : str | None = None
83- checkpoint_every : int = Field (default = 0 , ge = 0 )
83+ checkpoint_every : int = pydantic . Field (default = 0 , ge = 0 )
8484
8585
86- class OptimizerConfig (BaseModel ):
86+ class OptimizerConfig (pydantic . BaseModel ):
8787 """Configuration for optimizers.
8888
8989 Attributes:
@@ -100,13 +100,13 @@ class OptimizerConfig(BaseModel):
100100 scheduler: Learning-rate scheduler type; None disables scheduling.
101101 """
102102
103- model_config = ConfigDict (extra = "ignore" )
103+ model_config = pydantic . ConfigDict (extra = "ignore" )
104104
105105 optimizer : str = "sgd"
106- lr : float = Field (default = 0.1 , gt = 0.0 )
107- momentum : float = Field (default = 0.0 , ge = 0.0 )
108- dampening : float = Field (default = 0.0 , ge = 0.0 )
109- weight_decay : float = Field (default = 0.0 , ge = 0.0 )
106+ lr : float = pydantic . Field (default = 0.1 , gt = 0.0 )
107+ momentum : float = pydantic . Field (default = 0.0 , ge = 0.0 )
108+ dampening : float = pydantic . Field (default = 0.0 , ge = 0.0 )
109+ weight_decay : float = pydantic . Field (default = 0.0 , ge = 0.0 )
110110 nesterov : bool = False
111111 maximize : bool = False
112112 foreach : bool | None = None
@@ -135,29 +135,29 @@ class TruncatedExponentialFamilyDistributionConfig(TrainerConfig, OptimizerConfi
135135 project: Enable per-step sublevel-set projection.
136136 """
137137
138- model_config = ConfigDict (extra = "ignore" )
138+ model_config = pydantic . ConfigDict (extra = "ignore" )
139139
140140 # Override parent defaults for distribution training.
141- tol : float = Field (default = 1e-1 , ge = 0.0 )
142- record_params_every : int = Field (default = 1 , ge = 1 )
143- epochs : int | None = Field (default = 1 , ge = 1 )
141+ tol : float = pydantic . Field (default = 1e-1 , ge = 0.0 )
142+ record_params_every : int = pydantic . Field (default = 1 , ge = 1 )
143+ epochs : int | None = pydantic . Field (default = 1 , ge = 1 )
144144
145145 # Distribution-specific fields.
146- val : float = Field (default = 0.2 , ge = 0.0 , le = 1.0 )
147- eps : float = Field (default = 1e-5 , gt = 0.0 )
148- min_radius : float = Field (default = 3.0 , ge = 0.0 )
149- max_radius : float = Field (default = 10.0 , ge = 0.0 )
150- rate : float = Field (default = 1.1 , gt = 1.0 )
151- batch_size : int = Field (default = 10 , ge = 1 )
152- num_samples : int = Field (default = 10000 , ge = 1 )
153- max_phases : int = Field (default = 1 , ge = 1 )
154- loss_convergence_tol : float = Field (default = 1e-3 , ge = 0.0 )
155- relative_loss_tol : float = Field (default = float ("inf" ), ge = 0.0 )
156- loss_increase_tol : float = Field (default = float ("inf" ), ge = 0.0 )
146+ val : float = pydantic . Field (default = 0.2 , ge = 0.0 , le = 1.0 )
147+ eps : float = pydantic . Field (default = 1e-5 , gt = 0.0 )
148+ min_radius : float = pydantic . Field (default = 3.0 , ge = 0.0 )
149+ max_radius : float = pydantic . Field (default = 10.0 , ge = 0.0 )
150+ rate : float = pydantic . Field (default = 1.1 , gt = 1.0 )
151+ batch_size : int = pydantic . Field (default = 10 , ge = 1 )
152+ num_samples : int = pydantic . Field (default = 10000 , ge = 1 )
153+ max_phases : int = pydantic . Field (default = 1 , ge = 1 )
154+ loss_convergence_tol : float = pydantic . Field (default = 1e-3 , ge = 0.0 )
155+ relative_loss_tol : float = pydantic . Field (default = float ("inf" ), ge = 0.0 )
156+ loss_increase_tol : float = pydantic . Field (default = float ("inf" ), ge = 0.0 )
157157 project : bool = True
158158
159- @model_validator (mode = "after" )
160- def check_radius (self ) -> TruncatedExponentialFamilyDistributionConfig :
159+ @pydantic . model_validator (mode = "after" )
160+ def check_radius (self ) -> " TruncatedExponentialFamilyDistributionConfig" :
161161 """Validate that min_radius does not exceed max_radius."""
162162 if self .min_radius > self .max_radius :
163163 raise ValueError (
@@ -166,10 +166,10 @@ def check_radius(self) -> TruncatedExponentialFamilyDistributionConfig:
166166 )
167167 return self
168168
169- @model_validator (mode = "after" )
169+ @pydantic . model_validator (mode = "after" )
170170 def resolve_epochs_iterations (
171171 self ,
172- ) -> TruncatedExponentialFamilyDistributionConfig :
172+ ) -> " TruncatedExponentialFamilyDistributionConfig" :
173173 """Clear the default epochs when iterations is explicitly provided.
174174
175175 The trainer uses exactly one stopping criterion. When the user
@@ -194,5 +194,5 @@ class TruncatedMultivariateNormalConfig(TruncatedExponentialFamilyDistributionCo
194194 covariance matrix parameter; falls back to lr when None.
195195 """
196196
197- eigenvalue_lower_bound : float = Field (default = 1e-2 , gt = 0.0 )
198- covariance_matrix_lr : float | None = Field (default = None , gt = 0.0 )
197+ eigenvalue_lower_bound : float = pydantic . Field (default = 1e-2 , gt = 0.0 )
198+ covariance_matrix_lr : float | None = pydantic . Field (default = None , gt = 0.0 )
0 commit comments