Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm training instability fix #675

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions equinox/nn/_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,61 @@ class BatchNorm(StatefulLayer, strict=True):
statistics updated. During inference then just the running statistics are used.
Whether the model is in training or inference mode should be toggled using
[`equinox.nn.inference_mode`][].

With `approach = "batch"` during training the batch mean and variance are used
for normalization. For inference the exponential running mean and ubiased
variance are used for normalization in accordance with the cited paper below:

$\text{TrainStats}_t = \text{BatchStats}_t$

$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}
\text{BatchStats}_i}{\text{max} \left(1.0 - m^{t+1}, \varepsilon \right)}$

With `approach = "ema"` exponential running means and variances are kept. During
training the batch statistics are used to fill in the running statistics until
they are populated. In addition a linear iterpolation is used between the batch
and running statistics over the `warmup_period`. During inference the running
statistics are used for normalization:



$\text{WarmupFrac}_t = \text{min} \left(1.0, \frac{t}{\text{WarmupPeriod}} \right)$

$\text{TrainStats}_t = (1.0 - \text{WarmupFrac}_t) * BatchStats_t +
\text{WarmupFrac}_t * \left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}\text{BatchStats}_i$

$\text{InferenceStats}_t = \frac{\left(1.0 - m\right)\sum_{i=0}^{t}m^{t-i}
\text{BatchStats}_i}{\text{Max} \left(1.0 - m^{t+1}, \varepsilon \right)}$


$\text{Note: } \frac{(1.0 - m)\sum_{i=0}^{t}m^{t-i}}{1.0 - m^{t+1}} =
\frac{(1.0 - m)\sum_{i=0}^{t}m^{i}}{1.0 - m^{t+1}}$
$= \frac{(1.0 - m)\frac{1.0 - m^{t+1}}{1.0 - m}}{1.0 - m^{t+1}} = 1$


??? cite

[Batch Normalization: Accelerating Deep Network Training by Reducing
Internal Covariate Shift](https://arxiv.org/abs/1502.03167)

```bibtex
@article{DBLP:journals/corr/IoffeS15,
author = {Sergey Ioffe and
Christian Szegedy},
title = {Batch Normalization: Accelerating Deep Network Training
by Reducing Internal Covariate Shift},
journal = {CoRR},
volume = {abs/1502.03167},
year = {2015},
url = {http://arxiv.org/abs/1502.03167},
eprinttype = {arXiv},
eprint = {1502.03167},
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}
```

""" # noqa: E501

weight: Optional[Float[Array, "input_size"]]
Expand All @@ -67,7 +122,7 @@ def __init__(
eps: float = 1e-5,
channelwise_affine: bool = True,
momentum: float = 0.99,
warmup_period: int = 1000,
warmup_period: int = 1,
inference: bool = False,
dtype=None,
):
Expand All @@ -86,8 +141,8 @@ def __init__(
parameters.
- `momentum`: The rate at which to update the running statistics. Should be a
value between 0 and 1 exclusive.
- `warmup_period`: The period to warm up the running statistics. Only used when
`approach=\"ema\"`.
- `warmup_period`: The interpolation period between batch and running
statistics. Only used when `approach=\"ema\"`.
- `inference`: If `False` then the batch means and variances will be calculated
and used to update the running statistics. If `True` then the running
statistics are directly used for normalisation. This may be toggled with
Expand All @@ -107,6 +162,9 @@ def __init__(
raise ValueError(f"approach must be one of {valid_approaches}")
self.approach = approach

if warmup_period < 1:
raise ValueError("warmup_period must be >= 1")

if channelwise_affine:
self.weight = jnp.ones((input_size,))
self.bias = jnp.zeros((input_size,))
Expand All @@ -128,7 +186,7 @@ def __init__(
self.eps = eps
self.channelwise_affine = channelwise_affine
self.momentum = momentum
self.warmup_period = max(1, warmup_period)
self.warmup_period = warmup_period

@jax.named_scope("eqx.nn.BatchNorm")
def __call__(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely obvious to me that the ema implementation, with default arguments, reproduces the previous behaviour. (For example, we have warmup_period=1000 by default?)

Can you add some comments explaining what each approach corresponds to?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.

Added some to the documentation

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, it was different enough that I added it as "ema_compatibility". I changed the warning to rather strongly recommend against using "ema_compatibility". I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed

Expand Down Expand Up @@ -182,16 +240,15 @@ def _stats(y):

momentum = self.momentum
zero_frac = state.get(self.zero_frac_index)
zero_frac *= momentum
zero_frac = zero_frac * momentum
state = state.set(self.zero_frac_index, zero_frac)

batch_mean, batch_var = jax.vmap(_stats)(x)
running_mean, running_var = state.get(self.state_index)
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
running_var = (1 - momentum) * batch_var + momentum * running_var
state = state.set(self.state_index, (running_mean, running_var))

if self.approach == "ema":
running_var = (1 - momentum) * batch_var + momentum * running_var
warmup_count = state.get(self.count_index)
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period)
state = state.set(self.count_index, warmup_count)
Expand All @@ -202,8 +259,15 @@ def _stats(y):
norm_var = zero_frac * batch_var + running_var
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var
else:
axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using this to get the length of the "batch" axis - but not sure it's the best / correct way

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the correct way! IIRC psum(1) is actually special-cased for this purpose.

debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps)
running_var = (
1 - momentum
) * debias_coef * batch_var + momentum * running_var
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I neglected to use unbiased variance so corrected that here

norm_mean, norm_var = batch_mean, batch_var

state = state.set(self.state_index, (running_mean, running_var))

def _norm(y, m, v, w, b):
out = (y - m) / jnp.sqrt(v + self.eps)
if self.channelwise_affine:
Expand Down
32 changes: 28 additions & 4 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,14 @@ def test_batch_norm(getkey):
bn = eqx.nn.BatchNorm(5, "batch")
assert bn.approach == "batch"

with pytest.raises(ValueError):
bn = eqx.nn.BatchNorm(5, "batch", approach="ema", warmup_period=0)

# Test initialization
bn_momentum = 0.99
bn = eqx.nn.BatchNorm(5, "batch", approach="ema", momentum=bn_momentum)
bn = eqx.nn.BatchNorm(
5, "batch", approach="ema", warmup_period=10, momentum=bn_momentum
)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
running_mean, running_var = state.get(bn.state_index)
Expand Down Expand Up @@ -890,8 +895,7 @@ def test_batch_norm(getkey):
assert running_mean.shape == (6,)
assert running_var.shape == (6,)

# Test that it normalises

# Test that approach=ema normalises
x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test
bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, approach="ema")
state = eqx.nn.State(bn)
Expand All @@ -902,6 +906,27 @@ def test_batch_norm(getkey):
)
assert jnp.allclose(out, true_out)

# Test that approach=batch normalises in training mode
bn = eqx.nn.BatchNorm(
5, "batch", channelwise_affine=False, approach="batch", momentum=0.9
)
state = eqx.nn.State(bn)
vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn(x1alt, state)
true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt(
jnp.var(x1alt, axis=0, keepdims=True) + 1e-5
)
assert jnp.allclose(out, true_out)
# Test that approach=batch normaises in inference mode
bn_inf = eqx.nn.inference_mode(bn, value=True)
vbn_inf = jax.vmap(bn_inf, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vbn_inf(x1alt, state)
debias_coef = x1alt.shape[0] / (x1alt.shape[0] - 1)
true_out = (x1alt - jnp.mean(x1alt, axis=0, keepdims=True)) / jnp.sqrt(
debias_coef * jnp.var(x1alt, axis=0, keepdims=True) + 1e-5
)
assert jnp.allclose(out, true_out)

# Test that the statistics update during training
out, state = vbn(x1, state)
running_mean, running_var = state.get(bn.state_index)
Expand All @@ -913,7 +938,6 @@ def test_batch_norm(getkey):
assert not jnp.allclose(running_var, running_var2)

# Test that the statistics don't update at inference

ibn = eqx.nn.inference_mode(bn, value=True)
vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None))
out, state = vibn(4 * x1 + 20, state)
Expand Down