@@ -157,11 +157,7 @@ def __init__(
157
157
self ._param_exprs = {} # cache for einsum expr
158
158
self ._tiny = torch .finfo (torch .bfloat16 ).tiny
159
159
self .rng = random .Random (1337 )
160
- if deterministic :
161
- # Use a Generator to try to be more deterministic across resume (save/load)
162
- self .torch_rng = torch .Generator ().manual_seed (1337 )
163
- else :
164
- self .torch_rng = None
160
+ self .deterministic = deterministic
165
161
166
162
# make compile optional (for bwd compat)
167
163
if has_dynamo :
@@ -178,7 +174,6 @@ def __init__(
178
174
def __getstate__ (self ):
179
175
_dict = super ().__getstate__ ()
180
176
_dict ["rng" ] = self .rng
181
- _dict ["torch_rng" ] = self .torch_rng
182
177
return _dict
183
178
184
179
def state_dict (self ) -> Dict [str , Any ]:
@@ -187,28 +182,21 @@ def state_dict(self) -> Dict[str, Any]:
187
182
188
183
# Add the generator state
189
184
optimizer_state ['rng_state' ] = self .rng .getstate ()
190
- if self .torch_rng is not None :
191
- optimizer_state ['torch_rng_state' ] = self .torch_rng .get_state ()
192
-
193
185
return optimizer_state
194
186
195
187
def load_state_dict (self , state_dict : Dict [str , Any ]) -> None :
196
188
# Extract and remove the RNG state from the state dict
197
189
rng_states = {}
198
190
if 'rng_state' in state_dict :
199
191
rng_states ['rng_state' ] = state_dict .pop ('rng_state' )
200
- if 'torch_rng_state' in state_dict :
201
- rng_states ['torch_rng_state' ] = state_dict .pop ('torch_rng_state' )
202
-
192
+
203
193
# Load the optimizer state
204
194
super ().load_state_dict (state_dict )
205
195
state_dict .update (rng_states ) # add back
206
196
207
197
# Restore the RNG state if it exists
208
198
if 'rng_state' in rng_states :
209
199
self .rng .setstate (rng_states ['rng_state' ])
210
- if 'torch_rng_state' in rng_states :
211
- self .torch_rng .set_state (rng_states ['torch_rng_state' ])
212
200
213
201
def __setstate__ (self , state ):
214
202
super ().__setstate__ (state )
@@ -317,15 +305,17 @@ def step(self, closure=None):
317
305
if do_update :
318
306
exprA , exprGs , _ = exprs
319
307
Q = state ["Q" ]
320
- if self .torch_rng is None :
321
- V = torch .randn_like (debiased_momentum , dtype = precond_dtype )
308
+ if self .deterministic :
309
+ torch_rng = torch .Generator (device = debiased_momentum .device )
310
+ torch_rng .manual_seed (self .rng .randint (0 , 2 ** 31 ))
322
311
else :
323
- # Restoring generator state to device is messy. For now,
324
- # we keep RNG on CPU, but this slows the optimizer down quite a bit.
325
- # FIXME Need a better approach
326
- V = torch .randn (
327
- debiased_momentum .shape , generator = self .torch_rng , dtype = precond_dtype , device = 'cpu' )
328
- V = V .to (debiased_momentum .device )
312
+ torch_rng = None
313
+ V = torch .randn (
314
+ debiased_momentum .shape ,
315
+ generator = torch_rng ,
316
+ dtype = precond_dtype ,
317
+ device = debiased_momentum .device ,
318
+ )
329
319
G = debiased_momentum if momentum_into_precond_update else grad
330
320
331
321
A , conjB = self ._calc_A_and_conjB (exprA , G , Q , V )
0 commit comments