-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reversible Solvers #528
base: main
Are you sure you want to change the base?
Reversible Solvers #528
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,12 @@ | |
|
||
from ._heuristics import is_sde, is_unsafe_sde | ||
from ._saveat import save_y, SaveAt, SubSaveAt | ||
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver | ||
from ._solver import ( | ||
AbstractItoSolver, | ||
AbstractRungeKutta, | ||
AbstractStratonovichSolver, | ||
Reversible, | ||
) | ||
from ._term import AbstractTerm, AdjointTerm | ||
|
||
|
||
|
@@ -852,3 +857,206 @@ def loop( | |
) | ||
final_state = _only_transpose_ys(final_state) | ||
return final_state, aux_stats | ||
|
||
|
||
# Reversible Adjoint custom vjp computes gradients w.r.t. | ||
# - y, corresponding to the initial state; | ||
# - args, corresponding to explicit parameters; | ||
# - terms, corresponding to implicit parameters as part of the vector field. | ||
|
||
|
||
@eqx.filter_custom_vjp | ||
def _loop_reversible(y__args__terms, *, self, throw, init_state, **kwargs): | ||
del throw | ||
y, args, terms = y__args__terms | ||
init_state = eqx.tree_at(lambda s: s.y, init_state, y) | ||
del y | ||
return self._loop( | ||
args=args, | ||
terms=terms, | ||
init_state=init_state, | ||
inner_while_loop=ft.partial(_inner_loop, kind="lax"), | ||
outer_while_loop=ft.partial(_outer_loop, kind="lax"), | ||
**kwargs, | ||
) | ||
|
||
|
||
@_loop_reversible.def_fwd | ||
def _loop_reversible_fwd(perturbed, y__args__terms, **kwargs): | ||
del perturbed | ||
final_state, aux_stats = _loop_reversible(y__args__terms, **kwargs) | ||
ts = final_state.reversible_ts | ||
ts_final_index = final_state.reversible_save_index | ||
y1 = final_state.save_state.ys[-1] | ||
solver_state1 = final_state.solver_state | ||
return (final_state, aux_stats), (ts, ts_final_index, y1, solver_state1) | ||
|
||
|
||
@_loop_reversible.def_bwd | ||
def _loop_reversible_bwd( | ||
residuals, | ||
grad_final_state__aux_stats, | ||
perturbed, | ||
y__args__terms, | ||
*, | ||
self, | ||
solver, | ||
event, | ||
t0, | ||
t1, | ||
dt0, | ||
init_state, | ||
progress_meter, | ||
**kwargs, | ||
): | ||
assert event is None | ||
|
||
del perturbed, init_state, t1, progress_meter, self, kwargs | ||
ts, ts_final_index, y1, solver_state1 = residuals | ||
original_solver_state, z1 = solver_state1 | ||
del residuals, solver_state1 | ||
|
||
grad_final_state, _ = grad_final_state__aux_stats | ||
# ReversibleAdjoint currently only allows SaveAt(t1=True) so grad_y1 should have | ||
# the same structure as y1. | ||
grad_y1 = grad_final_state.save_state.ys[-1] | ||
grad_y1 = jtu.tree_map(_materialise_none, y1, grad_y1) | ||
del grad_final_state, grad_final_state__aux_stats | ||
|
||
y, args, terms = y__args__terms | ||
del y__args__terms | ||
|
||
diff_args = eqx.filter(args, eqx.is_inexact_array) | ||
diff_terms = eqx.filter(terms, eqx.is_inexact_array) | ||
diff_z1 = eqx.filter(z1, eqx.is_inexact_array) | ||
grad_args = jtu.tree_map(jnp.zeros_like, diff_args) | ||
grad_terms = jtu.tree_map(jnp.zeros_like, diff_terms) | ||
grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1) | ||
del diff_args, diff_terms, diff_z1 | ||
|
||
def grad_step(state): | ||
def solver_step(terms, t0, t1, y1, args): | ||
step, _, _, _, _ = solver.solver.step( | ||
terms, t0, t1, y1, args, (first_step, f0), False | ||
) | ||
return step | ||
|
||
ts_index, y1, solver_state, grad_y1, grad_z1, grad_args, grad_terms = state | ||
(first_step, f0), z1 = solver_state | ||
|
||
t1 = ts[ts_index] | ||
t0 = ts[ts_index - 1] | ||
ts_index = ts_index - 1 | ||
|
||
# TODO The solver steps switch between evaluating from z0 | ||
# and y1. Therefore, we re-evaluate f0 outside of the base | ||
# solver to ensure the vf is correct. | ||
# Can we avoid this re-evaluation? | ||
|
||
f0 = solver.func(terms, t1, y1, args) | ||
step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args) | ||
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω | ||
|
||
f0 = solver.func(terms, t0, z0, args) | ||
step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args) | ||
|
||
y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω | ||
|
||
grad_step_y1 = vjp_fun_y1(grad_z1) | ||
grad_y1 = (ω(grad_y1) + ω(grad_z1) - ω(grad_step_y1[3])).ω | ||
|
||
grad_step_z0 = vjp_fun_z0(grad_y1) | ||
grad_y0 = (solver.l * ω(grad_y1)).ω | ||
grad_z0 = (ω(grad_z1) - solver.l * ω(grad_y1) + ω(grad_step_z0[3])).ω | ||
|
||
grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω | ||
grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω | ||
|
||
return ( | ||
ts_index, | ||
y0, | ||
((first_step, f0), z0), | ||
grad_y0, | ||
grad_z0, | ||
grad_args, | ||
grad_terms, | ||
) | ||
|
||
def cond_fun(state): | ||
ts_index = state[0] | ||
return ts_index > 0 | ||
|
||
state = ( | ||
ts_final_index, | ||
y1, | ||
(original_solver_state, z1), | ||
grad_y1, | ||
grad_z1, | ||
grad_args, | ||
grad_terms, | ||
) | ||
|
||
state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax") | ||
_, _, _, grad_y0, grad_z0, grad_args, grad_terms = state | ||
return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms | ||
|
||
|
||
class ReversibleAdjoint(AbstractAdjoint): | ||
""" | ||
Backpropagate through [`diffrax.diffeqsolve`][] when using the | ||
[`diffrax.Reversible`][] solver. | ||
|
||
This method implies very low memory usage and exact gradient calculation (up to | ||
floating point errors). | ||
|
||
This will compute gradients with respect to the `terms`, `y0` and `args` arguments | ||
passed to [`diffrax.diffeqsolve`][]. If you attempt to compute gradients with | ||
respect to anything else (for example `t0`, or arguments passed via closure), then | ||
a `CustomVJPException` will be raised. See also | ||
[this FAQ](../../further_details/faq/#im-getting-a-customvjpexception) | ||
entry. | ||
""" | ||
|
||
def loop( | ||
self, | ||
*, | ||
args, | ||
terms, | ||
solver, | ||
saveat, | ||
init_state, | ||
passed_solver_state, | ||
passed_controller_state, | ||
event, | ||
**kwargs, | ||
): | ||
# `is` check because this may return a Tracer from SaveAt(ts=<array>) | ||
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`saveat=SaveAt(t1=True)`." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will probably not take long until someone asks to use this alongside I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at (Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FWIW I imagine There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added functionality for Not a solution, but some thoughts: The In principle, the interpolated I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On your thoughts -- I think this is exactly right. In principle we should be able to just solve backwards, and once we have the relevant y-values we can (re)interpolate the same solution we originally provided, and then pull the contangents backwards through that computation via autodiff. Code-wise that may be somewhat fiddly, but if you're willing to take it on then I expect that'll actually be a really useful use-case. I'm not sure if this would be done by decoupling solve from interpolation. I expect it would be some |
||
|
||
if not isinstance(solver, Reversible): | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`Reversible()` solver." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we perhaps remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like this idea :D I've removed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency I'd probably put it in But that's only for consistency, one could imagine an alternate layout where these all lived next to their consumers instead. |
||
|
||
y = init_state.y | ||
init_state = eqx.tree_at(lambda s: s.y, init_state, object()) | ||
init_state = _nondiff_solver_controller_state( | ||
self, init_state, passed_solver_state, passed_controller_state | ||
) | ||
|
||
final_state, aux_stats = _loop_reversible( | ||
(y, args, terms), | ||
self=self, | ||
saveat=saveat, | ||
init_state=init_state, | ||
solver=solver, | ||
event=event, | ||
**kwargs, | ||
) | ||
final_state = _only_transpose_ys(final_state) | ||
return final_state, aux_stats |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -23,7 +23,7 @@ | |||
import optimistix as optx | ||||
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real | ||||
|
||||
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint | ||||
from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint, ReversibleAdjoint | ||||
from ._custom_types import ( | ||||
BoolScalarLike, | ||||
BufferDenseInfos, | ||||
|
@@ -110,6 +110,11 @@ class State(eqx.Module): | |||
event_dense_info: Optional[DenseInfo] | ||||
event_values: Optional[PyTree[Union[BoolScalarLike, RealScalarLike]]] | ||||
event_mask: Optional[PyTree[BoolScalarLike]] | ||||
# | ||||
# Information for reversible adjoint (save ts) | ||||
# | ||||
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way, I think this should be registered as a buffer in Line 627 in 467d95f
As for what buffers are, see also some discussion on buffers here from one of Andraz's PRs: +the docs for |
||||
reversible_save_index: Optional[IntScalarLike] | ||||
|
||||
|
||||
def _is_none(x: Any) -> bool: | ||||
|
@@ -293,6 +298,11 @@ def loop( | |||
dense_ts = dense_ts.at[0].set(t0) | ||||
init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) | ||||
|
||||
if init_state.reversible_ts is not None: | ||||
reversible_ts = init_state.reversible_ts | ||||
reversible_ts = reversible_ts.at[0].set(t0) | ||||
init_state = eqx.tree_at(lambda s: s.reversible_ts, init_state, reversible_ts) | ||||
|
||||
def save_t0(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: | ||||
if subsaveat.t0: | ||||
save_state = _save(t0, init_state.y, args, subsaveat.fn, save_state) | ||||
|
@@ -574,6 +584,15 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): | |||
result, | ||||
) | ||||
|
||||
reversible_ts = state.reversible_ts | ||||
reversible_save_index = state.reversible_save_index | ||||
|
||||
if state.reversible_ts is not None: | ||||
reversible_ts = maybe_inplace( | ||||
reversible_save_index + 1, tprev, reversible_ts | ||||
) | ||||
reversible_save_index = reversible_save_index + jnp.where(keep_step, 1, 0) | ||||
|
||||
new_state = State( | ||||
y=y, | ||||
tprev=tprev, | ||||
|
@@ -595,6 +614,8 @@ def _outer_cond_fn(cond_fn_i, old_event_value_i): | |||
event_dense_info=event_dense_info, | ||||
event_values=event_values, | ||||
event_mask=event_mask, | ||||
reversible_ts=reversible_ts, # pyright: ignore[reportArgumentType] | ||||
reversible_save_index=reversible_save_index, | ||||
) | ||||
|
||||
return ( | ||||
|
@@ -1320,6 +1341,18 @@ def _outer_cond_fn(cond_fn_i): | |||
) | ||||
del had_event, event_structure, event_mask_leaves, event_values__mask | ||||
|
||||
# Reversible info | ||||
if isinstance(adjoint, ReversibleAdjoint): | ||||
if max_steps is None: | ||||
raise ValueError( | ||||
"`max_steps=None` is incompatible with `ReversibleAdjoint`" | ||||
) | ||||
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) | ||||
reversible_save_index = 0 | ||||
else: | ||||
reversible_ts = None | ||||
reversible_save_index = None | ||||
Comment on lines
+1344
to
+1354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what might be simpler here is to have: if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(...)
reversible_save_index = 0 so that we are always saving this information if possible. The benefit of this is that in principle someone else could write their own Finally the This shouldn't really impose any performance penalty (a very small compile-time one only) because for any other adjoint method it will just be DCE'd. |
||||
|
||||
# Initialise state | ||||
init_state = State( | ||||
y=y0, | ||||
|
@@ -1342,6 +1375,8 @@ def _outer_cond_fn(cond_fn_i): | |||
event_dense_info=event_dense_info, | ||||
event_values=event_values, | ||||
event_mask=event_mask, | ||||
reversible_ts=reversible_ts, | ||||
reversible_save_index=reversible_save_index, | ||||
) | ||||
|
||||
# | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm trying to figure out how
solver_state
is handled here. I don't think it is correct?solver_state
is some completely arbitrary information that is propagated forward step-by-step, internal to the solver. In particular we don't have an API for reconstructing this backwards in time reversibly.