Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm initialization #955

Open
ZagButNoZig opened this issue Feb 22, 2025 · 3 comments
Open

BatchNorm initialization #955

ZagButNoZig opened this issue Feb 22, 2025 · 3 comments

Comments

@ZagButNoZig
Copy link

First of all, thank you for your great library!

The problem

  • Initialize a model with multiple batch norms (e.g. ResNet)
  • Set model to inference mode
  • Observe numerical instability (values > 10^9 after a few ResNet-Blocks)

If I read your code correctly variance in BatchNorm is initialized to zero, meaning that if you use BatchNorm in inference without any training (think: testing random model performance on dataset or the first iteration of an RL algorithm), the BatchNorm scales all values by $\frac{1}{\sqrt{eps}} = 10^4$. Both PyTorch and Flax use $\sigma^2 = 1$ in their code. I think this is a much saner default, considering that $\sigma^2 = 0$ doesn't really make sense.

Tl;dr:

  • BatchNorm causes instability due to the initial value of the variance
  • Both Flax and PyTorch use variance = 1 as their initial value
@patrick-kidger
Copy link
Owner

I think that's a good spot! Tagging @lockwo and #948 here -- we're thinking of reworking BatchNorm. @lockwo WDYT?

@lockwo
Copy link
Contributor

lockwo commented Feb 24, 2025

Yea, that's a good point. It has less effect on the batch mode of the PR (since the weighting is all based on the current batch), but will still impact evaluation. I changed the initialization to be ones for the std in that case. I believe haiku (which is what I based the implementation on) does it slightly differently, which is why I didn't have it originally (https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/batch_norm.py#L125, https://github.com/google-deepmind/dm-haiku/blob/main/haiku/_src/moving_averages.py#L91).

I can also change it for the existing approach (ema), but that would result in a change in what batch norm does (and while it likely results in better performance in many cases, would change how default batch norm would behave, which I'm not sure would meet the requirement of "bit for bit backwards compatibility" #675 (comment), but we could just put a note saying if you want the old init, you can run this tree_at command).

@ZagButNoZig
Copy link
Author

ZagButNoZig commented Feb 25, 2025

I can also change it for the existing approach (ema), but that would result in a change in what batch norm does (and while it likely results in better performance in many cases, would change how default batch norm would behave, which I'm not sure would meet the requirement of "bit for bit backwards compatibility" #675 (comment), but we could just put a note saying if you want the old init, you can run this tree_at command).

I would argue that this is a bugfix that is (resonably) backwards compatible, but that's just my 2 cents.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants