-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Incompatibility of least_squares and custom_vjp #50
Comments
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
Yup, you're completely correct in your diagnosis: Diffrax has a We have essentially two possible fixes: offer a way for Diffrax to use forward-mode autodifferentiation, or offer a way for Optimistix to use reverse-mode. For now I've just added the latter. in #51. Try using Optimistix from that branch and see if it solves your problem! You'll need to pass (I'd like to add better forward-mode support for Diffrax, but the best way of doing this is really dependent on JAX just adding directly support for |
Amazing, works as intended (at least for the simple example I've tried)! |
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
I'm running into some trouble applying
optimistix.least_squares(fn, LevenbergMarquardt(...), x0)
to certain problems. From the error message below, my understanding of the root cause is that forward-mode autodiff cannot be used onjax.custom_vjp
. In my case I am usingdiffrax
to solve an ODE withinfn(...)
, which I think might be causing the problem.Is my basic understanding correct? Are there specific constraints / assumptions that
fn(...)
must follow foroptimistix.least_squares
to work (e.g. cannot usejax.custom_vjp
)? Is there any way around this?The error I get is:
The full code to reproduce the error is below. By the way I get the same problem when trying to use
jaxopt.LevenbergMarquardt
on this problem.The text was updated successfully, but these errors were encountered: