-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkfac.py
853 lines (709 loc) · 35.5 KB
/
kfac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
import numpy as np
import torch
import torch.nn as nn
from torch.nn import Linear
from torch.nn import Conv1d, Conv2d, Conv3d
from torch.optim.optimizer import Optimizer
import torch.nn.functional as F
from contextlib import contextmanager
from typing import Dict, Iterable, List, Optional, Tuple, Union
class Lock(object):
def __init__(self) -> None:
self._entered: bool = False
@contextmanager
def __call__(self) -> None:
assert not self._entered
try:
self._entered = True
yield
finally:
self._entered = False
def __bool__(self) -> bool:
return self._entered
class FisherBlock(object):
def __init__(self, in_features: int, out_features: int, dtype: torch.dtype, device: torch.device, **kwargs):
self._in_features = in_features
self._out_features = out_features
self._dtype = dtype
self._device = device
self._activations_cov = MovingAverageVariable((in_features, in_features), dtype=dtype, device=device)
self._sensitivities_cov = MovingAverageVariable((out_features, out_features), dtype=dtype, device=device)
self._forward_lock = False
self._backward_lock = False
def setup(self, forward_lock: Lock, backward_lock: Lock, **kwargs) -> None:
self._forward_lock = forward_lock
self._backward_lock = backward_lock
def update_cov(self, cov_ema_decay: float = 1.0) -> None:
raise NotImplementedError()
def compute_damping(self, damping: torch.Tensor, normalization: float = None) -> Tuple[torch.Tensor, torch.Tensor]:
if normalization is not None:
maybe_normalized_damping = normalize_damping(damping, normalization)
else:
maybe_normalized_damping = damping
return compute_pi_adjusted_damping(
self.activation_covariance,
self.sensitivity_covariance,
maybe_normalized_damping ** 0.5
)
def full_fisher_block(self):
left_factor = self.activation_covariance
right_factor = self.sensitivity_covariance
return self._renorm_coeff * kronecker_product(left_factor, right_factor)
def reset(self) -> None:
self._activations_cov.reset()
self._sensitivities_cov.reset()
def grads_to_mat(self, grads: Iterable[torch.Tensor]) -> torch.Tensor:
raise NotImplementedError()
def mat_to_grads(self, mat_grads: torch.Tensor) -> torch.Tensor:
raise NotImplementedError()
def multiply(self, grads: Iterable[torch.Tensor], damping: torch.Tensor) -> Iterable[torch.Tensor]:
act_cov, sen_cov = self.activation_covariance, self.sensitivity_covariance
a_damp, s_damp = self.compute_damping(damping, self.renorm_coeff)
act_cov += torch.eye(act_cov.shape[0], device=a_damp.device) * a_damp
sen_cov += torch.eye(sen_cov.shape[0], device=a_damp.device) * s_damp
mat_grads = self.grads_to_mat(grads)
nat_grads = sen_cov @ mat_grads @ act_cov / self.renorm_coeff
return self.mat_to_grads(nat_grads)
def multiply_preconditioner(self, grads: Iterable[torch.Tensor], damping: torch.Tensor) -> Iterable[torch.Tensor]:
act_cov, sen_cov = self.activation_covariance, self.sensitivity_covariance
a_damp, s_damp = self.compute_damping(damping, self.renorm_coeff)
act_cov_inverse = inverse_by_cholesky(act_cov, a_damp)
sen_cov_inverse = inverse_by_cholesky(sen_cov, s_damp)
mat_grads = self.grads_to_mat(grads)
nat_grads = sen_cov_inverse @ mat_grads @ act_cov_inverse / self.renorm_coeff
return self.mat_to_grads(nat_grads)
@property
def is_static(self) -> bool:
return (self._in_features == 0) or (self._out_features == 0)
@property
def activation_covariance(self) -> torch.Tensor:
return self._activations_cov.value
@property
def sensitivity_covariance(self) -> torch.Tensor:
return self._sensitivities_cov.value
@property
def vars(self) -> Iterable[torch.Tensor]:
raise NotImplementedError()
@property
def grads(self) -> Iterable[torch.Tensor]:
return [tensor.grad for tensor in self.vars]
def set_gradients(self, new_grads):
for var, grad in zip(self.vars, new_grads):
var.grad.data = grad
@property
def renorm_coeff(self) -> float:
return 1.
class ExtensionFisherBlock(FisherBlock):
def __init__(self, module: torch.nn.Module, **kwargs) -> None:
super().__init__(**kwargs)
self.module = module
self.forward_hook_handle = None
self.backward_hook_handle = None
self.forward_hook_handle = self.module.register_forward_hook(self._forward_hook_wrapper)
self.backward_hook_handle = self.module.register_backward_hook(self._backward_hook_wrapper)
def _forward_hook_wrapper(self, *args):
if self._forward_lock:
return self.forward_hook(*args)
def _backward_hook_wrapper(self, *args):
if self._backward_lock:
return self.backward_hook(*args)
def forward_hook(self, module: torch.nn.Module, *args):
raise NotImplementedError()
def backward_hook(self, module: torch.nn.Module, *args):
raise NotImplementedError()
class Identity(FisherBlock):
def __init__(self, module: torch.nn.Module, **kwargs) -> None:
self.module = module
super().__init__(
in_features=0,
out_features=0,
dtype=torch.get_default_dtype(),
device='cpu',
**kwargs)
def update_cov(self, cov_ema_decay: float = 1.0):
return
def multiply(self, grads: Iterable[torch.Tensor], damping: torch.Tensor) -> Iterable[torch.Tensor]:
return grads
def multiply_preconditioner(self, grads: Iterable[torch.Tensor], damping: torch.Tensor) -> Iterable[torch.Tensor]:
return grads
@property
def normalization_factor(self):
return 1.
@property
def vars(self) -> Iterable[torch.Tensor]:
return tuple()
class ConvFisherBlock(ExtensionFisherBlock):
def __init__(
self,
module: Union[Conv1d, Conv2d, Conv3d],
**kwargs
) -> None:
in_features = np.prod(module.kernel_size) * module.in_channels + int(module.bias is not None)
out_features = module.out_channels
super().__init__(
module=module,
in_features=in_features,
out_features=out_features,
dtype=module.weight.dtype,
device=module.weight.device,
**kwargs
)
self.n_dim = len(module.kernel_size)
self._activations = None
self._sensitivities = None
self._center = False
@torch.no_grad()
def forward_hook(self, module: Union[Conv1d, Conv2d, Conv3d], inp: torch.Tensor, out: torch.Tensor) -> None:
self._activations = self.extract_patches(inp[0])
@torch.no_grad()
def backward_hook(self, module: Union[Conv1d, Conv2d, Conv3d], grad_inp: torch.Tensor, grad_out: torch.Tensor) -> None:
self._sensitivities = grad_out[0].transpose(1, -1).contiguous()
# Reshape to (batch_size, n_spatial_locations, n_out_features)
self._sensitivities = self._sensitivities.view(
self._sensitivities.shape[0],
-1,
self._sensitivities.shape[-1]
) #* self._sensitivities.shape[0]
# in the original code they scale by the batch_size, I don't quite understand this
# sometimes it boosts the performance sometimes it hurts
def setup(self, center: bool = False, **kwargs) -> None:
self._center = center
super().setup(**kwargs)
def update_cov(self, cov_ema_decay: float = 1.0) -> None:
if self._activations is None or self._sensitivities is None:
return
act, sen = self._activations, self._sensitivities
act = act.reshape(-1, act.shape[-1])
sen = sen.reshape(-1, sen.shape[-1])
if self._center:
act = center(act)
sen = center(sen)
if self.has_bias:
act = append_homog(act)
activation_cov = compute_cov(act)
sensitivity_cov = compute_cov(sen)
self._activations_cov.add_to_average(activation_cov, decay=cov_ema_decay)
self._sensitivities_cov.add_to_average(sensitivity_cov, decay=cov_ema_decay)
def grads_to_mat(self, grads: Iterable[torch.Tensor]) -> torch.Tensor:
if self.has_bias:
weights, bias = grads
# reshape to (out_features, in_features)
weights = weights.view(weights.shape[0], -1)
mat_grads = torch.cat([weights, bias[:, None]], -1)
else:
# reshape to (out_features, in_features)
mat_grads = grads[0].view(grads.shape[0], -1)
return mat_grads
def mat_to_grads(self, mat_grads: torch.Tensor) -> torch.Tensor:
if self.has_bias:
return mat_grads[:, :-1].view_as(self.module.weight), mat_grads[:, -1]
else:
return mat_grads.view_as(self.module.weight),
@property
def renorm_coeff(self) -> float:
return self._activations.shape[1]
@property
def has_bias(self) -> bool:
return self.module.bias is not None
@property
def vars(self) -> Iterable[torch.Tensor]:
if self.has_bias:
return (self.module.weight, self.module.bias)
else:
return (self.module.weight,)
def extract_patches(self, x: torch.Tensor) -> torch.Tensor:
# Extract convolutional patches
# Input: (batch_size, in_channels, spatial_dim1, ...)
# Add padding
if sum(self.module.padding) > 0:
padding_mode = self.module.padding_mode
if padding_mode == 'zeros':
padding_mode = 'constant'
x = F.pad(x, tuple(pad for pad in self.module.padding[::-1] for _ in range(2)), mode=padding_mode, value=0.)
# Unfold the convolution
for i, (size, stride) in enumerate(zip(self.module.kernel_size, self.module.stride)):
x = x.unfold(i+2, size, stride)
# Move in_channels to the end
# https://github.com/pytorch/pytorch/issues/36048
x = x.unsqueeze(2+self.n_dim).transpose(1, 2+self.n_dim).squeeze(1)
# Make the memory contiguous
x = x.contiguous()
# Return the shape (batch_size, n_spatial_locations, n_in_features)
x = x.view(
x.shape[0],
np.prod([x.shape[1+i] for i in range(self.n_dim)]),
-1
)
return x
class FullyConnectedFisherBlock(ExtensionFisherBlock):
def __init__(self, module: Linear, **kwargs) -> None:
super().__init__(
module=module,
in_features=module.in_features + int(module.bias is not None),
out_features=module.out_features,
dtype=module.weight.dtype,
device=module.weight.device,
**kwargs)
self._activations = None
self._sensitivities = None
self._center = False
@torch.no_grad()
def forward_hook(self, module: Linear, inp: torch.Tensor, out: torch.Tensor) -> None:
self._activations = inp[0].detach().clone().reshape(-1, self._in_features - self.has_bias).requires_grad_(False)
@torch.no_grad()
def backward_hook(self, module: Linear, grad_inp: torch.Tensor, grad_out: torch.Tensor) -> None:
self._sensitivities = grad_out[0].clone().detach().reshape(-1, self._out_features).requires_grad_(False) * grad_out[0].shape[0]
def setup(self, center: bool = False, **kwargs):
super().setup(**kwargs)
self._center = center
def update_cov(self, cov_ema_decay: float = 1.0) -> None:
if self._activations is None or self._sensitivities is None:
return
act, sen = self._activations, self._sensitivities
if self._center:
act = center(act)
sen = center(sen)
if self.has_bias:
act = append_homog(act)
activation_cov = compute_cov(act)
sensitivity_cov = compute_cov(sen)
self._activations_cov.add_to_average(activation_cov, cov_ema_decay)
self._sensitivities_cov.add_to_average(sensitivity_cov, cov_ema_decay)
def grads_to_mat(self, grads: Iterable[torch.Tensor]) -> torch.Tensor:
if self.has_bias:
weights, bias = grads
mat_grads = torch.cat([weights, bias[:, None]], -1)
else:
mat_grads = grads[0]
return mat_grads
def mat_to_grads(self, mat_grads: torch.Tensor) -> torch.Tensor:
if self.has_bias:
return mat_grads[:, :-1], mat_grads[:, -1]
else:
return mat_grads,
@property
def has_bias(self) -> None:
return self.module.bias is not None
@property
def vars(self) -> Iterable[torch.Tensor]:
if self.has_bias:
return [self.module.weight, self.module.bias]
else:
return [self.module.weight]
class KFAC(object):
def __init__(self,
model: torch.nn.Module,
learning_rate: float,
damping: torch.Tensor,
adapt_damping: bool = False,
damping_adaptation_decay: float = 0.99,
damping_adaptation_interval: int = 5,
include_damping_in_qmodel_change: bool = False,
min_damping=1e-8,
cov_ema_decay: float = 0.95,
momentum: float = 0.9,
momentum_type: str = 'regular',
norm_constraint: Optional[float] = None,
weight_decay: Optional[float] = None,
l2_reg: float = 0.,
update_cov_manually: bool = False,
center: bool = False) -> None:
"""Creates the KFAC Optimizer object.
Args:
model (torch.nn.Module): A `torch.nn.Module` to optimize.
learning_rate (float): The initial learning rate
damping (torch.Tensor): This quantity times the identiy matrix is (approximately) added
to the matrix being estimated. - This relates to the "trust" in the second order approximation.
adapt_damping (bool, optional): If True we adapt the damping according to the Levenberg-Marquardt
rule described in Section 6.5 of the original K-FAC paper. The details of this scheme are controlled by
various additional arguments below. Defaults to False.
damping_adaptation_decay (float, optional): The `damping` parameter is multiplied by the `damping_adaption_decay`
every `damping_adaption_interval` number of iterations. Defaults to 0.99.
damping_adaptation_interval (int, optional): Number of steps in between updating the `damping` parameter. Defaults to 5.
include_damping_in_qmodel_change (bool, optional): If True the damping contribution is included in the quadratic model
for the purposes of computing qmodel_change in rho. Defaults to False.
min_damping ([type], optional): Minimum value the damping parameter can take. The default is quite arbitrary. Defaults to 1e-8.
cov_ema_decay (float): The decay factor used when calculating the covariance estimate moving averages. Defaults to 0.95.
momentum (float, optional): The momentum decay constant to use.. Defaults to 0.9.
momentum_type (str, optional): The type of momentum to use. Options: [`regular`, `adam`]. Defaults to 'regular'.
norm_constraint (Optional[float], optional): If specified, the update is scaled down so that its approximate squared
Fisher norm v^T F v is at most the specified value. May only be used with `regular` momentum. Defaults to None.
weight_decay (Optional[float], optional): The coefficient to use for weight decay. If set to `None` there is no
weight decay. Defaults to None.
l2_reg (float, optional): L2 normalization. Defaults to 0..
update_cov_manually (bool, optional): If set to `True`, the covariance matrices are not updated automatically at every
`.step()` call. You will have to call it manually using `.update_cov()`. This is useful in distributed settings
or when you want your covariances w.r.t. the model distribution rather than the loss function. Defaults to False.
center (bool, optional): If set to True the activations and sensitivities are centered. This is useful when dealing with
unnormalized distributions. Defaults to False.
"""
legal_momentum_types = ['regular', 'adam']
momentum_type = momentum_type.lower()
assert momentum_type in legal_momentum_types, f'{momentum_type} type momentum is not supported.'
assert momentum_type not in [
'regular' 'adam'] or norm_constraint is None, 'Norm constraint may only be used with regular momentum.'
self.model = model
self.blocks: List[FisherBlock] = []
self.learning_rate = learning_rate
self.counter = 0
device = next(model.parameters()).device
dtype = next(model.parameters()).dtype
self._damping = torch.tensor(damping, device=device, dtype=dtype)
self._adapt_damping = adapt_damping
self._damping_adaptation_decay = damping_adaptation_decay
self._damping_adaptation_interval = damping_adaptation_interval
self._omega = damping_adaptation_decay ** damping_adaptation_interval
self._include_damping_in_qmodel_change = include_damping_in_qmodel_change
self._qmodel_change = torch.tensor(np.nan, device=device, dtype=dtype)
self._prev_loss = torch.tensor(np.nan, device=device, dtype=dtype)
self._rho = torch.tensor(np.nan, device=device, dtype=dtype)
self._min_damping = min_damping
self._weight_decay = weight_decay
self._l2_reg = l2_reg
self._norm_constraint = norm_constraint
self._cov_ema_decay = cov_ema_decay
self._momentum = momentum
self._momentum_type = momentum_type
self.track_forward = Lock()
self.track_backward = Lock()
for module in model.modules():
self.blocks.append(
init_fisher_block(
module,
center=center,
forward_lock=self.track_forward,
backward_lock=self.track_backward
)
)
self._velocities: Dict[FisherBlock, Iterable[torch.Tensor]] = {}
self.update_cov_manually = update_cov_manually
def reset_cov(self) -> None:
for block in self.blocks:
block.reset_cov()
def _add_weight_decay(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]:
"""Applies weight decay.
"""
return tuple(
(tuple(grad + self._weight_decay*var for grad,
var in zip(grads, layer.vars)), layer)
for grads, layer in grads_and_layers
)
def _squared_fisher_norm(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]], precon_grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> float:
"""Computes the squared (approximate) Fisher norm of the updates.
This is defined as v^T F v, where F is the approximate Fisher matrix
as computed by the estimator, and v = F^{-1} g, where g is the gradient.
This is computed efficiently as v^T g.
"""
return inner_product_pairs(grads_and_layers, precon_grads_and_layers)
def _update_clip_coeff(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]], precon_grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> float:
"""Computes the scale factor for the update to satisfy the norm constraint.
Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint,
F is the approximate Fisher matrix, and r is the update vector, i.e.
-alpha * v, where alpha is the learning rate, and v is the preconditioned
gradient.
This is based on Section 5 of Ba et al., Distributed Second-Order
Optimization using Kronecker-Factored Approximations. Note that they
absorb the learning rate alpha (which they denote eta_max) into the formula
for the coefficient, while in our implementation, the rescaling is done
before multiplying by alpha. Hence, our formula differs from theirs by a
factor of alpha.
"""
sq_norm_grad = self._squared_fisher_norm(
grads_and_layers, precon_grads_and_layers)
sq_norm_up = sq_norm_grad * self.learning_rate**2
return torch.min(
torch.ones((), dtype=sq_norm_up.dtype, device=sq_norm_up.device),
torch.sqrt(self._norm_constraint / sq_norm_up)
)
def _clip_updates(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]], precon_grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]:
"""Rescales the preconditioned gradients to satisfy the norm constraint.
Rescales the preconditioned gradients such that the resulting update r
(after multiplying by the learning rate) will satisfy the norm constraint.
This constraint is that r^T F r <= C, where F is the approximate Fisher
matrix, and C is the norm_constraint attribute. See Section 5 of
Ba et al., Distributed Second-Order Optimization using Kronecker-Factored
Approximations.
"""
coeff = self._update_clip_coeff(
grads_and_layers, precon_grads_and_layers)
return scalar_product_pairs(coeff, precon_grads_and_layers)
def _multiply_preconditioner(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]:
return tuple((layer.multiply_preconditioner(grads, self.damping), layer) for (grads, layer) in grads_and_layers)
def _update_velocities(self, grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]], decay: float, vec_coeff=1.0) -> Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]:
def _update_velocity(grads, layer):
if layer not in self._velocities:
self._velocities[layer] = tuple(
torch.zeros_like(grad) for grad in grads)
velocities = self._velocities[layer]
for velocity, grad in zip(velocities, grads):
new_velocity = decay * velocity + vec_coeff * grad
velocity.data = new_velocity
return velocities
return tuple((_update_velocity(grads, layer), layer) for grads, layer in grads_and_layers)
def _compute_approx_qmodel_change(self, updates_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]], grads_and_layers: Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]) -> torch.Tensor:
quad_term = 0.5 * inner_product_pairs(updates_and_layers, tuple(
(layer.multiply(grads, self.damping), layer) for grads, layer in grads_and_layers))
linear_term = inner_product_pairs(updates_and_layers, grads_and_layers)
if not self._include_damping_in_qmodel_change:
quad_term -= 0.5 * self._sub_damping_out_qmodel_change_coeff * \
self.damping * linear_term
return quad_term + linear_term
def _get_raw_updates(self) -> Iterable[Tuple[Iterable[torch.Tensor], FisherBlock]]:
# Get grads
grads_and_layers = tuple((layer.grads, layer) for layer in self.blocks if any(
grad is not None for grad in layer.grads))
if self._momentum_type == 'regular':
# Multiple preconditioner
raw_updates_and_layers = self._multiply_preconditioner(
grads_and_layers)
# Apply "KL clipping"
if self.use_norm_constraint:
raw_updates_and_layers = self._clip_updates(
grads_and_layers, raw_updates_and_layers)
# Update velocities
if self.use_momentum:
raw_updates_and_layers = self._update_velocities(
raw_updates_and_layers, self._momentum)
# Do adaptive damping
if self._adapt_damping and self.is_damping_adaption_time:
updates_and_layers = scalar_product_pairs(
-self.learning_rate,
raw_updates_and_layers
)
self._qmodel_change = self._compute_approx_qmodel_change(
updates_and_layers, grads_and_layers)
elif self._momentum_type == 'adam':
# For adam like momentum we first compute the velocities and use the velocities also for KL clipping instead
# of computing the velocities at the very end.
# Update velocities
if self.use_momentum:
velocities_and_layers = self._update_velocities(
grads_and_layers, self._momentum)
else:
velocities_and_layers = grads_and_layers
# Multiply preconditioner
raw_updates_and_layers = self._multiply_preconditioner(
grads_and_layers)
# Apply "KL clipping"
if self.use_norm_constraint:
raw_updates_and_layers = self._clip_updates(
velocities_and_layers, raw_updates_and_layers)
# Do adaptive damping
if self._adapt_damping and self.is_damping_adaption_time:
# See https://github.com/tensorflow/kfac/blob/cf6265590944b5b937ff0ceaf4695a72c95a02b9/kfac/python/ops/optimizer.py#L1009
self._qmodel_change = 0.5 * self.learning_rate**2 * inner_product_pairs(raw_updates_and_layers, velocities_and_layers)\
- self.learning_rate * \
inner_product_pairs(
raw_updates_and_layers, grads_and_layers)
else:
raise NotImplementedError(
f'Momentum {self._momentum_type} is not supported yet.')
return raw_updates_and_layers
def _update_damping(self, loss):
# Adapts the damping parameter. KFAC Section 6.5
if not self._adapt_damping or not self.is_damping_adaption_time:
return
loss_change = loss - self._prev_loss
rho = loss_change / self._qmodel_change
should_decrease = (
loss_change < 0 and self._qmodel_change > 0) or rho > 0.75
should_increase = rho < 0.25
if should_decrease:
new_damping = self.damping * self._omega
elif should_increase:
new_damping = self.damping / self._omega
else:
new_damping = self.damping
new_damping = torch.clamp(
new_damping, min=self._min_damping + self._l2_reg)
self._damping = new_damping
self._rho = rho
@torch.no_grad()
def step(self, loss: Optional[torch.Tensor] = None) -> None:
if self._adapt_damping and loss is None:
raise ValueError(
'The loss must be passed if adaptive damping is used.')
# We update the damping before the optimization step
# This allows us to avoid multiple passes through the model
# and can leave the loss computation out of the optimizer.
# We shouldn't do this at the very first iteration! (Some variables will be nan)
self._update_damping(loss)
# Update covariance matrices
# We allow for manual updates in case we need more control over the optimization
# routine, e.g., when distributin KFAC
if not self.update_cov_manually:
self.update_cov()
raw_updates_and_layers = self._get_raw_updates()
# Apply weight decay
if self.use_weight_decay:
raw_updates_and_layers = self._add_weight_decay(
raw_updates_and_layers)
# Apply the new gradients
for precon_grad, layer in raw_updates_and_layers:
layer.set_gradients(precon_grad)
# Do gradient step - if any parameter gradient was not updated by its natural gradient
# this will fall back to the normal gradient.
for param in self.model.parameters():
if param.grad is not None:
param.add_(param.grad, alpha=-self.learning_rate)
# Cache previous loss
if loss is not None:
self._prev_loss = loss.clone()
self.counter += 1
def update_cov(self) -> None:
for layer in self.blocks:
layer.update_cov(cov_ema_decay=self._cov_ema_decay)
@property
def covariances(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
return [
(
block._activations_cov._var,
block._sensitivities_cov._var
)
for block in self.blocks
if not block.is_static
]
@covariances.setter
def covariances(self, new_covariances: List[Tuple[torch.Tensor, torch.Tensor]]) -> None:
for block, (a_cov, s_cov) in zip(filter(lambda a: not a.is_static, self.blocks), new_covariances):
block._activations_cov.value = a_cov.to(
block._activations_cov._var, non_blocking=True)
block._sensitivities_cov.value = s_cov.to(
block._sensitivities_cov._var, non_blocking=True)
@property
def damping(self) -> torch.Tensor:
return self._damping.clone()
@property
def use_weight_decay(self) -> bool:
return self._weight_decay is not None and self._weight_decay != 0.
@property
def use_norm_constraint(self) -> bool:
return self._norm_constraint is not None
@property
def use_momentum(self) -> bool:
return self._momentum_type in ['regular', 'adam'] and self._momentum != 0
@property
def is_damping_adaption_time(self) -> bool:
# We do *not* want to update at the first iteration as the previous loss is unknown!
return ((self.counter+1) % self._damping_adaptation_interval) == 0
@property
def _sub_damping_out_qmodel_change_coeff(self) -> float:
return 1.0 - self._l2_reg / self.damping
def init_fisher_block(module: nn.Module, **kwargs) -> FisherBlock:
if type(module) is nn.Linear:
layer = FullyConnectedFisherBlock(module)
elif type(module) in [nn.Conv1d, nn.Conv2d, nn.Conv3d]:
layer = ConvFisherBlock(module)
elif type(module) is FisherBlock:
layer = module
else:
layer = Identity(module)
layer.setup(**kwargs)
return layer
def inner_product(list1: Iterable[torch.Tensor], list2: Iterable[torch.Tensor]) -> torch.Tensor:
return sum((tensor1 * tensor2).sum() for tensor1, tensor2 in zip(list1, list2))
def inner_product_pairs(list1: Iterable[Tuple[Iterable[torch.Tensor], object]], list2: Iterable[Tuple[Iterable[torch.Tensor], object]]):
return inner_product(
tuple(tensor for tensors, _ in list1 for tensor in tensors),
tuple(tensor for tensors, _ in list2 for tensor in tensors)
)
def scalar_product_pairs(scalar, list_: Iterable[Tuple[Iterable[torch.Tensor], object]]) -> Iterable[Tuple[Iterable[torch.Tensor], object]]:
return tuple((tuple(scalar*item for item in items), var) for items, var in list_)
def compute_pi_tracenorm(left_cov: torch.Tensor, right_cov: torch.Tensor) -> torch.Tensor:
left_norm = torch.trace(left_cov) * right_cov.shape[0]
right_norm = torch.trace(right_cov) * left_cov.shape[0]
assert torch.all(right_norm > 0), "Pi computation, trace of right cov matrix should be positive!"
pi = torch.sqrt(left_norm / right_norm)
return pi
def compute_pi_adjusted_damping(left_cov: torch.Tensor, right_cov: torch.Tensor, damping: torch.Tensor):
pi = compute_pi_tracenorm(left_cov, right_cov)
return damping * pi, damping / pi
def inverse_by_cholesky(tensor: torch.Tensor, damping: torch.Tensor) -> torch.Tensor:
damped = tensor + torch.eye(tensor.shape[-1], device=tensor.device, dtype=tensor.dtype) * damping
cholesky = torch.cholesky(damped)
return torch.cholesky_inverse(cholesky)
def kronecker_product(mat1: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor:
m1, n1 = mat1.shape
mat1_rsh = mat1.reshape(m1, 1, n1, 1)
m2, n2 = mat2.shape
mat2_rsh = mat2.reshape(1, m1, 1, n1)
return (mat1_rsh * mat2_rsh).reshape(m1*m2, n1*n2)
def normalize_damping(damping: torch.Tensor, num_replications: float, normalize_damping_power: float = 1.) -> torch.Tensor:
if normalize_damping_power:
return damping / (num_replications ** normalize_damping_power)
return damping
class MovingAverageVariable(object):
def __init__(
self,
shape: torch.Size,
dtype: torch.dtype = None,
device: torch.device = None,
normalize_value: bool = True) -> None:
if dtype is None:
dtype = torch.get_default_dtype()
self._normalize_value = normalize_value
self._var = torch.zeros(
shape, dtype=dtype, device=device, requires_grad=False)
self._total_weight = torch.zeros(
(), dtype=dtype, device=device, requires_grad=False)
@property
def dtype(self) -> torch.dtype:
return self._var.dtype
@property
def value(self) -> torch.Tensor:
if self._normalize_value:
return self._var / self._total_weight
else:
return self._var.clone()
def add_to_average(self, value: torch.Tensor, decay: float = 1.0, weight: float = 1.0) -> None:
self._var *= decay
self._total_weight *= decay
self._var += value
self._total_weight += weight
def reset(self) -> None:
self._var = torch.zeros_like(self._var)
self._total_weight = torch.zeros_like(self._total_weight)
@value.setter
def value(self, new_value) -> None:
if self._normalize_value:
self._var.data = new_value * self._normalize_value
else:
self._var.data = new_value
def center(x: torch.Tensor) -> torch.Tensor:
return x - x.mean(0, keepdim=True)
def compute_cov(tensor: torch.Tensor, tensor_right: torch.Tensor = None, normalizer=None) -> torch.Tensor:
"""Compute the empirical second moment of the rows of a 2D Tensor.
This function is meant to be applied to random matrices for which the true row
mean is zero, so that the true second moment equals the true covariance.
Args:
tensor: A 2D Tensor.
tensor_right: An optional 2D Tensor. If provided, this function computes
the matrix product tensor^T * tensor_right instead of tensor^T * tensor.
normalizer: optional scalar for the estimator (by default, the normalizer is
the number of rows of tensor).
Returns:
A square 2D Tensor with as many rows/cols as the number of input columns.
"""
assert len(tensor.shape) == 2
if normalizer is None:
normalizer = tensor.shape[0]
if tensor_right is None:
cov = tensor.T @ tensor / normalizer
# Ensure it is symmetric
return (cov + cov.T) / 2.0
else:
return (tensor.T @ tensor_right) / normalizer
def append_homog(tensor: torch.Tensor, homog_value: float = 1.) -> torch.Tensor:
"""Appends a homogeneous coordinate to the last dimension of a Tensor.
Args:
tensor: A Tensor.
homog_value: Value to append as homogeneous coordinate to the last dimension
of `tensor`. (Default: 1.0)
Returns:
A Tensor identical to the input but one larger in the last dimension. The
new entries are filled with ones.
"""
shape = list(tensor.shape)
shape[-1] = 1
appendage = torch.ones(shape, dtype=tensor.dtype, device=tensor.device) * homog_value
return torch.cat([tensor, appendage], -1)