-
-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm training instability fix #675
base: main
Are you sure you want to change the base?
Conversation
…bility when using running statistics.
equinox/nn/_batch_norm.py
Outdated
self.inference = inference | ||
self.axis_name = axis_name | ||
self.input_size = input_size | ||
self.eps = eps | ||
self.channelwise_affine = channelwise_affine | ||
self.momentum = momentum | ||
self.warmup_period = max(1, warmup_period) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the max
? Perhaps it would be better to just error out on values that are too small?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warmup_period=0 seemed natural for off - Changed to just check and error out
self.inference = inference | ||
self.axis_name = axis_name | ||
self.input_size = input_size | ||
self.eps = eps | ||
self.channelwise_affine = channelwise_affine | ||
self.momentum = momentum | ||
self.warmup_period = max(1, warmup_period) | ||
|
||
@jax.named_scope("eqx.nn.BatchNorm") | ||
def __call__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not completely obvious to me that the ema
implementation, with default arguments, reproduces the previous behaviour. (For example, we have warmup_period=1000
by default?)
Can you add some comments explaining what each approach
corresponds to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ema with warmup_period=1 approximately reproduces previous behavior. As I noted the start is different because of how the running statistics are initially populated. With warmup_period=1 there's no interpolation between the batch and running stats - the running stats are always used as with the previous behavior. I can give an exact replication with an extra approach if necessary.
Added some to the documentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think an exact replication is probably important for the default behaviour, just because I'd like to be sure that we're bit-for-bit backward compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, it was different enough that I added it as "ema_compatibility"
. I changed the warning to rather strongly recommend against using "ema_compatibility"
. I haven't found a use case where I wouldn't expect to see the instability (at least with a larger learning rate) but that could very much be due to a lack of imagination on my part.. That part can definitely change if needed
equinox/nn/_batch_norm.py
Outdated
state = state.set(self.first_time_index, jnp.array(False)) | ||
momentum = self.momentum | ||
zero_frac = state.get(self.zero_frac_index) | ||
zero_frac *= momentum |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistic nit: I tend not to use the inplace operations in JAX code. This (a) fits with the functional style a bit better, and (b) emphasises that we're definitely falling back to the zero_frac = zero_frac * momentum
interpretation of the syntax. (Gosh, why does Python has two different meanings for the same syntax?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense, done
equinox/nn/_batch_norm.py
Outdated
|
||
batch_mean, batch_var = jax.vmap(_stats)(x) | ||
running_mean, running_var = state.get(self.state_index) | ||
momentum = self.momentum | ||
running_mean = (1 - momentum) * batch_mean + momentum * running_mean | ||
running_var = (1 - momentum) * batch_var + momentum * running_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These don't appear to be used on the batch
branch. I think the lines here can be reorganised to keep each approach only using the things it needs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are used by the batch branch when we're in inference mode so they still need to be computed and stored
equinox/nn/_batch_norm.py
Outdated
warmup_count = state.get(self.count_index) | ||
warmup_count = jnp.minimum(warmup_count + 1, self.warmup_period) | ||
state = state.set(self.count_index, warmup_count) | ||
|
||
warmup_frac = warmup_count / self.warmup_period | ||
norm_mean = zero_frac * batch_mean + running_mean | ||
norm_mean = (1.0 - warmup_frac) * batch_mean + warmup_frac * norm_mean | ||
norm_var = zero_frac * batch_var + running_var | ||
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm definitely going to have to sit down and grok what's going on here more carefully! As above it would be good to have some comments / docstrings / references / etc. describing what each approach is meant to do.
(C.f. something like the MultiheadAttention
docstring for an example on how to use LaTeX if it'd be helpful.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added some commentary and tried making it a bit cleaner.
But overall batch mode should follow the cited paper. Ema follows the prior behavior but changes the initialization of the running stats and adds interpolation so it can be stable while training.
debias_coef = (axis_size) / jnp.maximum(axis_size - 1, self.eps) | ||
running_var = ( | ||
1 - momentum | ||
) * debias_coef * batch_var + momentum * running_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I neglected to use unbiased variance so corrected that here
@@ -202,8 +259,15 @@ def _stats(y): | |||
norm_var = zero_frac * batch_var + running_var | |||
norm_var = (1.0 - warmup_frac) * batch_var + warmup_frac * norm_var | |||
else: | |||
axis_size = jax.lax.psum(jnp.array(1.0), self.axis_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using this to get the length of the "batch" axis - but not sure it's the best / correct way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the correct way! IIRC psum(1)
is actually special-cased for this purpose.
equinox/nn/_batch_norm.py
Outdated
- `approach`: The approach to use for the running statistics. If `approach=None` | ||
a warning will be raised and approach will default to `"batch"`. During | ||
training `"batch"` only uses batch statisics while`"ema"` uses the running | ||
statistics. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So continuing from my previous comment -- probably the default should be ema
if approach=None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay! Sorry for taking so long to getting back around to reviewing this.
Lmk once you're happy that the previous behaviour is replicated by default, and I'll sit down with a pen and paper and satisfy myself that the calculations all look reasonable!
All good - I got caught up in other things myself! From my tests the replication is exact now. It added another approach that is very similar to |
2905390
to
a386f27
Compare
This is in reference to issue 659.
I modified BatchNorm to have two approaches
"batch"
and"ema"
."batch"
just uses the batch statistics during training time. If approach is not specified it defaults to"batch"
with a warning. It's robust and seems to be the standard choice - it's far less likely to kill a model just by adding it."ema"
is based of the smooth start method in the above issue. So keep a running mean and variance but instead of renormalizing Adam style the parts of the running averages that are zeroed are filled with the batch statistics. The problem is it's still not robust - the momentum parameter is simultaneously specifying a warmup period (when we're expecting the input distribution to change significantly) and how long we want the running average to be. So I added a linear warmup period.Now for any choice of momentum there seems to be a
warmup_period
choice that will give good results. And validation performance was at least as good as with batch mode for my tests. However, I don't see a good default forwarmup_period
.Some considerations:
approach="batch"
and the commonaxis_name="batch"
is a little awkwardapproach="batch"
if desiredLet me know what you think or if any changes or tests need to be added