-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Zero implicit gradients when using ImplicitAdjoint
with CG solver
#54
Comments
ImplicitAdjoint
with CG solver
Hey there! Thanks for the issue. Would you be able to condense your code down to a single MWE? (Preferably around 20 lines of code.) For example we probably don't need the details of your training loop, the fact that it's batched, etc. Moreover I'm afraid this code won't run -- if nothing else, it currently doesn't have any import statements. |
Hi @patrick-kidger, I updated the original post with a condensed MWE. In the above code, when I use Because of this, I started wondering if there is a fundamental difference in how implicit adjoints are calculated in the two packages. My instinct is that the mismatch might have to do with handling of higher-order terms but I am curious to hear your opinion and whether it is something that can be quickly patched. |
So in a minimisation problem, the solution stays constant as you peturb the initial parameters. Regardless of where you start, you should expect to converge to the same solution! So in fact a zero gradient is what is expected. (Imagine finding The fact that you get a nonzero gradient via The fact that JAXopt appears to do otherwise is possibly a bug in JAXopt. (?) That aside, some comments on your implementation:
|
Hi Patrick, Thank you for your thorough response. It is true that the solution to a minimisation problem is independent of the initial parameters and should lead to zero gradients. As you noted, the gradients from So, there is no bug in JAXopt, and it only appeared so because of my incomplete explanation. The iMAML paper that I am following uses CG because it avoids forming a Hessian matrix. It also seems that CG-like iterative solvers are used quite extensively within JAXopt - as far as I understand this again has to do with their matrix-free nature. Optimistix, on the other hand seems to be using direct solvers as a default. |
Hi @patrick-kidger and @packquickly,
I was trying to implement the following meta-learning example from jax-opt in optimistix: Few-shot Adaptation with Model Agnostic Meta-Learning . However, I ran into an issue with implicit differentiation through the inner loop. The below example runs well when using
optx.RecursiveCheckpointAdjoint
but when I try to recreate the iMAML setup by puttingoptx.ImplicitAdjoint
with a CG solver with 20 steps, all the meta-gradients are zero, and the meta-optimiser doesn't change at all in the training. Could you please help me identify the issue with the code? It seems to be an implementation detail for implicit adjoints that differs between jax-opt and optimistic.Here is an MWE:
The text was updated successfully, but these errors were encountered: