-
-
Notifications
You must be signed in to change notification settings - Fork 150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm training instability fix #675
base: main
Are you sure you want to change the base?
Changes from all commits
346e472
cc021f6
6bfa77f
e7705b6
7a99c7f
96f4a96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import warnings | ||
from collections.abc import Hashable, Sequence | ||
from typing import Optional, Union | ||
from typing import Literal, Optional, Union | ||
|
||
import jax | ||
import jax.lax as lax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Bool, Float, PRNGKeyArray | ||
from jaxtyping import Array, Float, Int, PRNGKeyArray | ||
|
||
from .._misc import default_floating_dtype | ||
from .._module import field | ||
|
@@ -40,28 +41,92 @@ class BatchNorm(StatefulLayer, strict=True): | |
statistics updated. During inference then just the running statistics are used. | ||
Whether the model is in training or inference mode should be toggled using | ||
[`equinox.nn.inference_mode`][]. | ||
|
||
With `approach = "batch"` during training the batch mean and variance are used | ||
for normalization. For inference the exponential running mean and ubiased | ||
variance are used for normalization in accordance with the cited paper below. | ||
Let `m` be momentum: | ||
|
||
$\text{TrainStats}_t = \text{BatchStats}_t$ | ||
|
||
$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i} | ||
\text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$ | ||
|
||
With `approach = "ema"` exponential running means and variances are kept. During | ||
training the batch statistics are used to fill in the running statistics until | ||
they are populated. In addition a linear iterpolation is used between the batch | ||
and running statistics over the `warmup_period`. During inference the running | ||
statistics are used for normalization: | ||
|
||
|
||
|
||
$\text{WarmupFrac}_t = \text{min} \left(1.0, \frac{t}{\text{WarmupPeriod}} \right)$ | ||
|
||
$\text{TrainStats}_t = (1.0 - \text{WarmupFrac}_t) * BatchStats_t + | ||
\text{WarmupFrac}_t * \left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}\text{BatchStats}_i$ | ||
|
||
$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i} | ||
\text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$ | ||
|
||
|
||
$\text{Note: } \frac{(1.0 - m)\sum_{i=0}^{t}m^{t-i}}{1.0 - m^{t+1}} = | ||
\frac{(1.0 - m)\sum_{i=0}^{t}m^{i}}{1.0 - m^{t+1}}$ | ||
$= \frac{(1.0 - m)\frac{1.0 - m^{t+1}}{1.0 - m}}{1.0 - m^{t+1}} = 1$ | ||
|
||
`approach = "ema_compatibility"` reproduces the original equinox BatchNorm | ||
behavior. It often results in training instabilities and `approach = "batch"` | ||
or `"ema"` is recommended. | ||
|
||
??? cite | ||
|
||
[Batch Normalization: Accelerating Deep Network Training by Reducing | ||
Internal Covariate Shift](https://arxiv.org/abs/1502.03167) | ||
|
||
```bibtex | ||
@article{DBLP:journals/corr/IoffeS15, | ||
author = {Sergey Ioffe and | ||
Christian Szegedy}, | ||
title = {Batch Normalization: Accelerating Deep Network Training | ||
by Reducing Internal Covariate Shift}, | ||
journal = {CoRR}, | ||
volume = {abs/1502.03167}, | ||
year = {2015}, | ||
url = {http://arxiv.org/abs/1502.03167}, | ||
eprinttype = {arXiv}, | ||
eprint = {1502.03167}, | ||
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, | ||
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
``` | ||
|
||
""" # noqa: E501 | ||
|
||
weight: Optional[Float[Array, "input_size"]] | ||
bias: Optional[Float[Array, "input_size"]] | ||
first_time_index: StateIndex[Bool[Array, ""]] | ||
count_index: StateIndex[Int[Array, ""]] | ||
state_index: StateIndex[ | ||
tuple[Float[Array, "input_size"], Float[Array, "input_size"]] | ||
] | ||
zero_frac_index: StateIndex[Float[Array, ""]] | ||
axis_name: Union[Hashable, Sequence[Hashable]] | ||
inference: bool | ||
input_size: int = field(static=True) | ||
approach: Literal["batch", "ema", "ema_compatibility"] = field(static=True) | ||
eps: float = field(static=True) | ||
channelwise_affine: bool = field(static=True) | ||
momentum: float = field(static=True) | ||
warmup_period: int = field(static=True) | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
axis_name: Union[Hashable, Sequence[Hashable]], | ||
approach: Optional[Literal["batch", "ema", "ema_compatibility"]] = None, | ||
eps: float = 1e-5, | ||
channelwise_affine: bool = True, | ||
momentum: float = 0.99, | ||
warmup_period: int = 1, | ||
inference: bool = False, | ||
dtype=None, | ||
): | ||
|
@@ -71,11 +136,17 @@ def __init__( | |
- `axis_name`: The name of the batch axis to compute statistics over, as passed | ||
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a | ||
tuple or a list) of names, to compute statistics over multiple named axes. | ||
- `approach`: The approach to use for the running statistics. If `approach=None` | ||
a warning will be raised and approach will default to `"ema_compatibility"`. | ||
During training `"batch"` only uses batch statisics while`"ema"` and | ||
`"ema_compatibility"` uses the running statistics. | ||
- `eps`: Value added to the denominator for numerical stability. | ||
- `channelwise_affine`: Whether the module has learnable channel-wise affine | ||
parameters. | ||
- `momentum`: The rate at which to update the running statistics. Should be a | ||
value between 0 and 1 exclusive. | ||
- `warmup_period`: The interpolation period between batch and running | ||
statistics. Only used when `approach=\"ema\"`. | ||
- `inference`: If `False` then the batch means and variances will be calculated | ||
and used to update the running statistics. If `True` then the running | ||
statistics are directly used for normalisation. This may be toggled with | ||
|
@@ -86,26 +157,46 @@ def __init__( | |
64-bit mode. | ||
""" | ||
|
||
if approach is None: | ||
warnings.warn( | ||
"BatchNorm approach is None, defaults to " | ||
'approach="ema_compatibility". This is not recommended as ' | ||
'it can lead to training instability. Use "batch" or ' | ||
'alternatively "ema" with appropriately selected warmup ' | ||
"instead." | ||
) | ||
approach = "ema_compatibility" | ||
|
||
valid_approaches = {"batch", "ema", "ema_compatibility"} | ||
if approach not in valid_approaches: | ||
raise ValueError(f"approach must be one of {valid_approaches}") | ||
self.approach = approach | ||
|
||
if warmup_period < 1: | ||
raise ValueError("warmup_period must be >= 1") | ||
|
||
if channelwise_affine: | ||
self.weight = jnp.ones((input_size,)) | ||
self.bias = jnp.zeros((input_size,)) | ||
else: | ||
self.weight = None | ||
self.bias = None | ||
self.first_time_index = StateIndex(jnp.array(True)) | ||
self.count_index = StateIndex(jnp.array(0, dtype=jnp.int32)) | ||
if dtype is None: | ||
dtype = default_floating_dtype() | ||
init_buffers = ( | ||
jnp.empty((input_size,), dtype=dtype), | ||
jnp.empty((input_size,), dtype=dtype), | ||
jnp.zeros((input_size,), dtype=dtype), | ||
jnp.zeros((input_size,), dtype=dtype), | ||
) | ||
self.state_index = StateIndex(init_buffers) | ||
self.zero_frac_index = StateIndex(jnp.array(1.0, dtype=dtype)) | ||
self.inference = inference | ||
self.axis_name = axis_name | ||
self.input_size = input_size | ||
self.eps = eps | ||
self.channelwise_affine = channelwise_affine | ||
self.momentum = momentum | ||
self.warmup_period = warmup_period | ||
|
||
@jax.named_scope("eqx.nn.BatchNorm") | ||
def __call__( | ||
|
@@ -143,7 +234,11 @@ def __call__( | |
if inference is None: | ||
inference = self.inference | ||
if inference: | ||
# renormalize running stats to account for the zeroed part | ||
zero_frac = state.get(self.zero_frac_index) | ||
running_mean, running_var = state.get(self.state_index) | ||
norm_mean = running_mean / jnp.maximum(1.0 - zero_frac, self.eps) | ||
norm_var = running_var / jnp.maximum(1.0 - zero_frac, self.eps) | ||
else: | ||
|
||
def _stats(y): | ||
|
@@ -154,16 +249,50 @@ def _stats(y): | |
var = jnp.maximum(0.0, var) | ||
return mean, var | ||
|
||
first_time = state.get(self.first_time_index) | ||
state = state.set(self.first_time_index, jnp.array(False)) | ||
|
||
momentum = self.momentum | ||
batch_mean, batch_var = jax.vmap(_stats)(x) | ||
zero_frac = state.get(self.zero_frac_index) | ||
running_mean, running_var = state.get(self.state_index) | ||
momentum = self.momentum | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
running_var = (1 - momentum) * batch_var + momentum * running_var | ||
running_mean = lax.select(first_time, batch_mean, running_mean) | ||
running_var = lax.select(first_time, batch_var, running_var) | ||
|
||
if self.approach == "ema": | ||
zero_frac = zero_frac * momentum | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
running_var = (1 - momentum) * batch_var + momentum * running_var | ||
warmup_count = state.get(self.count_index) | ||
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period) | ||
state = state.set(self.count_index, warmup_count) | ||
|
||
# fill in unpopulated part of running stats with batch stats | ||
warmup_frac = warmup_count / self.warmup_period | ||
norm_mean = zero_frac * batch_mean + running_mean | ||
norm_var = zero_frac * batch_var + running_var | ||
|
||
# apply warmup interpolation between batch and running statistics | ||
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean | ||
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var | ||
|
||
elif self.approach == "ema_compatibility": | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
running_var = (1 - momentum) * batch_var + momentum * running_var | ||
running_mean = lax.select(zero_frac == 1.0, batch_mean, running_mean) | ||
running_var = lax.select(zero_frac == 1.0, batch_var, running_var) | ||
norm_mean, norm_var = running_mean, running_var | ||
zero_frac = 0.0 * zero_frac | ||
|
||
else: | ||
zero_frac = zero_frac * momentum | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
# calculate unbiased variance for saving | ||
axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm using this to get the length of the "batch" axis - but not sure it's the best / correct way There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is the correct way! IIRC |
||
debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps) | ||
running_var = ( | ||
1 - momentum | ||
) * debias_coef * batch_var + momentum * running_var | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I neglected to use unbiased variance so corrected that here |
||
|
||
# just use batch statistics when not in inference mode | ||
norm_mean, norm_var = batch_mean, batch_var | ||
|
||
state = state.set(self.zero_frac_index, zero_frac) | ||
state = state.set(self.state_index, (running_mean, running_var)) | ||
|
||
def _norm(y, m, v, w, b): | ||
|
@@ -172,5 +301,5 @@ def _norm(y, m, v, w, b): | |
out = out * w + b | ||
return out | ||
|
||
out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) | ||
out = jax.vmap(_norm)(x, norm_mean, norm_var, self.weight, self.bias) | ||
return out, state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not completely obvious to me that the
ema
implementation, with default arguments, reproduces the previous behaviour. (For example, we havewarmup_period=1000
by default?)Can you add some comments explaining what each
approach
corresponds to?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.
Added some to the documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, it was different enough that I added it as
"ema_compatibility"
. I changed the warning to rather strongly recommend against using"ema_compatibility"
. I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed