-
Notifications
You must be signed in to change notification settings - Fork 133
/
Copy pathsyncbn.py
53 lines (43 loc) · 1.91 KB
/
syncbn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import torch.distributed as dist
from torch.autograd import Function
from torch.nn.modules.batchnorm import _BatchNorm
class sync_batch_norm(Function):
"""
A version of batch normalization that aggregates the activation statistics across all processes.
This needs to be a custom autograd.Function, because you also need to communicate between processes
on the backward pass (each activation affects all examples, so loss gradients from all examples affect
the gradient for each activation).
For a quick tutorial on torch.autograd.function, see
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
"""
@staticmethod
def forward(ctx, input, running_mean, running_std, eps: float, momentum: float):
# Compute statistics, sync statistics, apply them to the input
# Also, store relevant quantities to be used on the backward pass with `ctx.save_for_backward`
pass
@staticmethod
def backward(ctx, grad_output):
# don't forget to return a tuple of gradients wrt all arguments of `forward`!
pass
class SyncBatchNorm(_BatchNorm):
"""
Applies Batch Normalization to the input (over the 0 axis), aggregating the activation statistics
across all processes. You can assume that there are no affine operations in this layer.
"""
def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1):
super().__init__(
num_features,
eps,
momentum,
affine=False,
track_running_stats=True,
device=None,
dtype=None,
)
# your code here
self.running_mean = torch.zeros((num_features,))
self.running_std = torch.ones((num_features,))
def forward(self, input: torch.Tensor) -> torch.Tensor:
# You will probably need to use `sync_batch_norm` from above
pass