Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reversible Solvers #528

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DirectAdjoint as DirectAdjoint,
ImplicitAdjoint as ImplicitAdjoint,
RecursiveCheckpointAdjoint as RecursiveCheckpointAdjoint,
ReversibleAdjoint as ReversibleAdjoint,
)
from ._autocitation import citation as citation, citation_rules as citation_rules
from ._brownian import (
Expand Down Expand Up @@ -101,6 +102,7 @@
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
Reversible as Reversible,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
Expand Down
210 changes: 209 additions & 1 deletion diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
from ._solver import (
AbstractItoSolver,
AbstractRungeKutta,
AbstractStratonovichSolver,
Reversible,
)
from ._term import AbstractTerm, AdjointTerm


Expand Down Expand Up @@ -852,3 +857,206 @@ def loop(
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats


# Reversible Adjoint custom vjp computes gradients w.r.t.
# - y, corresponding to the initial state;
# - args, corresponding to explicit parameters;
# - terms, corresponding to implicit parameters as part of the vector field.


@eqx.filter_custom_vjp
def _loop_reversible(y__args__terms, *, self, throw, init_state, **kwargs):
del throw
y, args, terms = y__args__terms
init_state = eqx.tree_at(lambda s: s.y, init_state, y)
del y
return self._loop(
args=args,
terms=terms,
init_state=init_state,
inner_while_loop=ft.partial(_inner_loop, kind="lax"),
outer_while_loop=ft.partial(_outer_loop, kind="lax"),
**kwargs,
)


@_loop_reversible.def_fwd
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs):
del perturbed
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs)
ts = final_state.reversible_ts
ts_final_index = final_state.reversible_save_index
y1 = final_state.save_state.ys[-1]
solver_state1 = final_state.solver_state
return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1)


@_loop_reversible.def_bwd
def _loop_reversible_bwd(
residuals,
grad_final_state__aux_stats,
perturbed,
y__args__terms,
*,
self,
solver,
event,
t0,
t1,
dt0,
init_state,
progress_meter,
**kwargs,
):
assert event is None

del perturbed, init_state, t1, progress_meter, self, kwargs
ts, ts_final_index, y1, solver_state1 = residuals
original_solver_state, z1 = solver_state1
del residuals, solver_state1

grad_final_state, _ = grad_final_state__aux_stats
# ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have
# the same structure as y1.
grad_y1 = grad_final_state.save_state.ys[-1]
grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1)
del grad_final_state, grad_final_state__aux_stats

y, args, terms = y__args__terms
del y__args__terms

diff_args = eqx.filter(args, eqx.is_inexact_array)
diff_terms = eqx.filter(terms, eqx.is_inexact_array)
diff_z1 = eqx.filter(z1, eqx.is_inexact_array)
grad_args = jtu.tree_map(jnp.zeros_like, diff_args)
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms)
grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1)
del diff_args, diff_terms, diff_z1

def grad_step(state):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm trying to figure out how solver_state is handled here. I don't think it is correct?

solver_state is some completely arbitrary information that is propagated forward step-by-step, internal to the solver. In particular we don't have an API for reconstructing this backwards in time reversibly.

def solver_step(terms, t0, t1, y1, args):
step, _, _, _, _ = solver.solver.step(
terms, t0, t1, y1, args, (first_step, f0), False
)
return step

ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state
(first_step, f0), z1 = solver_state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]
ts_index = ts_index - 1

# TODO The solver steps switch between evaluating from z0
# and y1. Therefore, we re-evaluate f0 outside of the base
# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?

f0 = solver.func(terms, t1, y1, args)
step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args)
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω

f0 = solver.func(terms, t0, z0, args)
step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args)

y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω

grad_step_y1 = vjp_fun_y1(grad_z1)
grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω

grad_step_z0 = vjp_fun_z0(grad_y1)
grad_y0 = (solver.l * ω(grad_y1)).ω
grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω

grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω
grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω

return (
ts_index,
y0,
((first_step, f0), z0),
grad_y0,
grad_z0,
grad_args,
grad_terms,
)

def cond_fun(state):
ts_index = state[0]
return ts_index > 0

state = (
ts_final_index,
y1,
(original_solver_state, z1),
grad_y1,
grad_z1,
grad_args,
grad_terms,
)

state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax")
_, _, _, grad_y0, grad_z0, grad_args, grad_terms = state
return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms


class ReversibleAdjoint(AbstractAdjoint):
"""
Backpropagate through [`diffrax.diffeqsolve`][] when using the
[`diffrax.Reversible`][] solver.

This method implies very low memory usage and exact gradient calculation (up to
floating point errors).

This will compute gradients with respect to the `terms`, `y0` and `args` arguments
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with
respect to anything else (for example `t0`, or arguments passed via closure), then
a `CustomVJPException` will be raised. See also
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception)
entry.
"""

def loop(
self,
*,
args,
terms,
solver,
saveat,
init_state,
passed_solver_state,
passed_controller_state,
event,
**kwargs,
):
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)`."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will probably not take long until someone asks to use this alongside SaveAt(ts=...)!

I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts. Do you have any ideas for this?

(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I imagine SaveAt(steps=True) is probably much easier.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added functionality for SaveAt(steps=True), but SaveAt(ts=...) is a tricky one.

Not a solution, but some thoughts:

The ReversibleAdjoint computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.

In principle, the interpolated ys are just a function of the stepped-to ys. We can therefore calculate gradients for the stepped-to ys and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.

I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!

Copy link
Owner

@patrick-kidger patrick-kidger Jan 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On your thoughts -- I think this is exactly right. In principle we should be able to just solve backwards, and once we have the relevant y-values we can (re)interpolate the same solution we originally provided, and then pull the contangents backwards through that computation via autodiff. Code-wise that may be somewhat fiddly, but if you're willing to take it on then I expect that'll actually be a really useful use-case.

I'm not sure if this would be done by decoupling solve from interpolation. I expect it would be some _, vjp_fn = jax.vjp(compute_interpolation); vjp_fn(y_cotangent) calls inside your while loop on the backward pass.


if not isinstance(solver, Reversible):
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`Reversible()` solver."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we perhaps remove Reversible from the public API altogether, and just have solver = Reversible(solver) here? Make the Reversible solver an implementation detail of the adjoint.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this idea :D

I've removed Reversible from the public API and any AbstractSolver passed to ReversibleAdjoint is auto-wrapped. There is now a _Reversible class within the _adjoint module that is exclusively used by ReversibleAdjoint. Do you think this is an appropriate home for the _Reversible class or should I keep it elsewhere?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency I'd probably put it in _solvers/reversible.py -- I've generally tended to organize things in this way, e.g. _terms.py::AdjointTerm is used as part of BacksolveAdjoint.

But that's only for consistency, one could imagine an alternate layout where these all lived next to their consumers instead.


y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
init_state = _nondiff_solver_controller_state(
self, init_state, passed_solver_state, passed_controller_state
)

final_state, aux_stats = _loop_reversible(
(y, args, terms),
self=self,
saveat=saveat,
init_state=init_state,
solver=solver,
event=event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
return final_state, aux_stats
37 changes: 36 additions & 1 deletion diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import optimistix as optx
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint, ReversibleAdjoint
from ._custom_types import (
BoolScalarLike,
BufferDenseInfos,
Expand Down Expand Up @@ -110,6 +110,11 @@ class State(eqx.Module):
event_dense_info: Optional[DenseInfo]
event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]]
event_mask: Optional[PyTree[BoolScalarLike]]
#
# Information for reversible adjoint (save ts)
#
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By the way, I think this should be registered as a buffer in _outer_buffers used here:

cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers

As for what buffers are, see also some discussion on buffers here from one of Andraz's PRs:

#484 (comment)

+the docs for eqxi.while_loop

reversible_save_index: Optional[IntScalarLike]


def _is_none(x: Any) -> bool:
Expand Down Expand Up @@ -293,6 +298,11 @@ def loop(
dense_ts = dense_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts)

if init_state.reversible_ts is not None:
reversible_ts = init_state.reversible_ts
reversible_ts = reversible_ts.at[0].set(t0)
init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts)

def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state)
Expand Down Expand Up @@ -574,6 +584,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
result,
)

reversible_ts = state.reversible_ts
reversible_save_index = state.reversible_save_index

if state.reversible_ts is not None:
reversible_ts = maybe_inplace(
reversible_save_index + 1, tprev, reversible_ts
)
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0)

new_state = State(
y=y,
tprev=tprev,
Expand All @@ -595,6 +614,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType]
reversible_save_index=reversible_save_index,
)

return (
Expand Down Expand Up @@ -1320,6 +1341,18 @@ def _outer_cond_fn(cond_fn_i):
)
del had_event, event_structure, event_mask_leaves, event_values__mask

# Reversible info
if isinstance(adjoint, ReversibleAdjoint):
if max_steps is None:
raise ValueError(
"`max_steps=None` is incompatible with `ReversibleAdjoint`"
)
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype)
reversible_save_index = 0
else:
reversible_ts = None
reversible_save_index = None
Comment on lines +1344 to +1354
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what might be simpler here is to have:

if max_steps is None:
    reversible_ts = None
    reversible_save_index = None
else:
    reversible_ts = jnp.full(...)
    reversible_save_index = 0

so that we are always saving this information if possible.

The benefit of this is that in principle someone else could write their own ReversibleAdjoint2 and have it work without needing to be special-cased here inside the main diffeqsolve implementation: it would just consume the information made available to it.

Finally the ValueError can be moved inside the implementation of ReversibleAdjoint, if the necessary reversible_ts information is not available.

This shouldn't really impose any performance penalty (a very small compile-time one only) because for any other adjoint method it will just be DCE'd.


# Initialise state
init_state = State(
y=y0,
Expand All @@ -1342,6 +1375,8 @@ def _outer_cond_fn(cond_fn_i):
event_dense_info=event_dense_info,
event_values=event_values,
event_mask=event_mask,
reversible_ts=reversible_ts,
reversible_save_index=reversible_save_index,
)

#
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
StratonovichMilstein as StratonovichMilstein,
)
from .ralston import Ralston as Ralston
from .reversible import Reversible as Reversible
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .runge_kutta import (
AbstractDIRK as AbstractDIRK,
Expand Down
Loading