1
1
from collections import OrderedDict
2
2
from collections .abc import Mapping
3
- from typing import Tuple
3
+ from typing import List , Tuple
4
4
5
5
6
6
class Serializable :
7
- _state_dict_all_req_keys : Tuple = ()
8
- _state_dict_one_of_opt_keys : Tuple = ()
7
+ _state_dict_all_req_keys : Tuple [str , ...] = ()
8
+ _state_dict_one_of_opt_keys : Tuple [Tuple [str , ...], ...] = ((),)
9
+
10
+ def __init__ (self ) -> None :
11
+ self ._state_dict_user_keys : List [str ] = []
12
+
13
+ @property
14
+ def state_dict_user_keys (self ) -> List :
15
+ return self ._state_dict_user_keys
9
16
10
17
def state_dict (self ) -> OrderedDict :
11
18
raise NotImplementedError
@@ -19,6 +26,21 @@ def load_state_dict(self, state_dict: Mapping) -> None:
19
26
raise ValueError (
20
27
f"Required state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
21
28
)
22
- opts = [k in state_dict for k in self ._state_dict_one_of_opt_keys ]
23
- if len (opts ) > 0 and ((not any (opts )) or (all (opts ))):
24
- raise ValueError (f"state_dict should contain only one of '{ self ._state_dict_one_of_opt_keys } ' keys" )
29
+
30
+ # Handle groups of one-of optional keys
31
+ for one_of_opt_keys in self ._state_dict_one_of_opt_keys :
32
+ if len (one_of_opt_keys ) > 0 :
33
+ opts = [k in state_dict for k in one_of_opt_keys ]
34
+ num_present = sum (opts )
35
+ if num_present == 0 :
36
+ raise ValueError (f"state_dict should contain at least one of '{ one_of_opt_keys } ' keys" )
37
+ if num_present > 1 :
38
+ raise ValueError (f"state_dict should contain only one of '{ one_of_opt_keys } ' keys" )
39
+
40
+ # Check user keys
41
+ if hasattr (self , "_state_dict_user_keys" ) and isinstance (self ._state_dict_user_keys , list ):
42
+ for k in self ._state_dict_user_keys :
43
+ if k not in state_dict :
44
+ raise ValueError (
45
+ f"Required user state attribute '{ k } ' is absent in provided state_dict '{ state_dict .keys ()} '"
46
+ )
0 commit comments