Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reversible Solvers #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Conversation

sammccallum
Copy link

Hey Patrick,

Here's an implementation of Reversible Solvers! This includes:

  1. make any AbstractRungeKutta method in diffrax algebraically reversible - see diffrax.Reversible
  2. backpropagate through the solve in constant memory and get exact gradients (up to floating point errors) - see diffrax.ReversibleAdjoint

Main details I should highlight here:

  1. The current implementation relies on the _SolverState type of AbstractRungeKutta methods. Specifically, as the reversible method switches between evaluating the vector field at y and z, we ensure the fsal is correct by evaluating the vector field outside of the base Runge Kutta step. In principle this is unnecessary but required to fit with the behaviour of AbstractRungeKutta solvers; any ideas for how to avoid this?

  2. To backpropagate through the reversible solve we require knowledge of the ts that the solver visited. As this is not known a priori for adaptive step sizes, I've added a (teeny weeny) bit of infrastructure to the State in _integrate.py. This allows us to save the ts that the solver stepped to which we make available to ReversibleAdjoint as a residual. The added State follows exactly the implementation of saving dense_ts and is only triggered when adjoint=ReversibleAdjoint.

Best,
Sam

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really well done. I've left some comments but it's mostly around broader structural/testing/documentation stuff.

I've commented on your point 1 inline, and I think what you've done for point 2 looks good to me!

`adjoint=diffrax.ReversibleAdjoint()`.
"""

solver: AbstractRungeKutta
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are implicit RK methods handled here as well? According to this annotation they are but I don't think I see them in the tests.

What is it about RK methods that privileges them here btw? IIUC I think any single-step method should work?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also what about Euler, which isn't implemented as an AbstractRungeKutta but which does have the correct properties?

(It's done separately to be able to use as example code for how to write a solver.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, you're right - any single step method should work. The reversible solver now works with any AbstractSolver.

See the discussion on fsal for more info.

Comment on lines 63 to 64
def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]:
return self.solver.strong_order(terms)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect this technique to work for SDEs? If so then do call that out explicitly in the docstring, to reassure people! :)

(In particular I'm thinking of the asychronous leapfrog method, which to our surprise did not work for SDEs...)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have no theory here (does James' intuition count?), but numerically it works for SDEs! I've added SDEs to the docstring.

There's the detail that the second solver step (that steps backwards in time) should use the same Brownian increment as the first solver step. I believe this is handled by VirtualBrownianTree.

Copy link
Author

@sammccallum sammccallum Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added a check and test for UnsafeBrownianPath in light of the above.

And thinking about this further, we require the same conditions as BacksolveAdjoint; namely that the solver converges to the Stratonovich solution, so I've added a check and test for this.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm FWIW my intuition is also that this should work for SDEs. Sounds like a follow-up paper to be written :)

But... until that theory exists, I think I'd feel more comfortable issuing an error here instead, to try to minimize the possibility of footguns. Most users treat solvers like oracles, and I try to cater to that unfootgunable UX!

Comment on lines 34 to 35
`adjoint=diffrax.ReversibleAdjoint()`.
"""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do go ahead and a couple of references here! (See how we've done it in the other solvers.) At the very least including both your paper, and the various earlier pieces of work. Also make sure whatever you put here works with diffrax.citation, so that folks have an easy way to cite you :)

What happens if I use just ReversibleAdjoint with a different solver? What happens if I use Reversible with a different adjoint? Is this safe to use with adaptive time stepping? The docstring here needs to make clear what a user should expect to happen as this interacts with the other components of Diffrax!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, good point - added.

Removing Reversible from public API helps with control here.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test checking how this interacts with events? It's not immediately obvious to me that this will actually do the right thing.

Also, it would be good to see some 'negative tests' checking that the appropriate error is raised if Reversible is used in conjunction with e.g. SemiImplicitEuler, or any other method that isn't supported.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Events seem to work on the forward reversible solve but raise the same error as BacksolveAdjoint on the backward solve. I've added a catch to raise an error if you try to use ReversibleAdjoint with events.

Negative tests for incompatible solvers, events and saveats have been added.

Comment on lines 1034 to 1038
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)`."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably not take long until someone asks to use this alongside SaveAt(ts=...)!

I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts. Do you have any ideas for this?

(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I imagine SaveAt(steps=True) is probably much easier.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added functionality for SaveAt(steps=True), but SaveAt(ts=...) is a tricky one.

Not a solution, but some thoughts:

The ReversibleAdjoint computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.

In principle, the interpolated ys are just a function of the stepped-to ys. We can therefore calculate gradients for the stepped-to ys and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.

I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!

Copy link
Owner

@patrick-kidger patrick-kidger Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On your thoughts -- I think this is exactly right. In principle we should be able to just solve backwards, and once we have the relevant y-values we can (re)interpolate the same solution we originally provided, and then pull the contangents backwards through that computation via autodiff. Code-wise that may be somewhat fiddly, but if you're willing to take it on then I expect that'll actually be a really useful use-case.

I'm not sure if this would be done by decoupling solve from interpolation. I expect it would be some _, vjp_fn = jax.vjp(compute_interpolation); vjp_fn(y_cotangent) calls inside your while loop on the backward pass.

Comment on lines 1040 to 1044
if not isinstance(solver, Reversible):
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`Reversible()` solver."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps remove Reversible from the public API altogether, and just have solver = Reversible(solver) here? Make the Reversible solver an implementation detail of the adjoint.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this idea :D

I've removed Reversible from the public API and any AbstractSolver passed to ReversibleAdjoint is auto-wrapped. There is now a _Reversible class within the _adjoint module that is exclusively used by ReversibleAdjoint. Do you think this is an appropriate home for the _Reversible class or should I keep it elsewhere?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency I'd probably put it in _solvers/reversible.py -- I've generally tended to organize things in this way, e.g. _terms.py::AdjointTerm is used as part of BacksolveAdjoint.

But that's only for consistency, one could imagine an alternate layout where these all lived next to their consumers instead.

solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
(first_step, f0), z0 = solver_state
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail for non-FSAL Runge-Kutta solvers.
(Can you add a test for one of those to be sure we get correct behaviour?)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comment below.

# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?

f0 = self.func(terms, t0, z0, args)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is really okay -- AbstractSolver.func is something we try to use approximately never: it basically exists just to handle initial step size selection and steady state finding, which are both fairly heuristic and pretty far of the beaten path.

If I understand correctly, the issue is that your y1 isn't quite the value that is returned from a single step, so the FSAL property does not hold, and as such you need to reevaluate f0? If so then I think you should be able to avoid this issue by ensuring that the RK solvers are used in non-FSAL form. This is one of the most complicated corners of the codebase, but take a look at the comment starting here:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now disable the FSAL and SSAL properties in the _Reversible init method (if a RK solver is used).

With this we can now make any AbstractSolver reversible and we pass around the _Reversible solver state by (original_solver_state, z_n). We also never unpack the original_solver_state, so don't need to assume any structure.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! It makes me very happy that these are all possible, that really simplifies our lives ^^

Comment on lines 95 to 103
step_z0, z_error, dense_info, _, result1 = self.solver.step(
terms, t0, t1, z0, args, (first_step, f0), made_jump
)
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω

f0 = self.func(terms, t1, y1, args)
step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, (first_step, f0), made_jump
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On these two evaluations of .step -- take a look at eqx.internal.scan_trick, which might allow you to collapse these two callsites into one. That can be used to half compilation time!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMIIW, I'm not sure we can use the scan trick here as the function return signature is different for each solver step?

That is, we only want to update the original_solver_state and dense_info when taking the forward-in-time step. So we don't return these on the backward-in-time step. IIUC, collapsing the two calls into one would require the returned carry to be the same on both calls?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it's usually possible to make this work even in arbitrary cases by welding two copies of your state together and then using a jnp.where or lax.cond to route between them based on which step you're on.

That said this is a pretty fiddly optimization, and I probably shouldn't have suggested it just yet! Once we're happy with everything else then we could do this later, but until then the un-scan-trick'd code is much easier to read.

commit ec1ebac
Author: Sam McCallum <[email protected]>
Date:   Wed Nov 27 08:46:55 2024 +0000

    tidy up function arguments

commit 7b66f46
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 18:13:11 2024 +0000

    beefy tests

commit e713b5d
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:29:26 2024 +0000

    update references

commit 9acf6e0
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:12:26 2024 +0000

    test incorrect solver

commit 861aa97
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 13:05:05 2024 +0000

    catch already reversible solvers

commit 4b8b4c0
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 12:37:03 2024 +0000

    error estimate may be pytree

commit 0b01210
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 12:36:24 2024 +0000

    tests

commit 5435ab2
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:17:09 2024 +0000

    Revert "leapfrog not compatible"

    This reverts commit d88e732.

commit d88e732
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:15:32 2024 +0000

    leapfrog not compatible

commit 6e3f2de
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 11:13:30 2024 +0000

    pytree state

commit 3fa6432
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 10:28:26 2024 +0000

    docs

commit 2bfe820
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:34:36 2024 +0000

    remove reversible.py solver file

commit e7856d3
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:33:52 2024 +0000

    fix tests for relative import

commit 24d1935
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 09:18:05 2024 +0000

    private reversible

commit 8a7448e
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:56:40 2024 +0000

    remove debug print

commit 0391bc1
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:28:41 2024 +0000

    tests

commit 81a9a57
Author: Sam McCallum <[email protected]>
Date:   Tue Nov 26 08:23:41 2024 +0000

    more tests

commit 89f5731
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 20:52:51 2024 +0000

    test implicit solvers + SDEs

commit f30f47e
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 20:44:54 2024 +0000

    remove t0, t1, solver_state tangents

commit b903176
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 16:56:01 2024 +0000

    docs

commit acaa35f
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 12:56:50 2024 +0000

    better steps=True

commit 621e6f4
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 10:28:19 2024 +0000

    remove ifs in grad_step loop

commit 7dfb8e3
Author: Sam McCallum <[email protected]>
Date:   Mon Nov 25 09:15:18 2024 +0000

    Disable fsal, ssal properties to allow any solver to be made reversible

commit f160295
Author: Sam McCallum <[email protected]>
Date:   Fri Nov 22 15:09:57 2024 +0000

    tests

commit f327f66
Author: Sam McCallum <[email protected]>
Date:   Fri Nov 22 13:53:56 2024 +0000

    ReversibleAdjoint compatible with SaveAt(steps=True)

Reversible Solvers (v2)

Changes:
- `Reversible` solver is hidden from public API and automatically used with `ReversibleAdjoint`
- compatible with any `AbstractSolver`, except methods that are already algebraically reversible
- can now use `SaveAt(steps=True)`
- works with ODEs/CDEs/SDEs
- improved docs
- improved tests
@sammccallum
Copy link
Author

Thanks very much for the review and suggestions!

I'll reply to individual comments inline but here is an overview:

  1. I really like the idea for removing Reversible from the public API and just making the adjoint auto-wrap the original solver. It feels very JAX-like, as we just augment the forward-trace so that we can backpropagate with improved properties. This has been added.

  2. Can now use any AbstractSolver, apart from solvers that are already algebraically reversible (raises a ValueError if passed).

  3. Add functionality for SaveAt(steps=True).

  4. Improved docs so that people know what they can use ReversibleAdjoint with (any solver, ODEs/CDEs/SDEs, adaptive time steps).

  5. Improved tests

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sorry for the (long) delay! 😅

So this mostly looks like it's in a good place to me. I think my one big concern is handling solver_state -- as each solver is free to define any kind of state it likes, then (a) there need not be any way to reconstruct it reversibly (so the backward pass goes wrong), and (b) it may encode information from previous y,z-values that aren't actually the ones being propagated (so the forward pass goes wrong).

I think right now you're dodging this just because of the particular choice of solvers being considered, i.e. (forced-to-be-)non-FSAL Runge--Kutta solvers.

I think we've got a couple of possible options:

  • We could expand the solver API a little bit. For example a method to construct the previous state reversibly, and a flag to indicate that the y-value has changed and the solver might need to consider its state invalidated.
  • We could also special-case down to e.g. just AbstractRungeKutta, and hardwire all the things about it that we know how to handle.

I am weakly leaning towards the first option, as it could allow us to 'already reversible' solvers like ReversibleHeun or LeapfrogMidpoint (c.f. #541).

  • We can introduce an AbstractReversibleSolver, subclassed by ReversibleHuen, LeapfrogMidpoint and your _Reversible.
  • Then ReversibleAdjoint uses just the API provided by AbstractReversibleSolver, without making assumptions about precisely which one.
  • _Reversible could afford to consume only those kinds of one-step solvers it knows how to handle (e.g. non-FSAL RK methods), and if other special use-cases arise later then we can always make more subclasses of AbstractReversibleSolver.

But I also realise that figuring out these details make for pretty complicated research questions, and I don't want to presume upon your appetite for tackling them!

Let me know what you think / also happy to have a longer chat about this via email etc if it's easier.

#
# Information for reversible adjoint (save ts)
#
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I think this should be registered as a buffer in _outer_buffers used here:

cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers

As for what buffers are, see also some discussion on buffers here from one of Andraz's PRs:

#484 (comment)

+the docs for eqxi.while_loop

Comment on lines +1344 to +1354
# Reversible info
if isinstance(adjoint, ReversibleAdjoint):
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `ReversibleAdjoint`"
)
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
else:
reversible_ts = None
reversible_save_index = None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what might be simpler here is to have:

if max_steps is None:
    reversible_ts = None
    reversible_save_index = None
else:
    reversible_ts = jnp.full(...)
    reversible_save_index = 0

so that we are always saving this information if possible.

The benefit of this is that in principle someone else could write their own ReversibleAdjoint2 and have it work without needing to be special-cased here inside the main diffeqsolve implementation: it would just consume the information made available to it.

Finally the ValueError can be moved inside the implementation of ReversibleAdjoint, if the necessary reversible_ts information is not available.

This shouldn't really impose any performance penalty (a very small compile-time one only) because for any other adjoint method it will just be DCE'd.

```
"""

l: float = 0.999
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we pick a more descriptive name for this parameter? E.g. when writing an optimizer then mathematically we may conventionally use λ but in code we would often use a variable with a name like learning_rate.

Comment on lines +1111 to +1120
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
if (
eqx.tree_equal(saveat, SaveAt(t1=True)) is not True
and eqx.tree_equal(saveat, SaveAt(steps=True)) is not True
and eqx.tree_equal(saveat, SaveAt(t0=True, steps=True)) is not True
):
raise ValueError(
"Can only use `diffrax.ReversibleAdjoint` with "
"`saveat=SaveAt(t1=True)` or `saveat=SaveAt(steps=True)`."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing that SaveAt(steps=True, t1=True) or SaveAt(t0=True, steps=True, t1=True) should also be allowed?

Comment on lines +1153 to +1163
solver = _Reversible(solver, self.l)
tprev = init_state.tprev
tnext = init_state.tnext
y = init_state.y

init_state = eqx.tree_at(
lambda s: s.solver_state,
init_state,
solver.init(terms, tprev, tnext, y, args),
is_leaf=_is_none,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the _Reversible.init here will re-call the underlying original solver.init, I think unnecessarily? I think we should be able to do just init_state = eqx.tree_at(lambda s: s.solver_state, init_state, (init_state.solver_state, init_state.y)), and then set class _Reversible: def init(...): assert False.

Comment on lines +1229 to +1231
@property
def term_compatible_contr_kwargs(self):
return self.solver.term_compatible_contr_kwargs
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this, and term_structure, should not be necessary -- they are checked at the start of diffeqsolve, which we are already past. I think it may be neater to set these to assert False?

Comment on lines +1255 to +1257
if isinstance(self.solver, AbstractRungeKutta):
object.__setattr__(self.solver.tableau, "fsal", False)
object.__setattr__(self.solver.tableau, "ssal", False)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this will produce a bug if we do:

solver = Tsit5()
diffeqsolve(..., solver, ReversibleAdjoint())
diffeqsolve(..., solver, RecursiveCheckpointAdjoint())

as the solver is modified in-place.

I think I have a better solution: can we unconditionally pass made_jump=True into self.solver.step? This is our API to indicate to solvers that something has changed, and that their state may be out-of-date. Technically right now it's used to indicate jumps in the vector field, but we could re-use it (or add another flag) to indicate exogenous jumps in y.
Alternatively it may be safer for now to only allow AbstractRungeKutta here, and not general solvers -- it's not clear to me that any of this will really work with multi-step solvers like LeapfrogMidpoint, for example.

solver_state = (original_solver_state, z1)
result = update_result(result1, result2)

return y1, _add_maybe_none(z_error, y_error), dense_info, solver_state, result
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add the error estimates together?

grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1)
del diff_args, diff_terms, diff_z1

def grad_step(state):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm trying to figure out how solver_state is handled here. I don't think it is correct?

solver_state is some completely arbitrary information that is propagated forward step-by-step, internal to the solver. In particular we don't have an API for reconstructing this backwards in time reversibly.

@sammccallum
Copy link
Author

A description of the performance bug with the current ReversibleAdjoint:

  • vmapping over diffeqsolve creates a significant slow down

From the example below, we see that for batch_size=1 ReversibleAdjoint is faster than RecursiveCheckpointAdjoint (as expected). But for batch_size=1000 ReversibleAdjoint is slower than RecursiveCheckpointAdjoint.

import time

import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr

jax.config.update("jax_enable_x64", True)


class VectorField(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, y_dim, width_size, depth, key):
        self.mlp = eqx.nn.MLP(y_dim, y_dim, width_size, depth, key=key)

    def __call__(self, t, y, args):
        return self.mlp(y)


@eqx.filter_jit
def solve(model, y0, adjoint):
    term = dfx.ODETerm(model)
    solver = dfx.Euler()
    t0 = 0.0
    t1 = 10.0
    dt0 = 0.01
    sol = dfx.diffeqsolve(
        term,
        solver,
        t0,
        t1,
        dt0,
        y0,
        saveat=dfx.SaveAt(t1=True),
        adjoint=adjoint,
        max_steps=1000,
    )
    return sol.ys


@eqx.filter_value_and_grad
def grad_loss(model, y0, adjoint):
    ys = eqx.filter_vmap(solve, in_axes=(None, 0, None))(model, y0, adjoint)
    return jnp.mean(ys**2)


def measure_runtime(y0, model, adjoint):
    tic = time.time()
    loss, grads = grad_loss(model, y0, adjoint)
    toc = time.time()
    print(f"Compile time: {(toc - tic):.5f}")

    repeats = 10
    tic = time.time()
    for i in range(repeats):
        loss, grads = jax.block_until_ready(grad_loss(model, y0, adjoint))
    toc = time.time()
    print(f"Runtime: {((toc - tic) / repeats):.5f}")


y_dim = 100
width_size = 100
depth = 4
model = VectorField(y_dim, width_size, depth, key=jr.PRNGKey(10))

print("Batch Size = 1")
print("--------------")
y0 = jnp.ones((1, y_dim))
print("Recursive")
adjoint = dfx.RecursiveCheckpointAdjoint()
measure_runtime(y0, model, adjoint)
print("Reversible")
adjoint = dfx.ReversibleAdjoint()
measure_runtime(y0, model, adjoint)

print("\nBatch Size = 1000")
print("-----------------")
y0 = jnp.ones((1000, y_dim))
print("Recursive")
adjoint = dfx.RecursiveCheckpointAdjoint()
measure_runtime(y0, model, adjoint)
print("Reversible")
adjoint = dfx.ReversibleAdjoint()
measure_runtime(y0, model, adjoint)
Batch Size = 1
--------------
Recursive
Compile time: 3.99825
Runtime: 0.37017
Reversible
Compile time: 1.60509
Runtime: 0.18221

Batch Size = 1000
-----------------
Recursive
Compile time: 5.35720
Runtime: 3.15009
Reversible
Compile time: 4.87897
Runtime: 3.89896

FWIW, I don't think this is specifically a problem with vmap, but a problem of scale. For example, if we keep batch_size=1 but increase the width_size of the MLP then a similar slow down is observed.

Width Size = 100
--------------
Recursive
Compile time: 3.75692
Runtime: 0.33692
Reversible
Compile time: 1.63070
Runtime: 0.18080

Width Size = 1000
--------------
Recursive
Compile time: 3.18077
Runtime: 0.61647
Reversible
Compile time: 2.06301
Runtime: 0.65720

In principle, this is very wrong as we are only changing the cost of each step $n$. The complexity remains a comparison between $O(n)$ for Reversible and $O(n \log n)$ for RecursiveCheckpointing where $n$ is fixed. Any ideas would be massively appreciated!

(The quoted runtimes are on GPU)

@patrick-kidger
Copy link
Owner

Hmm, that's rather weird! I've just taken a look over both your MWE and your code and nothing immediately jumps out at me as wrong.

A few immediate thoughts about how you might tackle this:

  • Can you try doing all of this without ever explicitly calling jax.vmap? Just setting up multiple independent systems altogether and treating them as a single ODE solve. When we do vmap-of-loop or vmap-of-cond then there can be some performance pessimations that occur, and I notice that you're unconditionally using vmap even on the batch_size=1 case.
  • Use a profiler. If you want an example of what debugging with that looks like, here is a recent Equinox issue on a similar topic. FWIW for slowdowns of the nature, then it's moderately common that the underlying root cause is a buffer somewhere getting unnecessarily copied between loop iterations; something to keep an eye out for.
  • Check the asymptotics of how the time varies as your problem size changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants