-
-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BatchNorm initialization #955
Comments
Yea, that's a good point. It has less effect on the I can also change it for the existing approach (ema), but that would result in a change in what batch norm does (and while it likely results in better performance in many cases, would change how default batch norm would behave, which I'm not sure would meet the requirement of "bit for bit backwards compatibility" #675 (comment), but we could just put a note saying if you want the old init, you can run this tree_at command). |
I would argue that this is a bugfix that is (resonably) backwards compatible, but that's just my 2 cents. |
First of all, thank you for your great library!
The problem
If I read your code correctly variance in BatchNorm is initialized to zero, meaning that if you use BatchNorm in inference without any training (think: testing random model performance on dataset or the first iteration of an RL algorithm), the BatchNorm scales all values by$\frac{1}{\sqrt{eps}} = 10^4$ . Both PyTorch and Flax use $\sigma^2 = 1$ in their code. I think this is a much saner default, considering that $\sigma^2 = 0$ doesn't really make sense.
Tl;dr:
The text was updated successfully, but these errors were encountered: