@@ -146,51 +146,53 @@ def __init__(
146
146
'`cpu_offload` should be `None`, `bool`'
147
147
f'or `CPUOffload`, but has type { type (cpu_offload )} ' )
148
148
149
- if isinstance (auto_wrap_policy , str ):
150
- auto_wrap_policy = FUNCTIONS .get ( # type: ignore
151
- auto_wrap_policy )
152
- if auto_wrap_policy is None :
153
- raise ValueError ('`auto_wrap_policy` is not registered!' )
154
- elif isinstance (auto_wrap_policy , dict ):
155
- policy = auto_wrap_policy .pop ('type' )
156
- if isinstance (policy , str ):
157
- policy = FUNCTIONS .get (policy ) # type: ignore
158
- if policy is None :
159
- raise ValueError ('`auto_wrap_policy` is not registered!' )
160
- auto_wrap_policy = partial (policy , ** auto_wrap_policy )
161
-
162
- if not (auto_wrap_policy is None
163
- or callable (auto_wrap_policy )): # type: ignore
164
- raise TypeError ('`auto_wrap_policy` should be a str, a '
165
- 'callable, a dict or None, but has type '
166
- f'{ type (auto_wrap_policy )} ' )
167
-
168
- if isinstance (backward_prefetch , str ):
169
- backward_prefetch = BackwardPrefetch [backward_prefetch ]
170
- if not (isinstance (backward_prefetch , BackwardPrefetch )
171
- or backward_prefetch is None ):
172
- raise TypeError (
173
- '`backward_prefetch` should be `None`, string of '
174
- '"BACKWARD_PRE" and "BACKWARD_POST", or '
175
- f'`BackwardPrefetch`, but has type { type (backward_prefetch )} ' )
176
-
177
- if isinstance (param_init_fn , str ):
178
- param_init_fn = FUNCTIONS .get ( # type: ignore
179
- param_init_fn )
180
- if param_init_fn is None :
181
- raise ValueError ('`param_init_fn` is not registered!' )
182
- elif isinstance (param_init_fn , dict ):
183
- init_fn = param_init_fn .pop ('type' )
149
+ with FUNCTIONS .switch_scope_and_registry (None ):
150
+ if isinstance (auto_wrap_policy , str ):
151
+ auto_wrap_policy = FUNCTIONS .get ( # type: ignore
152
+ auto_wrap_policy )
153
+ if auto_wrap_policy is None :
154
+ raise ValueError ('`auto_wrap_policy` is not registered!' )
155
+ elif isinstance (auto_wrap_policy , dict ):
156
+ policy = auto_wrap_policy .pop ('type' )
157
+ if isinstance (policy , str ):
158
+ policy = FUNCTIONS .get (policy ) # type: ignore
159
+ if policy is None :
160
+ raise ValueError ('`auto_wrap_policy` is not registered!' )
161
+ auto_wrap_policy = partial (policy , ** auto_wrap_policy )
162
+
163
+ if not (auto_wrap_policy is None
164
+ or callable (auto_wrap_policy )): # type: ignore
165
+ raise TypeError ('`auto_wrap_policy` should be a str, a '
166
+ 'callable, a dict or None, but has type '
167
+ f'{ type (auto_wrap_policy )} ' )
168
+
169
+ if isinstance (backward_prefetch , str ):
170
+ backward_prefetch = BackwardPrefetch [backward_prefetch ]
171
+ if not (isinstance (backward_prefetch , BackwardPrefetch )
172
+ or backward_prefetch is None ):
173
+ raise TypeError (
174
+ '`backward_prefetch` should be `None`, string of '
175
+ '"BACKWARD_PRE" and "BACKWARD_POST", or '
176
+ f'`BackwardPrefetch`, but has type { type (backward_prefetch )} ' # noqa: E501
177
+ )
178
+
184
179
if isinstance (param_init_fn , str ):
185
- init_fn = FUNCTIONS .get (init_fn ) # type: ignore
186
- if init_fn is None :
187
- raise ValueError ('`param_init_fn` is not registered!' )
188
- param_init_fn = partial (init_fn , ** param_init_fn )
189
-
190
- if not (callable (param_init_fn ) or param_init_fn is None ):
191
- raise TypeError ('`param_init_fn` should be a str, a '
192
- 'callable, a dict or None, but has type '
193
- f'{ type (param_init_fn )} ' )
180
+ param_init_fn = FUNCTIONS .get ( # type: ignore
181
+ param_init_fn )
182
+ if param_init_fn is None :
183
+ raise ValueError ('`param_init_fn` is not registered!' )
184
+ elif isinstance (param_init_fn , dict ):
185
+ init_fn = param_init_fn .pop ('type' )
186
+ if isinstance (param_init_fn , str ):
187
+ init_fn = FUNCTIONS .get (init_fn ) # type: ignore
188
+ if init_fn is None :
189
+ raise ValueError ('`param_init_fn` is not registered!' )
190
+ param_init_fn = partial (init_fn , ** param_init_fn )
191
+
192
+ if not (callable (param_init_fn ) or param_init_fn is None ):
193
+ raise TypeError ('`param_init_fn` should be a str, a '
194
+ 'callable, a dict or None, but has type '
195
+ f'{ type (param_init_fn )} ' )
194
196
195
197
def parse_dtype (dtype ):
196
198
if dtype is None :
0 commit comments