-
-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm training instability fix #675
base: main
Are you sure you want to change the base?
Changes from 2 commits
346e472
cc021f6
6bfa77f
e7705b6
7a99c7f
96f4a96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
import warnings | ||
from collections.abc import Hashable, Sequence | ||
from typing import Optional, Union | ||
from typing import Literal, Optional, Union | ||
|
||
import jax | ||
import jax.lax as lax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array, Bool, Float, PRNGKeyArray | ||
from jaxtyping import Array, Float, Int, PRNGKeyArray | ||
|
||
from .._misc import default_floating_dtype | ||
from .._module import field | ||
|
@@ -44,24 +45,29 @@ class BatchNorm(StatefulLayer, strict=True): | |
|
||
weight: Optional[Float[Array, "input_size"]] | ||
bias: Optional[Float[Array, "input_size"]] | ||
first_time_index: StateIndex[Bool[Array, ""]] | ||
count_index: StateIndex[Int[Array, ""]] | ||
state_index: StateIndex[ | ||
tuple[Float[Array, "input_size"], Float[Array, "input_size"]] | ||
] | ||
zero_frac_index: StateIndex[Float[Array, ""]] | ||
axis_name: Union[Hashable, Sequence[Hashable]] | ||
inference: bool | ||
input_size: int = field(static=True) | ||
approach: Literal["batch", "ema"] = field(static=True) | ||
eps: float = field(static=True) | ||
channelwise_affine: bool = field(static=True) | ||
momentum: float = field(static=True) | ||
warmup_period: int = field(static=True) | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
axis_name: Union[Hashable, Sequence[Hashable]], | ||
approach: Optional[Literal["batch", "ema"]] = None, | ||
eps: float = 1e-5, | ||
channelwise_affine: bool = True, | ||
momentum: float = 0.99, | ||
warmup_period: int = 1000, | ||
inference: bool = False, | ||
dtype=None, | ||
): | ||
|
@@ -71,11 +77,17 @@ def __init__( | |
- `axis_name`: The name of the batch axis to compute statistics over, as passed | ||
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a | ||
tuple or a list) of names, to compute statistics over multiple named axes. | ||
- `approach`: The approach to use for the running statistics. If `approach=None` | ||
a warning will be raised and approach will default to `"batch"`. During | ||
training `"batch"` only uses batch statisics while`"ema"` uses the running | ||
statistics. | ||
- `eps`: Value added to the denominator for numerical stability. | ||
- `channelwise_affine`: Whether the module has learnable channel-wise affine | ||
parameters. | ||
- `momentum`: The rate at which to update the running statistics. Should be a | ||
value between 0 and 1 exclusive. | ||
- `warmup_period`: The period to warm up the running statistics. Only used when | ||
`approach=\"ema\"`. | ||
- `inference`: If `False` then the batch means and variances will be calculated | ||
and used to update the running statistics. If `True` then the running | ||
statistics are directly used for normalisation. This may be toggled with | ||
|
@@ -86,26 +98,37 @@ def __init__( | |
64-bit mode. | ||
""" | ||
|
||
if approach is None: | ||
warnings.warn('BatchNorm approach is None, defaults to approach="batch"') | ||
approach = "batch" | ||
|
||
valid_approaches = {"batch", "ema"} | ||
if approach not in valid_approaches: | ||
raise ValueError(f"approach must be one of {valid_approaches}") | ||
self.approach = approach | ||
|
||
if channelwise_affine: | ||
self.weight = jnp.ones((input_size,)) | ||
self.bias = jnp.zeros((input_size,)) | ||
else: | ||
self.weight = None | ||
self.bias = None | ||
self.first_time_index = StateIndex(jnp.array(True)) | ||
self.count_index = StateIndex(jnp.array(0, dtype=jnp.int32)) | ||
if dtype is None: | ||
dtype = default_floating_dtype() | ||
init_buffers = ( | ||
jnp.empty((input_size,), dtype=dtype), | ||
jnp.empty((input_size,), dtype=dtype), | ||
jnp.zeros((input_size,), dtype=dtype), | ||
jnp.zeros((input_size,), dtype=dtype), | ||
) | ||
self.state_index = StateIndex(init_buffers) | ||
self.zero_frac_index = StateIndex(jnp.array(1.0, dtype=dtype)) | ||
self.inference = inference | ||
self.axis_name = axis_name | ||
self.input_size = input_size | ||
self.eps = eps | ||
self.channelwise_affine = channelwise_affine | ||
self.momentum = momentum | ||
self.warmup_period = max(1, warmup_period) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warmup_period=0 seemed natural for off - Changed to just check and error out |
||
|
||
@jax.named_scope("eqx.nn.BatchNorm") | ||
def __call__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not completely obvious to me that the Can you add some comments explaining what each There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary. Added some to the documentation There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, it was different enough that I added it as |
||
|
@@ -143,7 +166,10 @@ def __call__( | |
if inference is None: | ||
inference = self.inference | ||
if inference: | ||
zero_frac = state.get(self.zero_frac_index) | ||
running_mean, running_var = state.get(self.state_index) | ||
norm_mean = running_mean / jnp.maximum(1.0 - zero_frac, self.eps) | ||
norm_var = running_var / jnp.maximum(1.0 - zero_frac, self.eps) | ||
else: | ||
|
||
def _stats(y): | ||
|
@@ -154,23 +180,35 @@ def _stats(y): | |
var = jnp.maximum(0.0, var) | ||
return mean, var | ||
|
||
first_time = state.get(self.first_time_index) | ||
state = state.set(self.first_time_index, jnp.array(False)) | ||
momentum = self.momentum | ||
zero_frac = state.get(self.zero_frac_index) | ||
zero_frac *= momentum | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Stylistic nit: I tend not to use the inplace operations in JAX code. This (a) fits with the functional style a bit better, and (b) emphasises that we're definitely falling back to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. makes sense, done |
||
state = state.set(self.zero_frac_index, zero_frac) | ||
|
||
batch_mean, batch_var = jax.vmap(_stats)(x) | ||
running_mean, running_var = state.get(self.state_index) | ||
momentum = self.momentum | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
running_var = (1 - momentum) * batch_var + momentum * running_var | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These don't appear to be used on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are used by the batch branch when we're in inference mode so they still need to be computed and stored |
||
running_mean = lax.select(first_time, batch_mean, running_mean) | ||
running_var = lax.select(first_time, batch_var, running_var) | ||
state = state.set(self.state_index, (running_mean, running_var)) | ||
|
||
if self.approach == "ema": | ||
warmup_count = state.get(self.count_index) | ||
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period) | ||
state = state.set(self.count_index, warmup_count) | ||
|
||
warmup_frac = warmup_count / self.warmup_period | ||
norm_mean = zero_frac * batch_mean + running_mean | ||
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean | ||
norm_var = zero_frac * batch_var + running_var | ||
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm definitely going to have to sit down and grok what's going on here more carefully! As above it would be good to have some comments / docstrings / references / etc. describing what each approach is meant to do. (C.f. something like the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added some commentary and tried making it a bit cleaner. But overall batch mode should follow the cited paper. Ema follows the prior behavior but changes the initialization of the running stats and adds interpolation so it can be stable while training. |
||
else: | ||
norm_mean, norm_var = batch_mean, batch_var | ||
|
||
def _norm(y, m, v, w, b): | ||
out = (y - m) / jnp.sqrt(v + self.eps) | ||
if self.channelwise_affine: | ||
out = out * w + b | ||
return out | ||
|
||
out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) | ||
out = jax.vmap(_norm)(x, norm_mean, norm_var, self.weight, self.bias) | ||
return out, state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So continuing from my previous comment -- probably the default should be
ema
ifapproach=None
.