Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AbstractGaussNewton now supports reverse-autodiff for Jacobians. #51

Merged
merged 1 commit into from
May 1, 2024

Conversation

patrick-kidger
Copy link
Owner

In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a jax.custom_vjp, see #50

In particular this is useful when the underlying function only supports reverse-mode autodifferentiation due to a `jax.custom_vjp`, see #50
@johannahaffner
Copy link
Contributor

johannahaffner commented Apr 29, 2024

Hi Patrick,

I tried this branch, and it did not work for my use case (parameter estimation for an ODE). I tried both LevenbergMarquardt and GaussNewton in combination with DirectAdjoint or RecursiveCheckpointAdjoint, and I get

TypeError: can't apply forward-mode autodiff (jvp) to a custom_vjp function.

with both. I have an MWE if that would be useful for you.

During the install, this is the commit it resolved to:

Switched to a new branch 'gauss-newton-jacrev'
branch 'gauss-newton-jacrev' set up to track 'origin/gauss-newton-jacrev'.
Resolved https://github.com/patrick-kidger/optimistix to commit 776820485fd9df320d3089bcd302f2f69124cf14

Hope you had a nice weekend!

Johanna

@johannahaffner
Copy link
Contributor

johannahaffner commented Apr 30, 2024

Nevermind, it works! I checked out #50 and realized that I need to pass options=dict(jac="bwd") to least_squares, now it works.

@patrick-kidger patrick-kidger merged commit a1743c1 into main May 1, 2024
2 checks passed
@patrick-kidger
Copy link
Owner Author

Awesome, I'm glad to hear it! I've just merged this in, so this will appear in the next release of Optimistix. :)

@johannahaffner
Copy link
Contributor

johannahaffner commented May 10, 2024

Hi! Not sure where else to post this, but I wanted to note it somewhere: on my problem, I get dramatic performance reductions when using reverse-mode autodiff inside of least_squares.

Specifically:

  • Using dfx.DirectAdjoint, I get 0.286 s per trajectory if passing options=dict(jac="fwd") and
  • 29.8 s per trajectory if passing options=dict(jac="bwd"), which drops to
  • 8.11 s per trajectory if using dfx.RecursiveCheckpointAdjoint instead of DirectAdjoint.

The documentation does note that forward-mode is usually more efficient, so this might just be further confirmation of that observation :)

The solver I used is LevenbergMarquardt(atol=1e-06, rtol=1e-03).

@patrick-kidger
Copy link
Owner Author

Hey there!

Ah, indeed this might have been the case. It was hard to guess which was the greater overhead: DirectAdjoint + forward mode, or RecursiveCheckpointAdjoint + reverse mode.

FWIW this is motivating me to consider adding a forward-mode specific "adjoint". (I'd been holding off on this in the hopes that support could be added to RecursiveCheckpointAdjoint but that's something which requires changes in JAX itself.) Let me see if I can throw something together and we can see how it runs.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented May 11, 2024

Okay, completely untested, but something like the following should probably work:

class ForwardMode(diffrax.AbstractAdjoint):
    def loop(
        self,
        *,
        solver,
        throw,
        passed_solver_state,
        passed_controller_state,
        **kwargs,
    ):
        del throw, passed_solver_state, passed_controller_state
        inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
        outer_while_loop = ftunctoolspartial(diffrax._adjoint._outer_loop, kind="lax")
        # Support forward-mode autodiff.
        # TODO: remove this hack once we can JVP through custom_vjps.
        if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
            solver = eqx.tree_at(
                lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
            )
        final_state = self._loop(
            solver=solver,
            inner_while_loop=inner_while_loop,
            outer_while_loop=outer_while_loop,
            **kwargs,
        )
        return final_state

by passing diffeqsolve(..., adjoint=ForwardMode()). Then the resulting diffeqsolve should be forward-mode autodifferentiable.

Let me know how this goes!

@johannahaffner
Copy link
Contributor

johannahaffner commented May 14, 2024

Awesome, thank you! I'm happy to give it a proper go tomorrow. So far I am getting this error

.../site-packages/diffrax/_integrate.py#line=457), in loop()
    456     return new_state
--> 458 final_state = outer_while_loop(
    459     cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers
    460 )
    462 def _save_t1(subsaveat, save_state):

.../contextlib.py#line=80), in inner()
     80 with self._recreate_cm():
---> 81     return func(*args, **kwds)

.../site-packages/equinox/internal/_loop/loop.py#line=102), in while_loop()
    102 del cond_fun, body_fun, init_val
--> 103 _, _, _, final_val = lax.while_loop(cond_fun_, body_fun_, init_val_)
    104 return final_val

JaxStackTraceBeforeTransformation: ValueError: Reverse-mode differentiation does not work for 
lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using 
fori_loop with static start/stop.

I'm a little unsure where the reverse-mode differentiation is coming from here (this is with options=dict(jac="fwd") in least_squares). But also happy to dig into it a little :)

BTW, I am assuming that _is_none = lambda x: x is None, as given as an example here.

@johannahaffner
Copy link
Contributor

Ok, mini update: I haven't been able to dig into it any further, but I did get better benchmarks on my problem.
I use jit(vmap(...)) to parallelize over the individuals, so the runtime is determined by the trajectories that take longest to fit.
The sum of the maximum number of steps taken in each iteration works out to a value that is quite close to the total runtime, divided by the time it takes to simulate the data. (376 vs. 380.)

It looks like forward-mode differentiation on DirectAdjoint is already super fast, at least as far as optimistix is concerned. So unless the ForwardModeAdjoint also speeds up the ODE solving, I would not expect too many performance gains to come from this.

@patrick-kidger
Copy link
Owner Author

So unless the ForwardModeAdjoint also speeds up the ODE solving

Indeed it does! (Granted, possibly not by that much. Depends on your problem.)

Do you have a quick MWE demonstrating the above crash?

@johannahaffner
Copy link
Contributor

Ah even faster ODEs! Nice, that would be cool.
I tried this on my real data, will make an MWE after the weekend!

@johannahaffner
Copy link
Contributor

johannahaffner commented May 22, 2024

Here comes the MWE! The adjoint works fine in diffrax. But optimistix attempts to reverse-mode differentiate through it.

import jax.numpy as jnp

import equinox as eqx
import diffrax
import optimistix

import functools

_is_none = lambda x: x is None  # This function is needed in the forward adjoint

class ForwardMode(diffrax.AbstractAdjoint):
    def loop(
        self,
        *,
        solver,
        throw,
        passed_solver_state,
        passed_controller_state,
        **kwargs,
    ):
        del throw, passed_solver_state, passed_controller_state
        inner_while_loop = functools.partial(diffrax._adjoint._inner_loop, kind="lax")
        outer_while_loop = functools.partial(diffrax._adjoint._outer_loop, kind="lax") 
        # Support forward-mode autodiff.
        # TODO: remove this hack once we can JVP through custom_vjps.
        if isinstance(solver, diffrax.AbstractRungeKutta) and solver.scan_kind is None:
            solver = eqx.tree_at(
                lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none
            )
        final_state = self._loop(
            solver=solver,
            inner_while_loop=inner_while_loop,
            outer_while_loop=outer_while_loop,
            **kwargs,
        )
        return final_state

class ToyModel(eqx.Module):
    """Toy model that provides a simple interface to generate an ODE solution,
    subject to its parameters.
    """
    _term: diffrax.ODETerm
    _t0: float
    _t1: float
    _dt0: float
    _y0: float
    _saveat: diffrax.SaveAt
    _solver: diffrax.AbstractERK
    _adjoint: diffrax.AbstractAdjoint

    def __init__(self, ode_model, initial_state, times, solver, adjoint):
        self._term = diffrax.ODETerm(ode_model)
        self._y0 = initial_state
        
        self._t0 = times[0]
        self._t1 = times[-1]
        self._dt0 = 0.01
        self._saveat = diffrax.SaveAt(ts=times)
        
        self._solver = solver
        self._adjoint = adjoint

    def __call__(self, param):
        sol = diffrax.diffeqsolve(
            self._term, 
            self._solver, 
            self._t0, self._t1, self._dt0, self._y0, 
            args=param, 
            saveat=self._saveat, 
            adjoint=self._adjoint,
        )
        return sol.ys

def estimate_parameters(initial_guess, model, data, solver, solver_options: dict = dict(jac='fwd')):
    """Function that estimates the parameters."""

    def residuals(param, args):
        model, data = args
        fit = model(param)
        res = data - fit
        return res

    sol = optimistix.least_squares(
        residuals, 
        solver, 
        initial_guess,
        args = (model, data),
        options = solver_options,
    )
    return sol

# Create the model
def dydt(t, y, k):  # Toy ODE
    return - k * y
t = jnp.linspace(0, 10, 50)  
y0 = 10.
model = ToyModel(dydt, y0, t, diffrax.Tsit5(), ForwardMode())

# Solve ODE
k = 0.5  # True value
ode_solution = model(k)  # This runs without issue

# Now try solving for the parameters
k0 = 0.1
solver = optimistix.LevenbergMarquardt(atol=1e-09, rtol=1e-06)
lm_solution = estimate_parameters(k0, model, ode_solution, solver)  # This fails

@patrick-kidger patrick-kidger deleted the gauss-newton-jacrev branch May 25, 2024 22:34
patrick-kidger added a commit that referenced this pull request May 26, 2024
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
@patrick-kidger
Copy link
Owner Author

Can you try #61? Hopefully that should fix things :)

(It's also revealed an unrelated bug with complex numbers, but I don't think that should affect you.)

@johannahaffner
Copy link
Contributor

Hi Patrick,

amazing, thank you for the quick fix!
I tried it on my real data, it works a charm and delivers a handsome 3x speedup.

@johannahaffner
Copy link
Contributor

johannahaffner commented May 26, 2024

delivers a handsome 3x speedup.

whoops, my bad, I forgot to control for whether my laptop was plugged in :)

Using DirectAdjoint is 52% slower than using ForwardMode for the parameter estimation, and solving just the ODE is 42% slower with DirectAdjoint.

Still pretty nice! I'll keep using the ForwardMode above.

patrick-kidger added a commit that referenced this pull request Aug 17, 2024
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
patrick-kidger added a commit that referenced this pull request Aug 17, 2024
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants