Skip to content

Commit a49b020

Browse files
committed
Merge branch 'ClashLuke-patch-1'
2 parents 8b3c07a + 490d222 commit a49b020

File tree

1 file changed

+12
-22
lines changed

1 file changed

+12
-22
lines changed

timm/optim/kron.py

+12-22
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,7 @@ def __init__(
157157
self._param_exprs = {} # cache for einsum expr
158158
self._tiny = torch.finfo(torch.bfloat16).tiny
159159
self.rng = random.Random(1337)
160-
if deterministic:
161-
# Use a Generator to try to be more deterministic across resume (save/load)
162-
self.torch_rng = torch.Generator().manual_seed(1337)
163-
else:
164-
self.torch_rng = None
160+
self.deterministic = deterministic
165161

166162
# make compile optional (for bwd compat)
167163
if has_dynamo:
@@ -178,7 +174,6 @@ def __init__(
178174
def __getstate__(self):
179175
_dict = super().__getstate__()
180176
_dict["rng"] = self.rng
181-
_dict["torch_rng"] = self.torch_rng
182177
return _dict
183178

184179
def state_dict(self) -> Dict[str, Any]:
@@ -187,28 +182,21 @@ def state_dict(self) -> Dict[str, Any]:
187182

188183
# Add the generator state
189184
optimizer_state['rng_state'] = self.rng.getstate()
190-
if self.torch_rng is not None:
191-
optimizer_state['torch_rng_state'] = self.torch_rng.get_state()
192-
193185
return optimizer_state
194186

195187
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
196188
# Extract and remove the RNG state from the state dict
197189
rng_states = {}
198190
if 'rng_state' in state_dict:
199191
rng_states['rng_state'] = state_dict.pop('rng_state')
200-
if 'torch_rng_state' in state_dict:
201-
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')
202-
192+
203193
# Load the optimizer state
204194
super().load_state_dict(state_dict)
205195
state_dict.update(rng_states) # add back
206196

207197
# Restore the RNG state if it exists
208198
if 'rng_state' in rng_states:
209199
self.rng.setstate(rng_states['rng_state'])
210-
if 'torch_rng_state' in rng_states:
211-
self.torch_rng.set_state(rng_states['torch_rng_state'])
212200

213201
def __setstate__(self, state):
214202
super().__setstate__(state)
@@ -317,15 +305,17 @@ def step(self, closure=None):
317305
if do_update:
318306
exprA, exprGs, _ = exprs
319307
Q = state["Q"]
320-
if self.torch_rng is None:
321-
V = torch.randn_like(debiased_momentum, dtype=precond_dtype)
308+
if self.deterministic:
309+
torch_rng = torch.Generator(device=debiased_momentum.device)
310+
torch_rng.manual_seed(self.rng.randint(0, 2 ** 31))
322311
else:
323-
# Restoring generator state to device is messy. For now,
324-
# we keep RNG on CPU, but this slows the optimizer down quite a bit.
325-
# FIXME Need a better approach
326-
V = torch.randn(
327-
debiased_momentum.shape, generator=self.torch_rng, dtype=precond_dtype, device='cpu')
328-
V = V.to(debiased_momentum.device)
312+
torch_rng = None
313+
V = torch.randn(
314+
debiased_momentum.shape,
315+
generator=torch_rng,
316+
dtype=precond_dtype,
317+
device=debiased_momentum.device,
318+
)
329319
G = debiased_momentum if momentum_into_precond_update else grad
330320

331321
A, conjB = self._calc_A_and_conjB(exprA, G, Q, V)

0 commit comments

Comments
 (0)