-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to evaluate derivatives of diffrax-solved, equilibrated functions? #432
Comments
What you're doing looks reasonable to me. The error you're getting is coming from using either Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff. By default, Diffrax provides custom reverse-mode autodiff for |
Why is this? This doesn't seem fundamentally impossible |
The only way to do this is to define a custom primitive, but that's not exactly easy! In particular for our use case here, where we our custom autodiff is a loop over steps. ( |
|
Hi!
This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function$\mathcal{L}$ with respect to some variable $\theta$ at the gradient equilibrium of that loss with respect to some other variable $\partial \mathcal{L}/\partial{y} \approx 0$ . Mathematically this would be something like
In code, building on your snippet from #181
Given these, I could just solve for y and then take the gradient wrt theta, like so
However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y
But using this approach I get a
# TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.
.Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?
Thanks!
The text was updated successfully, but these errors were encountered: