-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathkfac_utils.py
More file actions
197 lines (168 loc) · 7.05 KB
/
kfac_utils.py
File metadata and controls
197 lines (168 loc) · 7.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import torch
import torch.nn as nn
import torch.nn.functional as F
def try_contiguous(x):
if not x.is_contiguous():
x = x.contiguous()
return x
def _extract_patches(x, kernel_size, stride, padding):
"""
:param x: The input feature maps. (batch_size, in_c, h, w)
:param kernel_size: the kernel size of the conv filter (tuple of two elements)
:param stride: the stride of conv operation (tuple of two elements)
:param padding: number of paddings. be a tuple of two elements
:return: (batch_size, out_h, out_w, in_c*kh*kw)
"""
if padding[0] + padding[1] > 0:
x = F.pad(x, (padding[1], padding[1], padding[0],
padding[0])).data # Actually check dims
x = x.unfold(2, kernel_size[0], stride[0])
x = x.unfold(3, kernel_size[1], stride[1])
x = x.transpose_(1, 2).transpose_(2, 3).contiguous()
x = x.view(
x.size(0), x.size(1), x.size(2),
x.size(3) * x.size(4) * x.size(5))
return x
def update_running_stat(aa, m_aa, stat_decay):
# using inplace operation to save memory!
if stat_decay == 1.0:
# Simple cumulative average
m_aa += aa
else:
# Exponential moving average
m_aa *= stat_decay / (1 - stat_decay)
m_aa += aa
m_aa *= (1 - stat_decay)
class ComputeMatGrad:
@classmethod
def __call__(cls, input, grad_output, layer):
if isinstance(layer, nn.Linear):
grad = cls.linear(input, grad_output, layer)
elif isinstance(layer, nn.Conv2d):
grad = cls.conv2d(input, grad_output, layer)
else:
raise NotImplementedError
return grad
@staticmethod
def linear(input, grad_output, layer):
"""
:param input: batch_size * input_dim
:param grad_output: batch_size * output_dim
:param layer: [nn.module] output_dim * input_dim
:return: batch_size * output_dim * (input_dim + [1 if with bias])
"""
with torch.no_grad():
if layer.bias is not None:
input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
input = input.unsqueeze(1)
grad_output = grad_output.unsqueeze(2)
grad = torch.bmm(grad_output, input)
return grad
@staticmethod
def conv2d(input, grad_output, layer):
"""
:param input: batch_size * in_c * in_h * in_w
:param grad_output: batch_size * out_c * h * w
:param layer: nn.module batch_size * out_c * (in_c*k_h*k_w + [1 if with bias])
:return:
"""
with torch.no_grad():
input = _extract_patches(input, layer.kernel_size, layer.stride, layer.padding)
input = input.view(-1, input.size(-1)) # b * hw * in_c*kh*kw
grad_output = grad_output.transpose(1, 2).transpose(2, 3)
grad_output = try_contiguous(grad_output).view(grad_output.size(0), -1, grad_output.size(-1))
# b * hw * out_c
if layer.bias is not None:
input = torch.cat([input, input.new(input.size(0), 1).fill_(1)], 1)
input = input.view(grad_output.size(0), -1, input.size(-1)) # b * hw * in_c*kh*kw
grad = torch.einsum('abm,abn->amn', (grad_output, input))
return grad
class ComputeCovA:
@classmethod
def compute_cov_a(cls, a, layer):
return cls.__call__(a, layer)
@classmethod
def __call__(cls, a, layer):
if isinstance(layer, nn.Linear):
cov_a = cls.linear(a, layer)
elif isinstance(layer, nn.Conv2d):
cov_a = cls.conv2d(a, layer)
else:
# FIXME(CW): for extension to other layers.
# raise NotImplementedError
cov_a = None
return cov_a
@staticmethod
def conv2d(a, layer):
batch_size = a.size(0)
a = _extract_patches(a, layer.kernel_size, layer.stride, layer.padding)
spatial_size = a.size(1) * a.size(2)
a = a.view(-1, a.size(-1))
if layer.bias is not None:
a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
a = a/spatial_size
# FIXME(CW): do we need to divide the output feature map's size?
return a.t() @ (a / batch_size)
@staticmethod
def linear(a, layer):
# a: (batch_size, ..., in_dim) where ... can be sequence_length for transformers
# Flatten all dimensions except the last one to treat each token as an independent sample.
a = a.reshape(-1, a.size(-1))
# N_eff = B * T, which is a.size(0) after reshaping
batch_size = a.size(0)
if layer.bias is not None:
a = torch.cat([a, a.new(a.size(0), 1).fill_(1)], 1)
# Normalize by the number of effective samples
return a.t() @ (a / batch_size)
class ComputeCovG:
@classmethod
def compute_cov_g(cls, g, layer, batch_averaged=False):
"""
:param g: gradient
:param layer: the corresponding layer
:param batch_averaged: if the gradient is already averaged with the batch size?
:return:
"""
# batch_size = g.size(0)
return cls.__call__(g, layer, batch_averaged)
@classmethod
def __call__(cls, g, layer, batch_averaged):
if isinstance(layer, nn.Conv2d):
cov_g = cls.conv2d(g, layer, batch_averaged)
elif isinstance(layer, nn.Linear):
cov_g = cls.linear(g, layer, batch_averaged)
else:
cov_g = None
return cov_g
@staticmethod
def conv2d(g, layer, batch_averaged):
# g: batch_size * n_filters * out_h * out_w
# n_filters is actually the output dimension (analogous to Linear layer)
spatial_size = g.size(2) * g.size(3)
batch_size = g.shape[0]
g = g.transpose(1, 2).transpose(2, 3)
g = try_contiguous(g)
g = g.view(-1, g.size(-1))
if batch_averaged:
g = g * batch_size
g = g * spatial_size
cov_g = g.t() @ (g / g.size(0))
return cov_g
@staticmethod
def linear(g, layer, batch_averaged):
# g: (batch_size, ..., out_dim) where ... can be sequence_length for transformers
# Flatten all dimensions except the last one to treat each token as an independent sample.
g = g.reshape(-1, g.size(-1))
# N_eff = B * T, which is g.size(0) after reshaping
batch_size = g.size(0)
if batch_averaged:
# g is already averaged over the entire batch (B*T), so we need to rescale it
# to get the sum of gradients before computing the covariance.
# However, the standard KFAC formulation assumes g is the gradient for each sample.
# Let's stick to the non-batch-averaged case which is more standard.
cov_g = g.t() @ (g * batch_size) # This seems incorrect if g is already averaged.
# The correct formulation should be E[gg^T] - E[g]E[g]^T.
# Assuming E[g] is close to zero, we approximate this with E[gg^T].
# If g is the gradient for each of the N_eff samples, then E[gg^T] is (g^T @ g) / N_eff.
cov_g = g.t() @ (g / batch_size)
return cov_g