77
88import numpy as np
99import torch as ch
10+ from pydantic import BaseModel
1011from torch .optim import SGD , LBFGS , Adam , AdamW , lr_scheduler
1112
1213from delphi .utils .constants import (
1516 PythonFrameworks ,
1617 SchedulerType ,
1718)
18- from delphi .utils .defaults import (
19- check_and_fill_args ,
20- DELPHI_DEFAULTS ,
21- SGD_DEFAULTS ,
22- LBFGS_DEFAULTS ,
23- ADAM_DEFAULTS ,
24- ADAMW_DEFAULTS ,
25- )
26- from delphi .utils .helpers import AverageMeter , Parameters
19+ from delphi .utils .helpers import AverageMeter
2720
2821
2922class delphi (ch .nn .Module ): # pylint: disable=invalid-name,too-many-instance-attributes,abstract-method
@@ -40,22 +33,23 @@ class delphi(ch.nn.Module): # pylint: disable=invalid-name,too-many-instance-at
4033
4134 _OPTIMIZER_REGISTRY : ClassVar [dict [str , Callable ]] = {}
4235
43- def __init__ (self , args : Parameters ):
36+ def __init__ (self , args : BaseModel ):
4437 """Initialize the delphi model.
4538
4639 Args:
47- args: Hyperparameter object; see DELPHI_DEFAULTS for supported keys.
40+ args: Fully constructed Pydantic config. Concrete subclasses are
41+ responsible for converting a user-supplied dict to their
42+ specific config before calling super().__init__.
4843
4944 Raises:
50- TypeError: If args is not a Parameters instance .
45+ TypeError: If args is not a Pydantic BaseModel .
5146 """
5247 super ().__init__ ()
53- if not isinstance (args , Parameters ):
48+ if not isinstance (args , BaseModel ):
5449 raise TypeError (
55- f"args is type { type (args ).__name__ } ; "
56- "expected delphi.utils.helpers.Parameters"
50+ f"args is type { type (args ).__name__ } ; expected a pydantic.BaseModel."
5751 )
58- self .args : Parameters = check_and_fill_args ( args , DELPHI_DEFAULTS )
52+ self .args = args
5953
6054 self .optimizer : ch .optim .Optimizer | None = None
6155 self .schedule : lr_scheduler .LRScheduler | None = None
@@ -149,7 +143,6 @@ def _remove_none_config(self, config: dict) -> dict:
149143
150144 def _create_sgd (self , params : list [dict ]) -> SGD :
151145 """Create an SGD optimizer from args."""
152- check_and_fill_args (self .args , SGD_DEFAULTS )
153146 config = {
154147 "lr" : self .args .lr ,
155148 "momentum" : getattr (self .args , "momentum" , 0 ),
@@ -165,7 +158,6 @@ def _create_sgd(self, params: list[dict]) -> SGD:
165158
166159 def _create_lbfgs (self , params : list [dict ]) -> LBFGS :
167160 """Create an L-BFGS optimizer from args."""
168- check_and_fill_args (self .args , LBFGS_DEFAULTS )
169161 config = {
170162 "lr" : getattr (self .args , "lr" , 1.0 ),
171163 "max_iter" : getattr (self .args , "max_iter" , 20 ),
@@ -179,7 +171,6 @@ def _create_lbfgs(self, params: list[dict]) -> LBFGS:
179171
180172 def _create_adam (self , params : list [dict ]) -> Adam :
181173 """Create an Adam optimizer from args."""
182- check_and_fill_args (self .args , ADAM_DEFAULTS )
183174 config = {
184175 "lr" : getattr (self .args , "lr" , 1e-1 ),
185176 "betas" : (
@@ -199,7 +190,6 @@ def _create_adam(self, params: list[dict]) -> Adam:
199190
200191 def _create_adamw (self , params : list [dict ]) -> AdamW :
201192 """Create an AdamW optimizer from args."""
202- check_and_fill_args (self .args , ADAMW_DEFAULTS )
203193 config = {
204194 "lr" : getattr (self .args , "lr" , 1e-3 ),
205195 "betas" : (
0 commit comments