Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to evaluate derivatives of diffrax-solved, equilibrated functions? #432

Open
francesco-innocenti opened this issue Jun 3, 2024 · 4 comments

Comments

@francesco-innocenti
Copy link

Hi!

This is a follow-up on #181. The use case is to evaluate the derivatives (e.g. gradient, hessian) of some loss function $\mathcal{L}$ with respect to some variable $\theta$ at the gradient equilibrium of that loss with respect to some other variable $\partial \mathcal{L}/\partial{y} \approx 0$. Mathematically this would be something like

$\LARGE{\frac{\partial \mathcal{L}(y; \theta)}{\partial \theta}|_{\frac{\partial \mathcal{L}}{\partial y}\approx 0}}$

In code, building on your snippet from #181

def L(y, theta):  # some loss
    ... 

def dLdy(t, y, args):  # vector field for gradient system
    return -jax.grad(L)(y, args)

def solve_y(y0, theta):
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(vector_field),
        y0=y0,
        args=theta,
        ...
    )
    return sol.ys

def dLdtheta(self, y, theta):
    return grad(L, argnums=(1))(y, theta)

Given these, I could just solve for y and then take the gradient wrt theta, like so

y_sol = solve_y(y0, theta):
theta_grad = dLdtheta(y_sol, theta)

However, this ignores the dependencies between y and theta that occur in the integration of the gradient system. So ideally i would like to take the gradient of the loss where within the same loss I'm solving for y

def equilibrated_L(y0, theta):  # equilibrated loss
    y_sol = solve_y(y0, theta)
    ...
    return L

def dLdtheta(self, y, theta):
    return grad(equilibrated_L, argnums=(1))(y, theta)

theta_grad = dLdtheta(y, theta)

But using this approach I get a # TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function..

Hope all of that makes sense. Maybe I am missing something. For example, I wonder whether this could be a use case for an adjoint method?

Thanks!

@patrick-kidger
Copy link
Owner

What you're doing looks reasonable to me.

The error you're getting is coming from using either jax.jvp or jax.jacfwd.

Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff. By default, Diffrax provides custom reverse-mode autodiff for diffeqsolve. You might like to try the solution of patrick-kidger/optimistix#51 (comment), which provides an alternate adjoint method that supports forward mode only instead.

@lockwo
Copy link
Contributor

lockwo commented Jun 4, 2024

Unfortunately there's no good way in JAX to create something that has both custom forward-mode and custom reverse-mode autodiff.

Why is this? This doesn't seem fundamentally impossible

@patrick-kidger
Copy link
Owner

jax.custom_jvp and jax.custom_vjp cannot both be applied: one of them has to be on the outside, and this is the one that autodiff sees.

The only way to do this is to define a custom primitive, but that's not exactly easy! In particular for our use case here, where we our custom autodiff is a loop over steps. (eqx.internal.while_loop(..., kind="checkpointed")).

@johannahaffner
Copy link
Contributor

jax.hessian composes jacfwd(jacrev(function)).

jacrev(jacrev(function)), or jacfwd(jacfwd(function)) both work. So you can adapt to the adjoint of your choice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants