From e04b41aa364b49c31dddbda96fdcc86c612daa14 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 13 May 2023 17:27:10 -0700 Subject: [PATCH 1/7] General RK tidy-ups --- diffrax/solver/runge_kutta.py | 33 ++++++++++++++++----------------- setup.py | 2 +- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 02be7ea7..859edd23 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,4 +1,3 @@ -import abc from dataclasses import dataclass, field from typing import Optional, Tuple @@ -53,6 +52,10 @@ class ButcherTableau: ssal: bool = field(init=False) fsal: bool = field(init=False) + # Informational + implicit: bool = field(init=False) + num_stages: int = field(init=False) + def __post_init__(self): assert self.c.ndim == 1 for a_i in self.a_lower: @@ -101,6 +104,9 @@ def __post_init__(self): lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, ) + object.__setattr__(self, "implicit", implicit) + object.__setattr__(self, "num_stages", len(self.b_sol)) + ButcherTableau.__init__.__doc__ = """**Arguments:** @@ -123,7 +129,7 @@ def __post_init__(self): - `a_predictor`: optional. Used in a similar way to `a_lower`; specifies the linear combination of previous stages to use as a predictor for the solution to the implicit problem at that stage. See - [the developer documentation](../../devdocs/predictor_dirk). U#sed for diagonal + [the developer documentation](../../devdocs/predictor_dirk). Used for diagonal implicit Runge--Kutta methods only. Whether the solver exhibits either the FSAL or SSAL properties is determined @@ -197,15 +203,8 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): term_structure = AbstractTerm - @property - @abc.abstractmethod - def tableau(self) -> ButcherTableau: - pass - - @property - @abc.abstractmethod - def calculate_jacobian(self) -> CalculateJacobian: - pass + tableau: eqxi.AbstractClassVar[ButcherTableau] + calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] def _first(self, terms, t0, t1, y0, args): vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) @@ -411,11 +410,11 @@ def step( num_stages = len(self.tableau.c) + 1 if use_fs: - fs = jtu.tree_map(lambda f: jnp.empty((num_stages,) + f.shape), f0_struct) + fs = jtu.tree_map(lambda f: jnp.zeros((num_stages,) + f.shape), f0_struct) ks = None else: fs = None - ks = jtu.tree_map(lambda k: jnp.empty((num_stages,) + jnp.shape(k)), y0) + ks = jtu.tree_map(lambda k: jnp.zeros((num_stages,) + jnp.shape(k)), y0) # # First stage. Defines `result`, `scan_first_stage`. Places `f0` and `k0` into @@ -556,9 +555,9 @@ def eval_stage(_carry, _input): assert _a_diagonal_i is not None # Predictor for where to start iterating from if _return_fi: - _f_pred = _vector_tree_dot(_a_predictor_i, fs, _i) # noqa: F821 + _f_pred = _vector_tree_dot(_a_predictor_i, _fs, _i) # noqa: F821 else: - _k_pred = _vector_tree_dot(_a_predictor_i, ks, _i) # noqa: F821 + _k_pred = _vector_tree_dot(_a_predictor_i, _ks, _i) # noqa: F821 # Determine Jacobian to use at this stage if self.calculate_jacobian == CalculateJacobian.every_stage: if _return_fi: @@ -827,7 +826,7 @@ class AbstractSDIRK(AbstractDIRK): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if cls.tableau is not None: # Abstract subclasses may not have a tableau. + if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau. diagonal = cls.tableau.a_diagonal[0] assert (cls.tableau.a_diagonal == diagonal).all() @@ -844,7 +843,7 @@ class AbstractESDIRK(AbstractDIRK): def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - if cls.tableau is not None: # Abstract subclasses may not have a tableau. + if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau. assert cls.tableau.a_diagonal[0] == 0 diagonal = cls.tableau.a_diagonal[1] assert (cls.tableau.a_diagonal[1:] == diagonal).all() diff --git a/setup.py b/setup.py index 2c8bad1f..b51839f6 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.8" -install_requires = ["jax>=0.4.3", "equinox>=0.10.0"] +install_requires = ["jax>=0.4.3", "equinox>=0.10.4"] setuptools.setup( name=name, From 3829d4dadeac2f517c66f858c6dc32e307d63563 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 14 May 2023 11:36:49 -0700 Subject: [PATCH 2/7] Initial step size now uses a scan trick to reduce compilation time --- diffrax/step_size_controller/adaptive.py | 37 +++++++++++++++--------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 5297050d..f794cfb6 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -21,33 +21,42 @@ def _select_initial_step( t0: Scalar, y0: PyTree, args: PyTree, - func: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[PyTree[AbstractTerm], Scalar, PyTree, PyTree], PyTree], error_order: Scalar, rtol: Scalar, atol: Scalar, norm: Callable[[PyTree], Scalar], ) -> Scalar: - f0 = func(terms, t0, y0, args) - scale = (atol + ω(y0).call(jnp.abs) * rtol).ω - d0 = norm((y0**ω / scale**ω).ω) - d1 = norm((f0**ω / scale**ω).ω) - - _cond = (d0 < 1e-5) | (d1 < 1e-5) - _d1 = jnp.where(_cond, 1, d1) - h0 = jnp.where(_cond, 1e-6, 0.01 * (d0 / _d1)) + def fn(carry): + t, y, _h0, _d1, _f, _ = carry + f = func(terms, t, y, args) + return t, y, _h0, _d1, _f, f + + def intermediate(carry): + _, _, _, _, _, f0 = carry + d0 = norm((y0**ω / scale**ω).ω) + d1 = norm((f0**ω / scale**ω).ω) + _cond = (d0 < 1e-5) | (d1 < 1e-5) + _d1 = jnp.where(_cond, 1, d1) + h0 = jnp.where(_cond, 1e-6, 0.01 * (d0 / _d1)) + t1 = t0 + h0 + y1 = (y0**ω + h0 * f0**ω).ω + return t1, y1, h0, d1, f0, f0 - t1 = t0 + h0 - y1 = (y0**ω + h0 * f0**ω).ω - f1 = func(terms, t1, y1, args) + scale = (atol + ω(y0).call(jnp.abs) * rtol).ω + dummy_h = t0 + dummy_d = eqxi.eval_empty(norm, y0) + dummy_f = eqxi.eval_empty(lambda: func(terms, t0, y0, args)) + _, _, h0, d1, f0, f1 = eqxi.scan_trick( + fn, [intermediate], (t0, y0, dummy_h, dummy_d, dummy_f, dummy_f) + ) d2 = norm(((f1**ω - f0**ω) / scale**ω).ω) / h0 - max_d = jnp.maximum(d1, d2) h1 = jnp.where( max_d <= 1e-15, jnp.maximum(1e-6, h0 * 1e-3), (0.01 / max_d) ** (1 / error_order), ) - return jnp.minimum(100 * h0, h1) From 6310880eb6ece8693577d18cc45773478b32e6c0 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 13 May 2023 19:35:26 -0700 Subject: [PATCH 3/7] Removed scan_stages in favour of unconditionally using eqxi.scan --- benchmarks/scan_stages.py | 69 -------- benchmarks/scan_stages_cnf.py | 96 ----------- diffrax/adjoint.py | 8 +- diffrax/solver/runge_kutta.py | 284 +++++++++++--------------------- docs/api/solvers/ode_solvers.md | 4 +- docs/api/solvers/sde_solvers.md | 2 - docs/further_details/faq.md | 3 +- test/helpers.py | 30 ++-- 8 files changed, 118 insertions(+), 378 deletions(-) delete mode 100644 benchmarks/scan_stages.py delete mode 100644 benchmarks/scan_stages_cnf.py diff --git a/benchmarks/scan_stages.py b/benchmarks/scan_stages.py deleted file mode 100644 index a1f443c0..00000000 --- a/benchmarks/scan_stages.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`. - -On my relatively beefy CPU-only machine: -``` -scan_stages=True -Compile+run time 1.8253102810122073 -Run time 0.00017526978626847267 - -scan_stages=False -Compile+run time 10.679616351146251 -Run time 0.00021236995235085487 -``` -""" - -import functools as ft -import timeit - -import diffrax as dfx -import equinox as eqx -import jax.numpy as jnp -import jax.random as jr - - -def _weight(in_, out, key): - return [[w_ij for w_ij in w_i] for w_i in jr.normal(key, (out, in_))] - - -class VectorField(eqx.Module): - weights: list - - def __init__(self, in_, out, width, depth, *, key): - keys = jr.split(key, depth + 1) - self.weights = [_weight(in_, width, keys[0])] - for i in range(1, depth): - self.weights.append(_weight(width, width, keys[i])) - self.weights.append(_weight(width, out, keys[depth])) - - def __call__(self, t, y, args): - # Inefficient computation graph to make a toy example more expensive. - y = [y_i for y_i in y] - for w in self.weights: - y = [sum(w_ij * y_j for w_ij, y_j in zip(w_i, y)) for w_i in w] - return jnp.stack(y) - - -def run(scan_stages): - vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0)) - term = dfx.ODETerm(vf) - solver = dfx.Dopri8(scan_stages=scan_stages) - stepsize_controller = dfx.PIDController(rtol=1e-3, atol=1e-6) - t0 = 0 - t1 = 1 - dt0 = None - - @eqx.filter_jit - def solve(y0): - return dfx.diffeqsolve( - term, solver, t0, t1, dt0, y0, stepsize_controller=stepsize_controller - ) - - solve_ = ft.partial(solve, jnp.array([1.0])) - print(f"scan_stages={scan_stages}") - print("Compile+run time", timeit.timeit(solve_, number=1)) - print("Run time", timeit.timeit(solve_, number=1)) - - -run(scan_stages=True) -print() -run(scan_stages=False) diff --git a/benchmarks/scan_stages_cnf.py b/benchmarks/scan_stages_cnf.py deleted file mode 100644 index 3b8bbfa9..00000000 --- a/benchmarks/scan_stages_cnf.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`. - -On my CPU-only machine: -``` -bash> python scan_stages_cnf.py --scan_stages=False --backsolve=False -Compile+run time 79.18114789901301 -Run time 0.16631506383419037 - -bash> python scan_stages_cnf.py --scan_stages=False --backsolve=True -Compile+run time 28.233896102989092 -Run time 0.021237157052382827 - -bash> python scan_stages_cnf.py --scan_stages=True --backsolve=False -Compile+run time 37.9795492868870 -Run time 0.16300765215419233 - -bash> python scan_stages_cnf.py --scan_stages=True --backsolve=True -Compile+run time 12.199542510090396 -Run time 0.024600893026217818 -``` - -(Not forgetting that --backsolve=True produces only approximate gradients, so the fact -that it obtains better compile time and run time doesn't mean it's always the best -choice.) -""" - -# This benchmark is adapted from -# https://github.com/patrick-kidger/diffrax/issues/94#issuecomment-1140527134 - -import functools as ft -import timeit - -import diffrax -import equinox as eqx -import jax -import jax.nn as jnn -import jax.numpy as jnp -import jax.random as jr -import jax.scipy as jsp - - -def vector_field_prob(t, input, model): - y, _ = input - f, vjp_fn = jax.vjp(model, y) - (size,) = y.shape - eye = jnp.eye(size) - (dfdy,) = jax.vmap(vjp_fn)(eye) - logp = jnp.trace(dfdy) - return f, logp - - -@eqx.filter_vmap(in_axes=(None, 0, None, None)) -def log_prob(model, y0, scan_stages, backsolve): - term = diffrax.ODETerm(vector_field_prob) - solver = diffrax.Dopri5(scan_stages=scan_stages) - stepsize_controller = diffrax.PIDController(rtol=1.4e-8, atol=1.4e-8) - if backsolve: - adjoint = diffrax.BacksolveAdjoint() - else: - adjoint = diffrax.RecursiveCheckpointAdjoint() - sol = diffrax.diffeqsolve( - term, - solver, - t0=0.0, - t1=0.5, - dt0=0.05, - y0=(y0, 0.0), - args=model, - stepsize_controller=stepsize_controller, - adjoint=adjoint, - ) - (y1,), (log_prob,) = sol.ys - return log_prob + jsp.stats.norm.logpdf(y1).sum(0) - - -@eqx.filter_jit -@eqx.filter_grad -def solve(model, inputs, scan_stages, backsolve): - return -log_prob(model, inputs, scan_stages, backsolve).mean() - - -def run(scan_stages, backsolve): - mkey, dkey = jr.split(jr.PRNGKey(0), 2) - model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey) - x = jr.normal(dkey, (256, 2)) - solve2 = ft.partial(solve, model, x, scan_stages, backsolve) - print(f"scan_stages={scan_stages}, backsolve={backsolve}") - print("Compile+run time", timeit.timeit(solve2, number=1)) - print("Run time", timeit.timeit(solve2, number=1)) - print() - - -run(scan_stages=False, backsolve=False) -run(scan_stages=False, backsolve=True) -run(scan_stages=True, backsolve=False) -run(scan_stages=True, backsolve=True) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index b0000454..c7f75db4 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -14,7 +14,7 @@ from .ad import implicit_jvp from .heuristics import is_sde, is_unsafe_sde from .saveat import save_y, SaveAt, SubSaveAt -from .solver import AbstractItoSolver, AbstractStratonovichSolver +from .solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver from .term import AbstractTerm, AdjointTerm @@ -332,6 +332,7 @@ class DirectAdjoint(AbstractAdjoint): def loop( self, *, + solver, max_steps, terms, throw, @@ -362,10 +363,15 @@ def loop( else: kind = "bounded" msg = None + # Support forward-mode autodiff. + # TODO: remove this hack once we can JVP through custom_vjps. + if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax") inner_while_loop = ft.partial(_inner_loop, kind=kind) outer_while_loop = ft.partial(_outer_loop, kind=kind) final_state = self._loop( **kwargs, + solver=solver, max_steps=max_steps, terms=terms, inner_while_loop=inner_while_loop, diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 859edd23..5d058cfa 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional, Tuple +from typing import Literal, Optional, Tuple, Union import equinox as eqx import equinox.internal as eqxi @@ -16,22 +16,6 @@ from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot -def _scan(*sequences): - for x in sequences: - if x is not None: - length = len(x) - break - else: - raise ValueError("Must have at least one non-None iterable") - - def _check(_x): - assert len(_x) == length - return _x - - sequences = [[None] * length if x is None else _check(x) for x in sequences] - return zip(*sequences) - - # Entries must be np.arrays, and not jnp.arrays, so that we can index into them during # trace time. @dataclass(frozen=True) @@ -199,52 +183,22 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): instance of [`diffrax.CalculateJacobian`][]. """ - scan_stages: bool = False + scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None term_structure = AbstractTerm tableau: eqxi.AbstractClassVar[ButcherTableau] calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] - def _first(self, terms, t0, t1, y0, args): + def _common(self, terms, t0, t1, y0, args): vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) - implicit_first_stage = ( - self.tableau.a_diagonal is not None and self.tableau.a_diagonal[0] != 0 - ) - # The gamut of conditions under which we need to evaluate `f0` or `k0`. - # - # If we're computing the Jacobian at the start of the step, then we - # need this as a linearisation point. + # If the vector field is expensive then we want to use vf_prods instead. + # FSAL implies evaluating just the vector field, since we need to contract + # the same vector field evaluation against two different controls. # - # If the first stage is implicit, then we need this as a predictor for - # where to start iterating from. - # - # If we're not scanning stages then we're definitely not deferring this - # evaluation to the scan loop, so get it done now. - need_f0_or_k0 = ( - self.calculate_jacobian == CalculateJacobian.every_step - or implicit_first_stage - or not self.scan_stages - ) - fsal = self.tableau.fsal - if fsal and vf_expensive: - # If the vector field is expensive then we want to use vf_prods instead. - # FSAL implies evaluating just the vector field, since we need to contract - # the same vector field evaluation against two different controls. - # - # But "evaluating just the vector field" is, as just established, expensive. - fsal = False - if fsal and self.scan_stages and not need_f0_or_k0: - # If we're scanning stages then we'd like to disable FSAL. - # FSAL implies evaluating the vector field in `init` as well as in `step`. - # But `scan_stages` is a please-compile-faster flag, so we should avoid the - # extra tracing. - # - # However we disable-the-disabling if `need_f0_or_k0`, since in this case - # we evaluate `f0` or `k0` anyway, so it wouldn't help. So we might as well - # take advantage of the runtime benefits of FSAL. - fsal = False - return vf_expensive, implicit_first_stage, need_f0_or_k0, fsal + # But "evaluating just the vector field" is, as just established, expensive. + fsal = self.tableau.fsal and not vf_expensive + return vf_expensive, fsal def func( self, @@ -263,7 +217,7 @@ def init( y0: PyTree, args: PyTree, ) -> _SolverState: - _, _, _, fsal = self._first(terms, t0, t1, y0, args) + _, fsal = self._common(terms, t0, t1, y0, args) if fsal: return terms.vf(t0, y0, args) else: @@ -316,20 +270,24 @@ def step( # e.g. we need `ks` to perform dense interpolation if needed. # - _implicit_later_stages = self.tableau.a_diagonal is not None and any( - self.tableau.a_diagonal[1:] != 0 - ) - _vf_expensive, implicit_first_stage, need_f0_or_k0, fsal = self._first( - terms, t0, t1, y0, args + implicit_first_stage = self.tableau.implicit and self.tableau.a_diagonal[0] != 0 + # If we're computing the Jacobian at the start of the step, then we + # need this as a linearisation point. + # + # If the first stage is implicit, then we need this as a predictor for + # where to start iterating from. + need_f0_or_k0 = ( + self.calculate_jacobian == CalculateJacobian.every_step + or implicit_first_stage ) - ssal = self.tableau.ssal - if _implicit_later_stages and fsal: + vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + if self.tableau.implicit and fsal: use_fs = True - elif _vf_expensive: + elif vf_expensive: use_fs = False else: # Choice not as important here; we use ks for minor efficiency reasons. use_fs = False - del _vf_expensive, _implicit_later_stages + del vf_expensive control = terms.contr(t0, t1) dt = t1 - t0 @@ -399,7 +357,7 @@ def step( # Allocate `fs` or `ks` as a place to store the stage evaluations. # - if use_fs or (fsal and self.scan_stages): + if use_fs or fsal: if f0 is None: # Only perform this trace if we have to; tracing can actually be # a bit expensive. @@ -430,11 +388,11 @@ def step( scan_first_stage = False assert self.tableau.a_diagonal is not None diagonal0 = self.tableau.a_diagonal[0] - if self.tableau.diagonal[0] == 1: + if self.tableau.a_diagonal[0] == 1: # No floating point error t0_ = t1 else: - t0_ = t0 + self.tableau.diagonal[0] * dt + t0_ = t0 + self.tableau.a_diagonal[0] * dt if use_fs: if y0 is not None: assert jac_f is not None @@ -459,7 +417,7 @@ def step( result = nonlinear_sol.result del diagonal0, t0_, nonlinear_sol else: - scan_first_stage = self.scan_stages + scan_first_stage = True result = RESULTS.successful if scan_first_stage: @@ -482,82 +440,49 @@ def step( # `scan_first_stage`. # - if self.scan_stages: - - def _vector_tree_dot(_x, _y, _i): - del _i - return vector_tree_dot(_x, _y) - - else: - - def _vector_tree_dot(_x, _y, _i): - return vector_tree_dot(_x, ω(_y)[:_i].ω) - def eval_stage(_carry, _input): _, _, _fs, _ks, _result = _carry _i, _a_lower_i, _a_diagonal_i, _a_predictor_i, _c_i = _input + # Unwrap buffers. Take advantage of the fact that they're initialised at + # zero, so that we don't actually read from a location before its written to + _unsafe_fs_unwrapped = jtu.tree_map(lambda _, x: x[...], fs, _fs) + _unsafe_ks_unwrapped = jtu.tree_map(lambda _, x: x[...], ks, _ks) # # Evaluate the linear combination of previous stages # if use_fs: - _increment = _vector_tree_dot(_a_lower_i, _fs, _i) # noqa: F821 + _increment = vector_tree_dot(_a_lower_i, _unsafe_fs_unwrapped) _increment = terms.prod(_increment, control) else: - _increment = _vector_tree_dot(_a_lower_i, _ks, _i) # noqa: F821 + _increment = vector_tree_dot(_a_lower_i, _unsafe_ks_unwrapped) _yi_partial = (y0**ω + _increment**ω).ω - # - # Is this an implicit or explicit stage? - # - - if self.tableau.a_diagonal is None: - _implicit_stage = False - else: - if self.scan_stages: - if scan_first_stage: # noqa: F821 - _diagonal = self.tableau.a_diagonal - else: - _diagonal = self.tableau.a_diagonal[1:] - _implicit_stage = any(_diagonal != 0) - if _implicit_stage and any(_diagonal == 0): - assert False, ( - "Cannot have a mix of implicit and " - "explicit stages when scanning" - ) - del _diagonal - else: - _implicit_stage = _a_diagonal_i != 0 - # # Figure out if we're computing a vector field ("f") or a # vector-field-product ("k") # # Ask for fi if we're using fs; ask for ki if we're using ks. Makes sense! - # In addition, ask for fi if we're on the last stage and are using - # an FSAL scheme, as we'll be passing that on to the next step. If - # we're scanning the stages then every stage uses the same logic so - # override the last iteration check. + # In addition, ask for fi if we're using an FSAL scheme, as we'll be passing + # that on to the next step. # - _last_iteration = _i == num_stages - 1 - _return_fi = use_fs or (fsal and (self.scan_stages or _last_iteration)) + _return_fi = use_fs or fsal _return_ki = not use_fs - del _last_iteration # # Evaluate the stage # _ti = jnp.where(_c_i == 1, t1, t0 + _c_i * dt) # No floating point error - if _implicit_stage: + if self.tableau.implicit: assert _a_diagonal_i is not None # Predictor for where to start iterating from if _return_fi: - _f_pred = _vector_tree_dot(_a_predictor_i, _fs, _i) # noqa: F821 + _f_pred = vector_tree_dot(_a_predictor_i, _unsafe_fs_unwrapped) else: - _k_pred = _vector_tree_dot(_a_predictor_i, _ks, _i) # noqa: F821 + _k_pred = vector_tree_dot(_a_predictor_i, _unsafe_ks_unwrapped) # Determine Jacobian to use at this stage if self.calculate_jacobian == CalculateJacobian.every_stage: if _return_fi: @@ -660,10 +585,10 @@ def eval_stage(_carry, _input): # if use_fs: - _fs = ω(_fs).at[_i].set(ω(_fi)).ω + _fs = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _fi, _fs) else: - _ks = ω(_ks).at[_i].set(ω(_ki)).ω - if ssal: + _ks = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _ki, _ks) + if self.tableau.ssal: _yi_partial_out = _yi_partial else: _yi_partial_out = None @@ -673,83 +598,72 @@ def eval_stage(_carry, _input): _fi_out = None return (_yi_partial_out, _fi_out, _fs, _ks, _result), None - if self.scan_stages: - if scan_first_stage: - tableau_a_lower = np.zeros((num_stages, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i + 1, : i + 1] = a_lower_i - tableau_a_diagonal = self.tableau.a_diagonal - tableau_a_predictor = self.tableau.a_predictor - tableau_c = np.zeros(num_stages) - tableau_c[1:] = self.tableau.c - i_init = 0 - assert tableau_a_diagonal is None - assert tableau_a_predictor is None - else: - tableau_a_lower = np.zeros((num_stages - 1, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i, : i + 1] = a_lower_i - if self.tableau.a_diagonal is None: - tableau_a_diagonal = None - else: - tableau_a_diagonal = self.tableau.a_diagonal[1:] - if self.tableau.a_predictor is None: - tableau_a_predictor = None - else: - tableau_a_predictor = np.zeros((num_stages - 1, num_stages)) - for i, a_predictor_i in enumerate(self.tableau.a_predictor): - tableau_a_predictor[i, : i + 1] = a_predictor_i - tableau_c = self.tableau.c - i_init = 1 - if ssal: - y_dummy = y0 + # + # Iterate over stages + # + + if scan_first_stage: + tableau_a_lower = np.zeros((num_stages, num_stages)) + for i, a_lower_i in enumerate(self.tableau.a_lower): + tableau_a_lower[i + 1, : i + 1] = a_lower_i + tableau_a_diagonal = self.tableau.a_diagonal + tableau_a_predictor = self.tableau.a_predictor + tableau_c = np.zeros(num_stages) + tableau_c[1:] = self.tableau.c + i_init = 0 + assert tableau_a_diagonal is None + assert tableau_a_predictor is None + else: + tableau_a_lower = np.zeros((num_stages - 1, num_stages)) + for i, a_lower_i in enumerate(self.tableau.a_lower): + tableau_a_lower[i, : i + 1] = a_lower_i + if self.tableau.a_diagonal is None: + tableau_a_diagonal = None else: - y_dummy = None - if fsal: - f_dummy = jtu.tree_map( - lambda x: jnp.zeros(x.shape, dtype=x.dtype), f0_struct - ) + tableau_a_diagonal = self.tableau.a_diagonal[1:] + if self.tableau.a_predictor is None: + tableau_a_predictor = None else: - f_dummy = None - (y1_partial, f1, fs, ks, result), _ = lax.scan( - eval_stage, - (y_dummy, f_dummy, fs, ks, result), - ( - np.arange(i_init, num_stages), - tableau_a_lower, - tableau_a_diagonal, - tableau_a_predictor, - tableau_c, - ), + tableau_a_predictor = np.zeros((num_stages - 1, num_stages)) + for i, a_predictor_i in enumerate(self.tableau.a_predictor): + tableau_a_predictor[i, : i + 1] = a_predictor_i + tableau_c = self.tableau.c + i_init = 1 + if self.tableau.ssal: + y_dummy = y0 + else: + y_dummy = None + if fsal: + f_dummy = jtu.tree_map( + lambda x: jnp.zeros(x.shape, dtype=x.dtype), f0_struct ) - del y_dummy, f_dummy else: - assert not scan_first_stage - if self.tableau.a_diagonal is None: - a_diagonal = None - else: - a_diagonal = self.tableau.a_diagonal[1:] - for i, a_lower_i, a_diagonal_i, a_predictor_i, c_i in _scan( - range(1, num_stages), - self.tableau.a_lower, - a_diagonal, - self.tableau.a_predictor, - self.tableau.c, - ): - (yi_partial, fi, fs, ks, result), _ = eval_stage( - (None, None, fs, ks, result), - (i, a_lower_i, a_diagonal_i, a_predictor_i, c_i), - ) - y1_partial = yi_partial - f1 = fi - del a_diagonal, yi_partial, fi - del scan_first_stage, _vector_tree_dot + f_dummy = None + if self.scan_kind is None: + scan_kind = "checkpointed" + else: + scan_kind = self.scan_kind + (y1_partial, f1, fs, ks, result), _ = eqxi.scan( + eval_stage, + (y_dummy, f_dummy, fs, ks, result), + ( + np.arange(i_init, num_stages), + tableau_a_lower, + tableau_a_diagonal, + tableau_a_predictor, + tableau_c, + ), + buffers=lambda x: (x[2], x[3]), # fs and ks + kind=scan_kind, + checkpoints="all", + ) + del y_dummy, f_dummy, scan_first_stage # # Compute step output # - if ssal: + if self.tableau.ssal: y1 = y1_partial else: if use_fs: diff --git a/docs/api/solvers/ode_solvers.md b/docs/api/solvers/ode_solvers.md index 04087a4c..24c8e731 100644 --- a/docs/api/solvers/ode_solvers.md +++ b/docs/api/solvers/ode_solvers.md @@ -14,8 +14,6 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#ordinary These methods are suitable for most problems. -Each of these takes a `scan_stages` argument at initialisation, defaulting to `False`. Set to `True` to substantially improve compilation speed in return for a slight reduction in runtime speed. - ::: diffrax.Euler selection: members: false @@ -54,7 +52,7 @@ Each of these takes a `scan_stages` argument at initialisation, defaulting to `F These methods are suitable for stiff problems. -Each of these takes a `scan_stages` argument at initialisation, which [behaves the same as for the explicit Runge--Kutta methods](#explicit-runge-kutta-erk-methods). In addition, each of these takes a `nonlinear_solver` argument at initialisation, defaulting to a Newton solver, which is used to solve the implicit problem at each step. See the page on [nonlinear solvers](../nonlinear_solver.md). +Each of these takes a `nonlinear_solver` argument at initialisation, defaulting to a Newton solver, which is used to solve the implicit problem at each step. See the page on [nonlinear solvers](../nonlinear_solver.md). ::: diffrax.ImplicitEuler selection: diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index 1a3db677..849a6dec 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -22,8 +22,6 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast ### Explicit Runge--Kutta (ERK) methods -Each of these takes a `scan_stages` argument at initialisation, which [behaves the same as as the explicit Runge--Kutta methods for ODEs](./ode_solvers.md#explicit-runge-kutta-erk-methods). - ::: diffrax.Euler selection: members: false diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index 146502dd..fe30a402 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -2,7 +2,6 @@ ### Compilation is taking a long time. -- Use `scan_stages=True`, e.g. `Tsit5(scan_stages=True)`. This is supported for all Runge--Kutta methods. This will substantially reduce compile time at the expense of a slightly slower run time. - Set `dt0=`, e.g. `diffeqsolve(..., dt0=0.01)`. In contrast `dt0=None` will determine the initial step size automatically, but will increase compilation time. - Prefer `SaveAt(t0=True, t1=True)` over `SaveAt(ts=[t0, t1])`, if possible. - It's an internal (subject-to-change) API, but you can also try adding `equinox.internal.noinline` to your vector field (s), e.g. `ODETerm(noinline(...))`. This stages the vector field out into a separate compilation graph. This can greatly decrease compilation time whilst greatly increasing runtime. @@ -18,7 +17,7 @@ The equivalent solver in Diffrax is: diffeqsolve( ..., dt0=None, - solver=Dopri5(scan_stages=True), + solver=Dopri5(), stepsize_controller=PIDController(rtol=1.4e-8, atol=1.4e-8), adjoint=BacksolveAdjoint(), max_steps=None, diff --git a/test/helpers.py b/test/helpers.py index 265ac94b..4a5fa749 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -9,30 +9,20 @@ all_ode_solvers = ( - diffrax.Bosh3(scan_stages=False), - diffrax.Bosh3(scan_stages=True), - diffrax.Dopri5(scan_stages=False), - diffrax.Dopri5(scan_stages=True), - diffrax.Dopri8(scan_stages=False), - diffrax.Dopri8(scan_stages=True), + diffrax.Bosh3(), + diffrax.Dopri5(), + diffrax.Dopri8(), diffrax.Euler(), - diffrax.Ralston(scan_stages=False), - diffrax.Ralston(scan_stages=True), - diffrax.Midpoint(scan_stages=False), - diffrax.Midpoint(scan_stages=True), - diffrax.Heun(scan_stages=False), - diffrax.Heun(scan_stages=True), + diffrax.Ralston(), + diffrax.Midpoint(), + diffrax.Heun(), diffrax.LeapfrogMidpoint(), diffrax.ReversibleHeun(), - diffrax.Tsit5(scan_stages=False), - diffrax.Tsit5(scan_stages=True), + diffrax.Tsit5(), diffrax.ImplicitEuler(), - diffrax.Kvaerno3(scan_stages=False), - diffrax.Kvaerno3(scan_stages=True), - diffrax.Kvaerno4(scan_stages=False), - diffrax.Kvaerno4(scan_stages=True), - diffrax.Kvaerno5(scan_stages=False), - diffrax.Kvaerno5(scan_stages=True), + diffrax.Kvaerno3(), + diffrax.Kvaerno4(), + diffrax.Kvaerno5(), ) From 3b28771de96bcb082fb7f38237da0361c4b10d99 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 14 May 2023 19:14:13 -0700 Subject: [PATCH 4/7] Added new step --- .flake8 | 2 +- diffrax/solver/base.py | 15 +- diffrax/solver/runge_kutta.py | 313 ++++++++++++++++++++++++++++++---- test/test_solver.py | 52 ++++++ 4 files changed, 340 insertions(+), 42 deletions(-) diff --git a/.flake8 b/.flake8 index 76ba4226..eb6bc5d3 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,4 @@ [flake8] max-line-length = 88 -ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731 +ignore = W291,W503,W504,E121,E123,E126,E203,E402,E701,E702,E731,F722 per-file-ignores = __init__.py: F401 diff --git a/diffrax/solver/base.py b/diffrax/solver/base.py index 8fff2c24..849d22ff 100644 --- a/diffrax/solver/base.py +++ b/diffrax/solver/base.py @@ -2,6 +2,7 @@ from typing import Callable, Optional, Tuple, Type, TypeVar import equinox as eqx +import equinox.internal as eqxi import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -41,16 +42,10 @@ class AbstractSolver(eqx.Module, metaclass=_MetaAbstractSolver): structure of `terms` in `diffeqsolve(terms, ...)`. """ - @property - @abc.abstractmethod - def term_structure(self) -> PyTree[Type[AbstractTerm]]: - """What PyTree structure `terms` should have when used with this solver.""" - - # On the type: frequently just Type[AbstractLocalInterpolation] - @property - @abc.abstractmethod - def interpolation_cls(self) -> Callable[..., AbstractLocalInterpolation]: - """How to interpolate the solution in between steps.""" + # What PyTree structure `terms` should have when used with this solver. + term_structure: eqxi.AbstractClassVar[PyTree[Type[AbstractTerm]]] + # How to interpolate the solution in between steps. + interpolation_cls: eqxi.AbstractClassVar[Callable[..., AbstractLocalInterpolation]] def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]: """Order of the solver for solving ODEs.""" diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 5d058cfa..81c84703 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Literal, Optional, Tuple, Union +from typing import Literal, Optional, Union import equinox as eqx import equinox.internal as eqxi @@ -9,15 +9,15 @@ import jax.tree_util as jtu import numpy as np from equinox.internal import ω +from jaxtyping import Array, Bool, PyTree, Scalar -from ..custom_types import Bool, DenseInfo, PyTree, Scalar +from ..custom_types import DenseInfo from ..solution import is_okay, RESULTS, update_result -from ..term import AbstractTerm +from ..term import AbstractTerm, ODETerm, WrapTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot -# Entries must be np.arrays, and not jnp.arrays, so that we can index into them during -# trace time. +# Not a pytree node! @dataclass(frozen=True) class ButcherTableau: """The Butcher tableau for an explicit or diagonal Runge--Kutta method.""" @@ -26,17 +26,15 @@ class ButcherTableau: c: np.ndarray b_sol: np.ndarray b_error: np.ndarray - a_lower: Tuple[np.ndarray, ...] + a_lower: tuple[np.ndarray, ...] # Implicit RK methods a_diagonal: Optional[np.ndarray] = None - a_predictor: Optional[Tuple[np.ndarray, ...]] = None + a_predictor: Optional[tuple[np.ndarray, ...]] = None - # Determine the use of fast-paths + # Properties implied by the above tableaus, e.g. used to define fast-paths. ssal: bool = field(init=False) fsal: bool = field(init=False) - - # Informational implicit: bool = field(init=False) num_stages: int = field(init=False) @@ -58,11 +56,8 @@ def __post_init__(self): if self.a_diagonal is None: assert self.a_predictor is None - implicit = False else: assert self.a_predictor is not None - implicit = True - if implicit: assert self.a_diagonal.ndim == 1 assert self.c.shape[0] + 1 == self.a_diagonal.shape[0] assert len(self.a_lower) == len(self.a_predictor) @@ -87,8 +82,7 @@ def __post_init__(self): "fsal", lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, ) - - object.__setattr__(self, "implicit", implicit) + object.__setattr__(self, "implicit", self.a_diagonal is not None) object.__setattr__(self, "num_stages", len(self.b_sol)) @@ -139,7 +133,7 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta): every_stage = "every_stage" -_SolverState = Optional[PyTree] +_SolverState = Optional[tuple[Bool[Scalar, ""], PyTree[Array]]] # TODO: examine termination criterion for Newton iteration @@ -169,6 +163,28 @@ def _implicit_relation_k(ki, nonlinear_solve_args): return diff +_unused = eqxi.str2jax("unused") # Sentinel that can be passed into `while_loop` etc. + + +def _is_term(x): + return isinstance(x, AbstractTerm) + + +# Not a pytree +class _Leaf: + def __init__(self, value): + self.value = value + + +def _sum(*x): + assert len(x) > 0 + # Not sure if the builtin does the right thing with JAX tracers? + total = x[0] + for xi in x[1:]: + total = total + xi + return total + + class AbstractRungeKutta(AbstractAdaptiveSolver): """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.) @@ -185,33 +201,71 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None - term_structure = AbstractTerm - - tableau: eqxi.AbstractClassVar[ButcherTableau] + tableau: eqxi.AbstractClassVar[PyTree[ButcherTableau]] calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + seen_implicit = False + num_stages = None + + def _f(t: ButcherTableau): + nonlocal seen_implicit + nonlocal num_stages + if num_stages is None: + num_stages = t.num_stages + if t.num_stages != num_stages: + raise ValueError("Tableaus must all have the same number of stages") + if t.implicit: + if seen_implicit: + raise ValueError("May have at most one implicit tableau") + else: + seen_implicit = True + return AbstractTerm + + if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau + term_structure = jtu.tree_map(_f, cls.tableau) + # Allow subclasses to specify more specific term structures if desired, e.g. + # (ODETerm, ControlTerm) rather than (AbstractTerm, AbtstractTerm). + try: + term_structure2 = cls.term_structure + except AttributeError: + cls.term_structure = term_structure + else: + x = jtu.tree_structure(term_structure, is_leaf=_is_term) + x2 = jtu.tree_structure(term_structure2, is_leaf=_is_term) + if x != x2: + raise ValueError("Mismatched term structures") + def _common(self, terms, t0, t1, y0, args): - vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) + # For simplicity we share `vf_expensive` and `fsal` across all tableaus. + # TODO: could we make these work per-tableau? + vf_expensive = False + fsal = True + terms = jtu.tree_leaves(terms, is_leaf=_is_term) + tableaus = jtu.tree_leaves(self.tableau) + assert len(terms) == len(tableaus) + for term, tableau in zip(terms, tableaus): + vf_expensive = vf_expensive or term.is_vf_expensive(t0, t1, y0, args) + fsal = fsal and tableau.fsal # If the vector field is expensive then we want to use vf_prods instead. # FSAL implies evaluating just the vector field, since we need to contract # the same vector field evaluation against two different controls. - # - # But "evaluating just the vector field" is, as just established, expensive. - fsal = self.tableau.fsal and not vf_expensive + fsal = fsal and not vf_expensive return vf_expensive, fsal def func( self, - terms: AbstractTerm, + terms: PyTree[AbstractTerm], t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.vf(t0, y0, args) + return jtu.tree_map(lambda t: t.vf(t0, y0, args), terms, is_leaf=_is_term) def init( self, - terms: AbstractTerm, + terms: PyTree[AbstractTerm], t0: Scalar, t1: Scalar, y0: PyTree, @@ -219,21 +273,218 @@ def init( ) -> _SolverState: _, fsal = self._common(terms, t0, t1, y0, args) if fsal: - return terms.vf(t0, y0, args) + first_step = jnp.array(True) + if (type(terms) is WrapTerm) and (type(terms.term) is ODETerm): + # Privileged optimisation for the common case + f0 = jtu.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), y0) + else: + # Must be initialiased at zero as it is inserted into `ks` which must be + # initialised at zero. + f0 = eqxi.eval_zero(lambda: self.func(terms, t0, y0, args)) + return first_step, f0 else: return None def step( self, - terms: AbstractTerm, + terms: PyTree[AbstractTerm], t0: Scalar, t1: Scalar, y0: PyTree, args: PyTree, solver_state: _SolverState, made_jump: Bool, - ) -> Tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: + ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: + # + # Some Runge--Kutta methods have special structure that we can use to improve + # efficiency. + # + # The famous one is FSAL; "first same as last". That is, the final evaluation + # of the vector field on the previous step is the same as the first evaluation + # on the subsequent step. We can reuse it and save an evaluation. + # However note that this requires saving a vf evaluation, not a + # vf-control-product. (This comes up when we have a different control on the + # next step, e.g. as with adaptive step sizes, or with SDEs.) + # As such we disable FSAL if a vf is expensive and a vf-control-product is + # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. + # For this SDE, the vf-control product is a vector-Jacobian product, which is + # notably cheaper than evaluating a full Jacobian.) + # + # Next we have SSAL; "solution same as last". That is, the output of the step + # has already been calculated during the internal stage calculations. We can + # reuse those and save a dot product. + # + # Finally we have a choice whether to save and work with vector field + # evaluations (fs), or to save and work with (vector field)-control products + # (ks). + # The former is needed for implicit FSAL solvers: they need to obtain the + # final f1 for the FSAL property, which means they need to do the implicit + # solve in vf-space rather than (vf-control-product)-space, which means they + # need to use `fs` to predict the initial point for the root finding operation. + # Meanwhile the latter is needed when solving optimise-then-discretise adjoint + # SDEs, for which vector field evaluations are prohibitively expensive, and we + # must necessarily work only with the (much cheaper) vf-control-products. (In + # this case this is the difference between computing a Jacobian and computing a + # vector-Jacobian product.) + # For other problems, we choose to use `ks`. This doesn't have a strong + # rationale although it does have some minor efficiency points in its favour, + # e.g. we need `ks` to perform dense interpolation if needed. + # + + assert jtu.tree_structure(terms, is_leaf=_is_term) == jtu.tree_structure( + self.tableau + ) + + # Structure of `terms` and `self.tableau`. + def t_map(fn, *trees): + def _fn(_, *_trees): + return fn(*_trees) + + return jtu.tree_map(_fn, self.tableau, *trees) + + def t_leaves(tree): + return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] + + # Structure of `y` and `k`. + # (but not `f`, which can be arbitrary and different) + def s_map(fn, *trees): + def _fn(_, *_trees): + return fn(*_trees) + return jtu.tree_map(_fn, y0, *trees) + + def ts_map(fn, *trees): + return t_map(lambda *_trees: s_map(fn, *_trees), *trees) + + control = t_map(lambda term_i: term_i.contr(t0, t1), terms) + dt = t1 - t0 + + def vf(t, y): + _vf = lambda term_i, t_i: term_i.vf(t_i, y, args) + return t_map(_vf, terms, t) + + def vf_prod(t, y): + _vf = lambda term_i, t_i, control_i: term_i.vf_prod(t_i, y, args, control_i) + return t_map(_vf, terms, t, control) + + def prod(f): + _prod = lambda term_i, f_i, control_i: term_i.prod(f_i, control_i) + return t_map(_prod, terms, f, control) + + num_stages = jtu.tree_leaves(self.tableau)[0].num_stages + is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + if fsal: + assert solver_state is not None + first_step, f0 = solver_state + stage_index = jnp.where(first_step, 0, 1) + # `made_jump` can be a tracer, hence the `is`. + if made_jump is False: + # Fast-path for compilation in the common case. + k0 = prod(f0) + else: + _t0 = t_map(lambda _: t0) + k0 = lax.cond(made_jump, lambda: vf_prod(_t0, y0), lambda: prod(f0)) + del _t0 + else: + f0 = _unused + k0 = _unused + stage_index = 0 + del solver_state + + # Must be initialised at zero as we do matmuls against the partially-filled + # array. + ks = t_map( + lambda: s_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0), + ) + if fsal: + ks = ts_map(lambda x, xs: xs.at[0].set(x), k0, ks) + + def embed_a_lower(tableau): + tableau_a_lower = np.zeros((num_stages, num_stages)) + for i, a_lower_i in enumerate(tableau.a_lower): + tableau_a_lower[i + 1, : i + 1] = a_lower_i + return jnp.asarray(tableau_a_lower) + + def embed_c(tableau): + tableau_c = np.zeros(num_stages) + tableau_c[1:] = tableau.c + return jnp.asarray(tableau_c) + + tableau_a_lower = t_map(embed_a_lower, self.tableau) + tableau_c = t_map(embed_c, self.tableau) + + def cond_fun(val): + _stage_index, *_ = val + return _stage_index < num_stages + + def body_fun(val): + stage_index, _, _, _, ks = val + a_lower_i = t_map(lambda t: t[stage_index], tableau_a_lower) + c_i = t_map(lambda t: t[stage_index], tableau_c) + # Unwrap buffers. This is only valid (=correct under autodiff) because we + # follow a triangular pattern and don't read from a location before it's + # written to, or write to the same location twice. + # (The reads in the matmuls don't count, as we initialise at zero.) + unsafe_ks = ts_map(lambda x: x[...], ks) + increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks) + yi_partial = s_map(_sum, y0, *t_leaves(increment)) + # No floating point error + ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i) + if fsal: + assert not is_vf_expensive + fi = vf(ti, yi_partial) + ki = prod(fi) + else: + fi = _unused + ki = vf_prod(ti, yi_partial) + ks = ts_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks) + return stage_index + 1, yi_partial, increment, fi, ks + + def buffers(val): + _, _, _, _, ks = val + return ks + + init_val = (stage_index, y0, t_map(lambda: y0), f0, ks) + final_val = eqxi.while_loop( + cond_fun, + body_fun, + init_val, + max_steps=num_stages, + buffers=buffers, + kind="checkpointed" if self.scan_kind is None else self.scan_kind, + checkpoints=num_stages, + ) + _, y1_partial, increment, f1, ks = final_val + + if all(tableau.ssal for tableau in jtu.tree_leaves(self.tableau)): + y1 = y1_partial + else: + increment = t_map( + lambda t, k, i: i if t.ssal else vector_tree_dot(t.b_sol, k), + self.tableau, + ks, + increment, + ) + y1 = s_map(_sum, y0, *t_leaves(increment)) + y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), self.tableau, ks) + dense_info = dict(y0=y0, y1=y1, k=ks) + if fsal: + new_solver_state = False, f1 + else: + new_solver_state = None + result = RESULTS.successful + return y1, y_error, dense_info, new_solver_state, result + + def old_step( + self, + terms: AbstractTerm, + t0: Scalar, + t1: Scalar, + y0: PyTree, + args: PyTree, + solver_state: _SolverState, + made_jump: Bool, + ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: # # Some Runge--Kutta methods have special structure that we can use to improve # efficiency. @@ -366,7 +617,7 @@ def step( f0_struct = jax.eval_shape(lambda: f0) # noqa: F821 # else f0_struct deliberately left undefined, and is unused. - num_stages = len(self.tableau.c) + 1 + num_stages = self.tableau.num_stages if use_fs: fs = jtu.tree_map(lambda f: jnp.zeros((num_stages,) + f.shape), f0_struct) ks = None @@ -444,7 +695,7 @@ def eval_stage(_carry, _input): _, _, _fs, _ks, _result = _carry _i, _a_lower_i, _a_diagonal_i, _a_predictor_i, _c_i = _input # Unwrap buffers. Take advantage of the fact that they're initialised at - # zero, so that we don't actually read from a location before its written to + # zero, so that we don't really read from a location before its written to. _unsafe_fs_unwrapped = jtu.tree_map(lambda _, x: x[...], fs, _fs) _unsafe_ks_unwrapped = jtu.tree_map(lambda _, x: x[...], ks, _ks) diff --git a/test/test_solver.py b/test/test_solver.py index 6de5fcbe..ae0275e3 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -1,4 +1,8 @@ import diffrax +import equinox as eqx +import jax.numpy as jnp +import jax.random as jr +import pytest def test_half_solver(): @@ -43,3 +47,51 @@ def test_implicit_euler_adaptive(): ) assert out1.result == diffrax.RESULTS.implicit_nonconvergence assert out2.result == diffrax.RESULTS.successful + + +def test_multiple_tableau1(): + class DoubleDopri5(diffrax.AbstractRungeKutta): + tableau = (diffrax.Dopri5.tableau, diffrax.Dopri5.tableau) + calculate_jacobian = diffrax.CalculateJacobian.never + + def interpolation_cls(self, *, k, **kwargs): + return diffrax.LocalLinearInterpolation(**kwargs) + + mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0)) + mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1)) + + term1 = diffrax.ODETerm(lambda t, y, args: mlp1(y)) + term2 = diffrax.ODETerm(lambda t, y, args: mlp2(y)) + t0 = 0 + t1 = 1 + dt0 = 0.1 + y0 = jnp.array([1.0, 2.0]) + out_a = diffrax.diffeqsolve( + diffrax.MultiTerm(term1, term2), + diffrax.Dopri5(), + t0, + t1, + dt0, + y0, + ) + out_b = diffrax.diffeqsolve( + (term1, term2), + DoubleDopri5(), + t0, + t1, + dt0, + y0, + ) + assert jnp.allclose(out_a.ys, out_b.ys, rtol=1e-8, atol=1e-8) + + +def test_multiple_tableau2(): + # Different number of stages + with pytest.raises(ValueError): + + class Dopri5Tsit5(diffrax.AbstractRungeKutta): + tableau = (diffrax.Dopri5.tableau, diffrax.Bosh3.tableau) + calculate_jacobian = diffrax.CalculateJacobian.never + + def interpolation_cls(self, *, k, **kwargs): + return diffrax.LocalLinearInterpolation(**kwargs) From a94c55b0b000d1da1aa173241d8e78cfe56e1f4c Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 15 May 2023 19:43:42 -0700 Subject: [PATCH 5/7] Tidied up how term_structure works, to allow it to specify MultiTerms. --- diffrax/__init__.py | 1 + diffrax/adjoint.py | 7 ++ diffrax/integrate.py | 68 ++++++++++--- diffrax/solver/__init__.py | 1 + diffrax/solver/base.py | 3 +- diffrax/solver/euler_heun.py | 14 +-- diffrax/solver/milstein.py | 22 ++--- diffrax/solver/runge_kutta.py | 143 ++++++++++++++++----------- diffrax/term.py | 29 +++++- docs/api/solvers/abstract_solvers.md | 5 + docs/api/solvers/sde_solvers.md | 5 +- docs/api/terms.md | 22 +++-- test/test_solver.py | 43 +++++++- 13 files changed, 255 insertions(+), 108 deletions(-) diff --git a/diffrax/__init__.py b/diffrax/__init__.py index dec7a5f6..ef4c06a3 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -65,6 +65,7 @@ Kvaerno5, LeapfrogMidpoint, Midpoint, + MultiButcherTableau, Ralston, ReversibleHeun, SemiImplicitEuler, diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index c7f75db4..ffd56b9f 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -541,6 +541,8 @@ def _loop_backsolve_bwd( zeros_like_diff_args = jtu.tree_map(jnp.zeros_like, diff_args) zeros_like_diff_terms = jtu.tree_map(jnp.zeros_like, diff_terms) del diff_args, diff_terms + # TODO: have this look inside MultiTerms? Need to think about the math. i.e.: + # is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm) adjoint_terms = jtu.tree_map( AdjointTerm, terms, is_leaf=lambda x: isinstance(x, AbstractTerm) ) @@ -768,6 +770,11 @@ def loop( "`BacksolveAdjoint` will only produce the correct solution for " "Stratonovich SDEs." ) + if jtu.tree_structure(solver.term_structure) != jtu.tree_structure(0): + raise NotImplementedError( + "`diffrax.BacksolveAdjoint` is only compatible with solvers that take " + "a single term." + ) y = init_state.y init_state = eqx.tree_at(lambda s: s.y, init_state, object()) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index c7e181d8..8d81e623 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -1,7 +1,7 @@ import functools as ft import typing import warnings -from typing import Any, Callable, Optional +from typing import Any, Callable, get_args, get_origin, Optional, Tuple import equinox as eqx import equinox.internal as eqxi @@ -16,7 +16,15 @@ from .heuristics import is_sde, is_unsafe_sde from .saveat import SaveAt, SubSaveAt from .solution import is_okay, is_successful, RESULTS, Solution -from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler +from .solver import ( + AbstractItoSolver, + AbstractSolver, + AbstractStratonovichSolver, + Euler, + EulerHeun, + ItoMilstein, + StratonovichMilstein, +) from .step_size_controller import ( AbstractAdaptiveStepSizeController, AbstractStepSizeController, @@ -24,7 +32,7 @@ PIDController, StepTo, ) -from .term import AbstractTerm, WrapTerm +from .term import AbstractTerm, MultiTerm, ODETerm, WrapTerm class SaveState(eqx.Module): @@ -57,6 +65,28 @@ def _is_none(x): return x is None +def _term_compatible(terms, term_structure): + def _check(term_cls, term): + if get_origin(term_cls) is MultiTerm: + if isinstance(term, MultiTerm): + [_tmp] = get_args(term_cls) + assert get_origin(_tmp) in (tuple, Tuple), "Malformed term_structure" + if not _term_compatible(term.terms, get_args(_tmp)): + raise ValueError + else: + raise ValueError + else: + if not isinstance(term, term_cls): + raise ValueError + + try: + jtu.tree_map(_check, term_structure, terms) + except ValueError: + # ValueError may also arise from mismatched tree structures + return False + return True + + def _is_subsaveat(x: Any) -> bool: return isinstance(x, SubSaveAt) @@ -541,19 +571,25 @@ def diffeqsolve( pred = (t1 - t0) * dt0 < 0 dt0 = eqxi.error_if(dt0, pred, msg) + # Backward compatibility + if isinstance( + solver, (EulerHeun, ItoMilstein, StratonovichMilstein) + ) and _term_compatible(terms, (ODETerm, AbstractTerm)): + warnings.warn( + "Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to " + f"{solver.__class__.__name__} is deprecated in favour of " + "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " + "the same terms can now be passed used for both general and SDE-specific " + "solvers!" + ) + terms = MultiTerm(*terms) + # Error checking - term_leaves, term_structure = jtu.tree_flatten( - terms, is_leaf=lambda x: isinstance(x, AbstractTerm) - ) - term_leaves2, term_structure2 = jtu.tree_flatten(solver.term_structure) - if term_structure != term_structure2 or any( - not isinstance(x, y) for x, y in zip(term_leaves, term_leaves2) - ): + if not _term_compatible(terms, solver.term_structure): raise ValueError( "`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with " f"structure {solver.term_structure}" ) - del term_leaves, term_structure, term_leaves2, term_structure2 if is_sde(terms): if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): @@ -627,10 +663,16 @@ def _promote(yi): _get_subsaveat_ts, saveat, replace_fn=lambda ts: ts * direction ) stepsize_controller = stepsize_controller.wrap(direction) + + def _wrap(term): + assert isinstance(term, AbstractTerm) + assert not isinstance(term, MultiTerm) + return WrapTerm(term, direction) + terms = jtu.tree_map( - lambda t: WrapTerm(t, direction), + _wrap, terms, - is_leaf=lambda x: isinstance(x, AbstractTerm), + is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm), ) # Stepsize controller gets an opportunity to modify the solver. diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index 30964682..f3a108c0 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -30,6 +30,7 @@ AbstractSDIRK, ButcherTableau, CalculateJacobian, + MultiButcherTableau, ) from .semi_implicit_euler import SemiImplicitEuler from .tsit5 import Tsit5 diff --git a/diffrax/solver/base.py b/diffrax/solver/base.py index 849d22ff..bcfa7232 100644 --- a/diffrax/solver/base.py +++ b/diffrax/solver/base.py @@ -6,8 +6,9 @@ import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu +from jaxtyping import PyTree -from ..custom_types import Bool, DenseInfo, PyTree, Scalar +from ..custom_types import Bool, DenseInfo, Scalar from ..heuristics import is_sde from ..local_interpolation import AbstractLocalInterpolation from ..nonlinear_solver import AbstractNonlinearSolver, NewtonNonlinearSolver diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index 05809d6b..26b2d234 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -5,7 +5,7 @@ from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation from ..solution import RESULTS -from ..term import AbstractTerm, ODETerm +from ..term import AbstractTerm, MultiTerm, ODETerm from .base import AbstractStratonovichSolver @@ -19,7 +19,7 @@ class EulerHeun(AbstractStratonovichSolver): Used to solve SDEs, and converges to the Stratonovich solution. """ - term_structure = (ODETerm, AbstractTerm) + term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]] interpolation_cls = LocalLinearInterpolation def order(self, terms): @@ -30,7 +30,7 @@ def strong_order(self, terms): def init( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -40,7 +40,7 @@ def init( def step( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -50,7 +50,7 @@ def step( ) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del solver_state, made_jump - drift, diffusion = terms + drift, diffusion = terms.terms dt = drift.contr(t0, t1) dW = diffusion.contr(t0, t1) @@ -67,10 +67,10 @@ def step( def func( self, - terms: Tuple[AbstractTerm, AbstractTerm], + terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]], t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - drift, diffusion = terms + drift, diffusion = terms.terms return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index d3753e46..17bdf59b 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -8,7 +8,7 @@ from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..local_interpolation import LocalLinearInterpolation from ..solution import RESULTS -from ..term import AbstractTerm, ODETerm +from ..term import AbstractTerm, MultiTerm, ODETerm from .base import AbstractItoSolver, AbstractStratonovichSolver @@ -36,7 +36,7 @@ class StratonovichMilstein(AbstractStratonovichSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure = (ODETerm, AbstractTerm) + term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]] interpolation_cls = LocalLinearInterpolation def order(self, terms): @@ -47,7 +47,7 @@ def strong_order(self, terms): def init( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -57,7 +57,7 @@ def init( def step( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -66,7 +66,7 @@ def step( made_jump: Bool, ) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del solver_state, made_jump - drift, diffusion = terms + drift, diffusion = terms.terms dt = drift.contr(t0, t1) dw = diffusion.contr(t0, t1) @@ -84,12 +84,12 @@ def _to_jvp(_y0): def func( self, - terms: Tuple[AbstractTerm, AbstractTerm], + terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]], t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - drift, diffusion = terms + drift, diffusion = terms.terms return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) @@ -104,7 +104,7 @@ class ItoMilstein(AbstractItoSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure = (ODETerm, AbstractTerm) + term_structure = MultiTerm[Tuple[ODETerm, AbstractTerm]] interpolation_cls = LocalLinearInterpolation def order(self, terms): @@ -115,7 +115,7 @@ def strong_order(self, terms): def init( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -125,7 +125,7 @@ def init( def step( self, - terms: Tuple[ODETerm, AbstractTerm], + terms: MultiTerm[Tuple[ODETerm, AbstractTerm]], t0: Scalar, t1: Scalar, y0: PyTree, @@ -346,7 +346,7 @@ def _dot(_, _v0): def func( self, - terms: Tuple[AbstractTerm, AbstractTerm], + terms: MultiTerm[Tuple[AbstractTerm, AbstractTerm]], t0: Scalar, y0: PyTree, args: PyTree, diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index 81c84703..d9bf408f 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Literal, Optional, Union +from typing import get_args, get_origin, Literal, Optional, Tuple, Union import equinox as eqx import equinox.internal as eqxi @@ -13,7 +13,7 @@ from ..custom_types import DenseInfo from ..solution import is_okay, RESULTS, update_result -from ..term import AbstractTerm, ODETerm, WrapTerm +from ..term import AbstractTerm, MultiTerm, ODETerm, WrapTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot @@ -115,6 +115,23 @@ def __post_init__(self): """ +class MultiButcherTableau(eqx.Module): + """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau + solvers, like stochastic Runge--Kutta methods or IMEX methods. + """ + + tableaus: Tuple[ButcherTableau, ...] + + def __init__(self, *tableaus: ButcherTableau): + self.tableaus = tableaus + + +MultiButcherTableau.__init__.__doc__ = """**Arguments:** + +- `*tableaus`: the tableaus to wrap together. +""" + + class CalculateJacobian(metaclass=eqxi.ContainerMeta): """An enumeration of possible ways a Runga--Kutta method may wish to calculate a Jacobian. @@ -201,53 +218,43 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None - tableau: eqxi.AbstractClassVar[PyTree[ButcherTableau]] + tableau: eqxi.AbstractClassVar[Union[ButcherTableau, MultiButcherTableau]] calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - seen_implicit = False - num_stages = None - - def _f(t: ButcherTableau): - nonlocal seen_implicit - nonlocal num_stages - if num_stages is None: - num_stages = t.num_stages - if t.num_stages != num_stages: - raise ValueError("Tableaus must all have the same number of stages") - if t.implicit: - if seen_implicit: + if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau + if isinstance(cls.tableau, ButcherTableau): + if hasattr(cls, "term_structure"): + assert issubclass(cls.term_structure, AbstractTerm) + else: + cls.term_structure = AbstractTerm + elif isinstance(cls.tableau, MultiButcherTableau): + if len({tab.num_stages for tab in cls.tableau.tableaus}) > 1: + raise ValueError("Tableaus must all have the same number of stages") + if len([tab for tab in cls.tableau.tableaus if tab.implicit]) > 1: raise ValueError("May have at most one implicit tableau") + if hasattr(cls, "term_structure"): + assert get_origin(cls.term_structure) is MultiTerm + [_tmp] = get_args(cls.term_structure) + assert get_origin(_tmp) in (tuple, Tuple) + assert all(issubclass(x, AbstractTerm) for x in get_args(_tmp)) else: - seen_implicit = True - return AbstractTerm - - if hasattr(cls, "tableau"): # Abstract subclasses may not have a tableau - term_structure = jtu.tree_map(_f, cls.tableau) - # Allow subclasses to specify more specific term structures if desired, e.g. - # (ODETerm, ControlTerm) rather than (AbstractTerm, AbtstractTerm). - try: - term_structure2 = cls.term_structure - except AttributeError: - cls.term_structure = term_structure + terms = tuple( + AbstractTerm for _ in range(len(cls.tableau.tableaus)) + ) + cls.term_structure = MultiTerm[Tuple[terms]] else: - x = jtu.tree_structure(term_structure, is_leaf=_is_term) - x2 = jtu.tree_structure(term_structure2, is_leaf=_is_term) - if x != x2: - raise ValueError("Mismatched term structures") + assert False def _common(self, terms, t0, t1, y0, args): # For simplicity we share `vf_expensive` and `fsal` across all tableaus. # TODO: could we make these work per-tableau? - vf_expensive = False - fsal = True - terms = jtu.tree_leaves(terms, is_leaf=_is_term) - tableaus = jtu.tree_leaves(self.tableau) - assert len(terms) == len(tableaus) - for term, tableau in zip(terms, tableaus): - vf_expensive = vf_expensive or term.is_vf_expensive(t0, t1, y0, args) - fsal = fsal and tableau.fsal + vf_expensive = terms.is_vf_expensive(t0, t1, y0, args) + if isinstance(self.tableau, MultiButcherTableau): + fsal = all(tab.fsal for tab in self.tableau.tableaus) + else: + fsal = self.tableau.fsal # If the vector field is expensive then we want to use vf_prods instead. # FSAL implies evaluating just the vector field, since we need to contract # the same vector field evaluation against two different controls. @@ -256,16 +263,16 @@ def _common(self, terms, t0, t1, y0, args): def func( self, - terms: PyTree[AbstractTerm], + terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return jtu.tree_map(lambda t: t.vf(t0, y0, args), terms, is_leaf=_is_term) + return terms.vf(t0, y0, args) def init( self, - terms: PyTree[AbstractTerm], + terms: AbstractTerm, t0: Scalar, t1: Scalar, y0: PyTree, @@ -274,20 +281,29 @@ def init( _, fsal = self._common(terms, t0, t1, y0, args) if fsal: first_step = jnp.array(True) - if (type(terms) is WrapTerm) and (type(terms.term) is ODETerm): - # Privileged optimisation for the common case - f0 = jtu.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), y0) - else: + f0 = sentinel = object() + if type(terms) is WrapTerm: + # Privileged optimisations for some common cases + _terms = terms.term + if type(_terms) is ODETerm: + f0 = jtu.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), y0) + elif type(_terms) is MultiTerm: + if all(type(x) is ODETerm for x in _terms.terms): + f0 = tuple( + jtu.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), y0) + for _ in range(len(_terms.terms)) + ) + if f0 is sentinel: # Must be initialiased at zero as it is inserted into `ks` which must be # initialised at zero. - f0 = eqxi.eval_zero(lambda: self.func(terms, t0, y0, args)) + f0 = eqxi.eval_zero(self.func, terms, t0, y0, args) return first_step, f0 else: return None def step( self, - terms: PyTree[AbstractTerm], + terms: AbstractTerm, t0: Scalar, t1: Scalar, y0: PyTree, @@ -331,16 +347,30 @@ def step( # e.g. we need `ks` to perform dense interpolation if needed. # + is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + + # The code below is actually quite generic: it handles a pytree of Butcher + # tableaus and a pytree of terms. + # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive. + # Here we just unpack from one to the other. + if isinstance(self.tableau, ButcherTableau): + assert isinstance(terms, AbstractTerm) + tableaus = self.tableau + else: + assert isinstance(terms, MultiTerm) + tableaus = self.tableau.tableaus + terms = terms.terms + assert jtu.tree_structure(terms, is_leaf=_is_term) == jtu.tree_structure( - self.tableau + tableaus ) - # Structure of `terms` and `self.tableau`. + # Structure of `terms` and `tableaus`. def t_map(fn, *trees): def _fn(_, *_trees): return fn(*_trees) - return jtu.tree_map(_fn, self.tableau, *trees) + return jtu.tree_map(_fn, tableaus, *trees) def t_leaves(tree): return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] @@ -371,8 +401,7 @@ def prod(f): _prod = lambda term_i, f_i, control_i: term_i.prod(f_i, control_i) return t_map(_prod, terms, f, control) - num_stages = jtu.tree_leaves(self.tableau)[0].num_stages - is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + num_stages = jtu.tree_leaves(tableaus)[0].num_stages if fsal: assert solver_state is not None first_step, f0 = solver_state @@ -410,8 +439,8 @@ def embed_c(tableau): tableau_c[1:] = tableau.c return jnp.asarray(tableau_c) - tableau_a_lower = t_map(embed_a_lower, self.tableau) - tableau_c = t_map(embed_c, self.tableau) + tableau_a_lower = t_map(embed_a_lower, tableaus) + tableau_c = t_map(embed_c, tableaus) def cond_fun(val): _stage_index, *_ = val @@ -456,17 +485,17 @@ def buffers(val): ) _, y1_partial, increment, f1, ks = final_val - if all(tableau.ssal for tableau in jtu.tree_leaves(self.tableau)): + if all(tableau.ssal for tableau in jtu.tree_leaves(tableaus)): y1 = y1_partial else: increment = t_map( lambda t, k, i: i if t.ssal else vector_tree_dot(t.b_sol, k), - self.tableau, + tableaus, ks, increment, ) y1 = s_map(_sum, y0, *t_leaves(increment)) - y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), self.tableau, ks) + y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), tableaus, ks) dense_info = dict(y0=y0, y1=y1, k=ks) if fsal: new_solver_state = False, f1 diff --git a/diffrax/term.py b/diffrax/term.py index 7b562e06..9c6006ad 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -1,6 +1,6 @@ import abc import operator -from typing import Callable, Tuple +from typing import Callable, Generic, Tuple, TypeVar import equinox as eqx import jax @@ -313,7 +313,10 @@ def _sum(*x): return sum(x[1:], x[0]) -class MultiTerm(AbstractTerm): +_Terms = TypeVar("_Terms", bound=Tuple[AbstractTerm, ...]) + + +class MultiTerm(AbstractTerm, Generic[_Terms]): r"""Accumulates multiple terms into a single term. Consider the SDE @@ -332,9 +335,9 @@ class MultiTerm(AbstractTerm): transform is a necessary part of e.g. solving an SDE with both drift and diffusion. """ - terms: Tuple[AbstractTerm, ...] + terms: _Terms - def __init__(self, *terms): + def __init__(self, *terms: AbstractTerm): """**Arguments:** - `*terms`: Any number of [`diffrax.AbstractTerm`][]s to combine. @@ -363,6 +366,15 @@ def vf_prod( ] return jtu.tree_map(_sum, *out) + def is_vf_expensive( + self, + t0: Scalar, + t1: Scalar, + y: Tuple[PyTree, PyTree, PyTree, PyTree], + args: PyTree, + ) -> bool: + return any(term.is_vf_expensive(t0, t1, y, args) for term in self.terms) + class WrapTerm(AbstractTerm): term: AbstractTerm @@ -384,6 +396,15 @@ def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree t = t * self.direction return self.term.vf_prod(t, y, args, control) + def is_vf_expensive( + self, + t0: Scalar, + t1: Scalar, + y: Tuple[PyTree, PyTree, PyTree, PyTree], + args: PyTree, + ) -> bool: + return self.term.is_vf_expensive(t0, t1, y, args) + class AdjointTerm(AbstractTerm): term: AbstractTerm diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md index 2942c989..23775db7 100644 --- a/docs/api/solvers/abstract_solvers.md +++ b/docs/api/solvers/abstract_solvers.md @@ -81,6 +81,11 @@ In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use members: - __init__ +::: diffrax.MultiButcherTableau + selection: + members: + - __init__ + ::: diffrax.CalculateJacobian selection: members: false diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md index 849a6dec..7c823352 100644 --- a/docs/api/solvers/sde_solvers.md +++ b/docs/api/solvers/sde_solvers.md @@ -14,7 +14,7 @@ See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochast diffeqsolve(terms, solver=Euler(), ...) ``` - Some solvers are SDE-specific. For these, such as for example [`diffrax.StratonovichMilstein`][], then `terms` should be a 2-tuple `(ODETerm, AbstractTerm)`, representing the drift and diffusion separately. + Some solvers are SDE-specific. For these, such as for example [`diffrax.StratonovichMilstein`][], then `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically. For those SDE-specific solvers then this is documented below, and the term structure is available programmatically under `.term_structure`. @@ -58,7 +58,8 @@ These are reversible in the same way as when applied to ODEs. [See here.](./ode_ !!! info "Term structure" - For these SDE-specific solvers, the terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple `(ODETerm, AbstractTerm)`, representing the drift and diffusion respectively. Typically that means `(ODETerm(...), ControlTerm(..., ...))`. + These solvers are SDE-specific. For these, `terms` must specifically be of the form `MultiTerm(ODETerm(...), SomeOtherTerm(...))` (Typically `SomeOTherTerm` will be a `ControlTerm` or `WeaklyDiagonalControlTerm`) representing the drift and diffusion specifically. + ::: diffrax.EulerHeun selection: diff --git a/docs/api/terms.md b/docs/api/terms.md index 73316cbb..b78e5241 100644 --- a/docs/api/terms.md +++ b/docs/api/terms.md @@ -4,13 +4,21 @@ One of the advanced features of Diffrax is its *term* system. When we write down $\mathrm{d}y(t) = f(t, y(t))\mathrm{d}t + g(t, y(t))\mathrm{d}w(t)$ -then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a *vector field* ($f$ or $g$) and a *control* ($\mathrm{d}t$ or $\mathrm{d}w(t)$). There is also an implicit assumption about how vector field and control interact: $f$ and $\mathrm{d}t$ interact as a vector-scalar product. $g$ and $\mathrm{d}w(t)$ interact as a matrix-vector product. (This interaction is always linear.) +then we have two "terms": a drift and a diffusion. Each of these terms has two parts: a *vector field* ($f$ or $g$) and a *control* ($\mathrm{d}t$ or $\mathrm{d}w(t)$). There is also an implicit assumption about how the vector field and control interact: $f$ and $\mathrm{d}t$ interact as a vector-scalar product. $g$ and $\mathrm{d}w(t)$ interact as a matrix-vector product. (This interaction is always linear.) -"Terms" are thus the building blocks of differential equations. In Diffrax, the above SDE has its drift described by [`diffrax.ODETerm`][] and the diffusion described by a [`diffrax.ControlTerm`][]. +"Terms" are thus the building blocks of differential equations. !!! example - As a simpler example, consider the ODE $\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))$. Then this has vector field $f$, control $\mathrm{d}t$, and their interaction is a vector-scalar product. This can be described as a single [`diffrax.ODETerm`][]. + Consider the ODE $\frac{\mathrm{d}{y}}{\mathrm{d}t} = f(t, y(t))$. Then this has vector field $f$, control $\mathrm{d}t$, and their interaction is a vector-scalar product. This can be described as a single [`diffrax.ODETerm`][]. + +If multiple terms affect the same evolving state, then they should be grouped into a single [`diffrax.MultiTerm`][]. + +!!! example + + An SDE would have its drift described by [`diffrax.ODETerm`][] and the diffusion described by a [`diffrax.ControlTerm`][]. As these affect the same evolving state variable, they should be passed to the solver as `MultiTerm(ODETerm(...), ControlTerm(...))`. + +If terms affect different pieces of the state, then they should be placed in some PyTree structure. (The exact structure will depend on what the solver accepts.) !!! example @@ -18,13 +26,9 @@ then we have two "terms": a drift and a diffusion. Each of these terms has two p $\frac{\mathrm{d}x}{\mathrm{d}t}(t) = f(t, y(t)),\qquad\frac{\mathrm{d}y}{\mathrm{d}t}(t) = g(t, x(t))$ - These can be described as a 2-tuple of [`diffrax.ODETerm`][]`s. - -The very first argument to [`diffrax.diffeqsolve`][] should be some PyTree of terms. This is interpreted by the solver in the appropriate way. + These would be passed to the solver as the 2-tuple of `(ODETerm(...), ODETerm(...))`. -- For example [`diffrax.Euler`][] expects a single term: it solves an ODE represented via `ODETerm(...)`, or an SDE represented via `MultiTerm(ODETerm(...), ControlTerm(...))`. -- Meanwhile [`diffrax.SemiImplicitEuler`][] solves the paired (Hamiltonian) system given in the example above, and expects a 2-tuple of terms representing each piece. -- Some SDE-specific solvers (e.g. [`diffrax.StratonovichMilstein`][] need to be able to see the distinction between the drift and diffusion, and expect a 2-tuple of terms representing the drift and diffusion respectively. +Each solver is capable of handling certain classes of problems, as described by their `solver.term_structure`. ??? abstract "`diffrax.AbstractTerm`" diff --git a/test/test_solver.py b/test/test_solver.py index ae0275e3..36d3b10e 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -51,7 +51,9 @@ def test_implicit_euler_adaptive(): def test_multiple_tableau1(): class DoubleDopri5(diffrax.AbstractRungeKutta): - tableau = (diffrax.Dopri5.tableau, diffrax.Dopri5.tableau) + tableau = diffrax.MultiButcherTableau( + diffrax.Dopri5.tableau, diffrax.Dopri5.tableau + ) calculate_jacobian = diffrax.CalculateJacobian.never def interpolation_cls(self, *, k, **kwargs): @@ -75,7 +77,7 @@ def interpolation_cls(self, *, k, **kwargs): y0, ) out_b = diffrax.diffeqsolve( - (term1, term2), + diffrax.MultiTerm(term1, term2), DoubleDopri5(), t0, t1, @@ -84,14 +86,47 @@ def interpolation_cls(self, *, k, **kwargs): ) assert jnp.allclose(out_a.ys, out_b.ys, rtol=1e-8, atol=1e-8) + with pytest.raises(ValueError): + diffrax.diffeqsolve( + (term1, term2), + DoubleDopri5(), + t0, + t1, + dt0, + y0, + ) + def test_multiple_tableau2(): # Different number of stages with pytest.raises(ValueError): - class Dopri5Tsit5(diffrax.AbstractRungeKutta): - tableau = (diffrax.Dopri5.tableau, diffrax.Bosh3.tableau) + class X(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau( + diffrax.Dopri5.tableau, diffrax.Bosh3.tableau + ) calculate_jacobian = diffrax.CalculateJacobian.never def interpolation_cls(self, *, k, **kwargs): return diffrax.LocalLinearInterpolation(**kwargs) + + # Multiple implicit + with pytest.raises(ValueError): + + class Y(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau( + diffrax.Kvaerno3.tableau, diffrax.Kvaerno3.tableau + ) + calculate_jacobian = diffrax.CalculateJacobian.never + + def interpolation_cls(self, *, k, **kwargs): + return diffrax.LocalLinearInterpolation(**kwargs) + + class Z(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau( + diffrax.Bosh3.tableau, diffrax.Kvaerno3.tableau + ) + calculate_jacobian = diffrax.CalculateJacobian.never + + def interpolation_cls(self, *, k, **kwargs): + return diffrax.LocalLinearInterpolation(**kwargs) From 2d34bac2d296af8002908cc8af8999d445e94c03 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 16 May 2023 22:52:32 -0700 Subject: [PATCH 6/7] Added Sil3, KenCarp{3,4,5} and support for IMEX methods. --- diffrax/__init__.py | 5 + diffrax/adjoint.py | 2 +- diffrax/custom_types.py | 4 +- diffrax/nonlinear_solver/__init__.py | 1 + diffrax/nonlinear_solver/affine.py | 34 + diffrax/solver/__init__.py | 4 + diffrax/solver/bosh3.py | 3 +- diffrax/solver/dopri5.py | 2 +- diffrax/solver/dopri8.py | 2 +- diffrax/solver/euler.py | 3 +- diffrax/solver/euler_heun.py | 5 + diffrax/solver/heun.py | 3 +- diffrax/solver/implicit_euler.py | 5 +- diffrax/solver/kencarp3.py | 151 +++ diffrax/solver/kencarp4.py | 164 ++++ diffrax/solver/kencarp5.py | 231 +++++ diffrax/solver/kvaerno3.py | 3 +- diffrax/solver/kvaerno4.py | 3 +- diffrax/solver/kvaerno5.py | 3 +- diffrax/solver/leapfrog_midpoint.py | 3 +- diffrax/solver/midpoint.py | 3 +- diffrax/solver/milstein.py | 14 +- diffrax/solver/ralston.py | 2 +- diffrax/solver/reversible_heun.py | 2 +- diffrax/solver/runge_kutta.py | 1228 +++++++++++++------------ diffrax/solver/semi_implicit_euler.py | 3 +- diffrax/solver/sil3.py | 86 ++ diffrax/solver/tsit5.py | 10 +- docs/api/solvers/abstract_solvers.md | 5 - docs/api/solvers/ode_solvers.md | 26 + docs/usage/how-to-choose-a-solver.md | 4 + test/helpers.py | 7 + test/test_global_interpolation.py | 19 +- test/test_integrate.py | 24 +- test/test_interpolation.py | 20 +- test/test_solver.py | 328 ++++++- 36 files changed, 1778 insertions(+), 634 deletions(-) create mode 100644 diffrax/nonlinear_solver/affine.py create mode 100644 diffrax/solver/kencarp3.py create mode 100644 diffrax/solver/kencarp4.py create mode 100644 diffrax/solver/kencarp5.py create mode 100644 diffrax/solver/sil3.py diff --git a/diffrax/__init__.py b/diffrax/__init__.py index ef4c06a3..381eec2a 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -31,6 +31,7 @@ from .misc import adjoint_rms_seminorm from .nonlinear_solver import ( AbstractNonlinearSolver, + AffineNonlinearSolver, NewtonNonlinearSolver, NonlinearSolution, ) @@ -60,6 +61,9 @@ Heun, ImplicitEuler, ItoMilstein, + KenCarp3, + KenCarp4, + KenCarp5, Kvaerno3, Kvaerno4, Kvaerno5, @@ -69,6 +73,7 @@ Ralston, ReversibleHeun, SemiImplicitEuler, + Sil3, StratonovichMilstein, Tsit5, ) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index ffd56b9f..01035baf 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -366,7 +366,7 @@ def loop( # Support forward-mode autodiff. # TODO: remove this hack once we can JVP through custom_vjps. if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: - solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax") + solver = eqx.tree_at(lambda s: s.scan_kind, solver, "bounded") inner_while_loop = ft.partial(_inner_loop, kind=kind) outer_while_loop = ft.partial(_outer_loop, kind=kind) final_state = self._loop( diff --git a/diffrax/custom_types.py b/diffrax/custom_types.py index 624f47f4..93e818b5 100644 --- a/diffrax/custom_types.py +++ b/diffrax/custom_types.py @@ -1,7 +1,8 @@ import inspect import typing -from typing import Dict, Generic, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Tuple, TypeVar, Union +import equinox.internal as eqxi import jax.tree_util as jtu @@ -129,3 +130,4 @@ def __class_getitem__(cls, item): DenseInfo = Dict[str, PyTree[Array]] DenseInfos = Dict[str, PyTree[Array["times", ...]]] # noqa: F821 +sentinel: Any = eqxi.doc_repr(object(), "sentinel") diff --git a/diffrax/nonlinear_solver/__init__.py b/diffrax/nonlinear_solver/__init__.py index 4e66f0ef..309691ff 100644 --- a/diffrax/nonlinear_solver/__init__.py +++ b/diffrax/nonlinear_solver/__init__.py @@ -1,2 +1,3 @@ +from .affine import AffineNonlinearSolver from .base import AbstractNonlinearSolver, NonlinearSolution from .newton import NewtonNonlinearSolver diff --git a/diffrax/nonlinear_solver/affine.py b/diffrax/nonlinear_solver/affine.py new file mode 100644 index 00000000..1b5badbd --- /dev/null +++ b/diffrax/nonlinear_solver/affine.py @@ -0,0 +1,34 @@ +import equinox as eqx +import jax +import jax.flatten_util as jfu +import jax.numpy as jnp + +from ..solution import RESULTS +from .base import AbstractNonlinearSolver, NonlinearSolution + + +class AffineNonlinearSolver(AbstractNonlinearSolver): + """Finds the fixed point of f(x)=0, where f(x) = Ax + b is affine. + + !!! Warning + + This solver only exists temporarily. It is deliberately undocumented and will be + removed shortly, in favour of a more comprehensive approach to performing linear + and nonlinear solves. + """ + + def _solve(self, fn, x, jac, nondiff_args, diff_args): + del jac + args = eqx.combine(nondiff_args, diff_args) + flat, unflatten = jfu.ravel_pytree(x) + zero = jnp.zeros_like(flat) + flat_fn = lambda z: jfu.ravel_pytree(fn(unflatten(z), args))[0] + b = flat_fn(zero) + A = jax.jacfwd(flat_fn)(zero) + out = -jnp.linalg.solve(A, b) + out = unflatten(out) + return NonlinearSolution(root=out, num_steps=0, result=RESULTS.successful) + + @staticmethod + def jac(fn, x, args): + return None diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index f3a108c0..ace213c4 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -14,6 +14,9 @@ from .euler_heun import EulerHeun from .heun import Heun from .implicit_euler import ImplicitEuler +from .kencarp3 import KenCarp3 +from .kencarp4 import KenCarp4 +from .kencarp5 import KenCarp5 from .kvaerno3 import Kvaerno3 from .kvaerno4 import Kvaerno4 from .kvaerno5 import Kvaerno5 @@ -33,4 +36,5 @@ MultiButcherTableau, ) from .semi_implicit_euler import SemiImplicitEuler +from .sil3 import Sil3 from .tsit5 import Tsit5 diff --git a/diffrax/solver/bosh3.py b/diffrax/solver/bosh3.py index cf68bd8e..8d27fe8b 100644 --- a/diffrax/solver/bosh3.py +++ b/diffrax/solver/bosh3.py @@ -20,7 +20,8 @@ class Bosh3(AbstractERK): """Bogacki--Shampine's 3/2 method. 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for - adaptive step sizing. + adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. Also sometimes known as "Ralston's third order method". """ diff --git a/diffrax/solver/dopri5.py b/diffrax/solver/dopri5.py index 2ba617df..ed0f035f 100644 --- a/diffrax/solver/dopri5.py +++ b/diffrax/solver/dopri5.py @@ -51,7 +51,7 @@ class Dopri5(AbstractERK): r"""Dormand-Prince's 5/4 method. 5th order Runge--Kutta method. Has an embedded 4th order method for adaptive step - sizing. + sizing. Uses 7 stages with FSAL. Uses 5th order interpolation for dense/ts output. ??? cite "Reference" diff --git a/diffrax/solver/dopri8.py b/diffrax/solver/dopri8.py index 77ba0ab8..1b9b6551 100644 --- a/diffrax/solver/dopri8.py +++ b/diffrax/solver/dopri8.py @@ -295,7 +295,7 @@ class Dopri8(AbstractERK): """Dormand--Prince's 8/7 method. 8th order Runge--Kutta method. Has an embedded 7th order method for adaptive step - sizing. + sizing. Uses 14 stages with FSAL. Uses 8th order interpolation for dense/ts output. ??? cite "References" diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py index 5ddcda87..c7043eef 100644 --- a/diffrax/solver/euler.py +++ b/diffrax/solver/euler.py @@ -16,7 +16,8 @@ class Euler(AbstractItoSolver): """Euler's method. - 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. + 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. Uses + 1 stage. Uses 1st order local linear interpolation for dense/ts output. When used to solve SDEs, converges to the Itô solution. """ diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index 26b2d234..9b5e3527 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -16,6 +16,11 @@ class EulerHeun(AbstractStratonovichSolver): """Euler-Heun method. + Uses a 1st order local linear interpolation scheme for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. + Used to solve SDEs, and converges to the Stratonovich solution. """ diff --git a/diffrax/solver/heun.py b/diffrax/solver/heun.py index eb35dd36..464d038d 100644 --- a/diffrax/solver/heun.py +++ b/diffrax/solver/heun.py @@ -17,7 +17,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver): """Heun's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd-order Hermite interpolation for dense/ts + output. Also sometimes known as either the "improved Euler method", "modified Euler method" or "explicit trapezoidal rule". diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py index 55f69dae..b0cd1def 100644 --- a/diffrax/solver/implicit_euler.py +++ b/diffrax/solver/implicit_euler.py @@ -22,8 +22,9 @@ def _implicit_relation(z1, nonlinear_solve_args): class ImplicitEuler(AbstractImplicitSolver): r"""Implicit Euler method. - A-B-L stable 1st order SDIRK method. Has an embedded 2nd order method for adaptive - step sizing. + A-B-L stable 1st order SDIRK method. Has an embedded 2nd order Heun method for + adaptive step sizing. Uses 1 stage. Uses a 1st order local linear interpolation for + dense/ts output. """ term_structure = AbstractTerm diff --git a/diffrax/solver/kencarp3.py b/diffrax/solver/kencarp3.py new file mode 100644 index 00000000..9a088db0 --- /dev/null +++ b/diffrax/solver/kencarp3.py @@ -0,0 +1,151 @@ +from typing import Optional, Tuple + +import equinox.internal as eqxi +import jax +import jax.numpy as jnp +import numpy as np +from equinox.internal import ω + +from ..custom_types import Array, PyTree, Scalar +from ..local_interpolation import AbstractLocalInterpolation +from ..misc import linear_rescale +from .base import AbstractImplicitSolver, vector_tree_dot +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 1767732205903 / 4055673282236 +_b_sol = np.array( + [ + 1471266399579 / 7840856788654, + -4482444167858 / 7529755066697, + 11266239266428 / 11593286722821, + _γ, + ] +) +_b_sol_embedded = np.array( + [ + 2756255671327 / 12835298489170, + -10771552573575 / 22201958757719, + 9247589265047 / 10645013368117, + 2193209047091 / 5459859503100, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array([2 * _γ, 3 / 5, 1.0]) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([2 * _γ]), + np.array([5535828885825 / 10492691773637, 788022342437 / 10882634858940]), + np.array( + [ + 6485989280629 / 16251701735622, + -4246266847089 / 9704473918619, + 10755448449292 / 10357097424841, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([2746238789719 / 10658868560708, -640167445237 / 6845629431997]), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + ), +) + + +class KenCarpInterpolation(AbstractLocalInterpolation): + y0: PyTree[Array[...]] + k: Tuple[PyTree[Array["order", ...]], PyTree[Array["order", ...]]] # noqa: F821 + + coeffs: eqxi.AbstractClassVar[np.ndarray] + + def __init__(self, *, y0, y1, k, **kwargs): + del y1 # exists for API compatibility + super().__init__(**kwargs) + self.y0 = y0 + self.k = k + + def evaluate( + self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True + ) -> PyTree: + del left + if t1 is not None: + return self.evaluate(t1) - self.evaluate(t0) + + t = linear_rescale(self.t0, t0, self.t1) + explicit_k, implicit_k = self.k + k = (explicit_k**ω + implicit_k**ω).ω + coeffs = t * jax.vmap(lambda row: jnp.polyval(row, t))(self.coeffs) + return (self.y0**ω + vector_tree_dot(coeffs, k) ** ω).ω + + +class _KenCarp3Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [-215264564351 / 13552729205753, 4655552711362 / 22874653954995], + [17870216137069 / 13817060693119, -18682724506714 / 9892148508045], + [-28141676662227 / 17317692491321, 34259539580243 / 13192909600954], + [2508943948391 / 7218656332882, 584795268549 / 6622622206610], + ] + ) + + +class KenCarp3(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 3/2 IMEX method. + + 3rd order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 2nd order method for adaptive step sizing. + Uses 4 stages. Uses 2nd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp3Interpolation + + def order(self, terms): + return 3 diff --git a/diffrax/solver/kencarp4.py b/diffrax/solver/kencarp4.py new file mode 100644 index 00000000..f9d83317 --- /dev/null +++ b/diffrax/solver/kencarp4.py @@ -0,0 +1,164 @@ +import numpy as np + +from .base import AbstractImplicitSolver +from .kencarp3 import KenCarpInterpolation +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 0.25 +_b_sol = np.array([82889 / 524892, 0, 15625 / 83664, 69875 / 102672, -2260 / 8211, _γ]) +_b_sol_embedded = np.array( + [ + 4586570599 / 29645900160, + 0, + 178811875 / 945068544, + 814220225 / 1159782912, + -3700637 / 11593932, + 61727 / 225920, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array([0.5, 83 / 250, 31 / 50, 17 / 20, 1.0]) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] +_c_ratio3 = _c[3] / _c[2] +_c_ratio4 = _c[4] / _c[3] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([0.5]), + np.array([13861 / 62500, 6889 / 62500]), + np.array( + [ + -116923316275 / 2393684061468, + -2731218467317 / 15368042101831, + 9408046702089 / 11113171139209, + ] + ), + np.array( + [ + -451086348788 / 2902428689909, + -2682348792572 / 7519795681897, + 12662868775082 / 11960479115383, + 3355817975965 / 11060851509271, + ] + ), + np.array( + [ + 647845179188 / 3216320057751, + 73281519250 / 8382639484533, + 552539513391 / 3454668386233, + 3354512671639 / 8306763924573, + 4040 / 17871, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([8611 / 62500, -1743 / 31250]), + np.array([5012029 / 34652500, -654441 / 2922500, 174375 / 388108]), + np.array( + [ + 15267082809 / 155376265600, + -71443401 / 120774400, + 730878875 / 902184768, + 2285395 / 8070912, + ] + ), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + np.array([1 - _c_ratio3, 0, 0, _c_ratio3]), # arbitrarily use linear interp. + np.array([1 - _c_ratio4, 0, 0, 0, _c_ratio4]), # also arbitrary linear interp. + ), +) + + +class _KenCarp4Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [ + 6818779379841 / 7100303317025, + -54480133 / 30881146, + 6943876665148 / 7220017795957, + ], + [0.0, 0.0, 0.0], + [ + 2173542590792 / 12501825683035, + -11436875 / 14766696, + 7640104374378 / 9702883013639, + ], + [ + -31592104683404 / 5083833661969, + 174696575 / 18121608, + -20649996744609 / 7521556579894, + ], + [ + 61146701046299 / 7138195549469, + -12120380 / 966161, + 8854892464581 / 2390941311638, + ], + [ + -17219254887155 / 4939391667607, + 3843 / 706, + -11397109935349 / 6675773540249, + ], + ] + ) + + +class KenCarp4(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 4/3 IMEX method. + + 4th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 3rd order method for adaptive step sizing. + Uses 6 stages. Uses 3rd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp4Interpolation + + def order(self, terms): + return 4 diff --git a/diffrax/solver/kencarp5.py b/diffrax/solver/kencarp5.py new file mode 100644 index 00000000..63e94780 --- /dev/null +++ b/diffrax/solver/kencarp5.py @@ -0,0 +1,231 @@ +import numpy as np + +from .base import AbstractImplicitSolver +from .kencarp3 import KenCarpInterpolation +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +_γ = 41 / 200 +_b_sol = np.array( + [ + -872700587467 / 9133579230613, + 0, + 0, + 22348218063261 / 9555858737531, + -1143369518992 / 8141816002931, + -39379526789629 / 19018526304540, + 32727382324388 / 42900044865799, + _γ, + ] +) +_b_sol_embedded = np.array( + [ + -975461918565 / 9796059967033, + 0, + 0, + 78070527104295 / 32432590147079, + -548382580838 / 3424219808633, + -33438840321285 / 15594753105479, + 3629800801594 / 4656183773603, + 4035322873751 / 18575991585200, + ] +) +_b_error = _b_sol - _b_sol_embedded +_c = np.array( + [ + 41 / 100, + 2935347310677 / 11292855782101, + 1426016391358 / 7196633302097, + 92 / 100, + 24 / 100, + 3 / 5, + 1.0, + ] +) +_c_ratio = _c[1] / _c[0] +_c_ratio2 = _c[2] / _c[0] +_c_ratio3 = _c[3] / _c[0] +_c_ratio4 = _c[4] / _c[1] +_c_ratio5 = _c[5] / _c[3] +_c_ratio6 = _c[6] / _c[3] + +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([41 / 100]), + np.array([367902744464 / 2072280473677, 677623207551 / 8224143866563]), + np.array([1268023523408 / 10340822734521, 0, 1029933939417 / 13636558850479]), + np.array( + [ + 14463281900351 / 6315353703477, + 0, + 66114435211212 / 5879490589093, + -54053170152839 / 4284798021562, + ] + ), + np.array( + [ + 14090043504691 / 34967701212078, + 0, + 15191511035443 / 11219624916014, + -18461159152457 / 12425892160975, + -281667163811 / 9011619295870, + ] + ), + np.array( + [ + 19230459214898 / 13134317526959, + 0, + 21275331358303 / 2942455364971, + -38145345988419 / 4862620318723, + -1 / 8, + -1 / 8, + ] + ), + np.array( + [ + -19977161125411 / 11928030595625, + 0, + -40795976796054 / 6384907823539, + 177454434618887 / 12078138498510, + 782672205425 / 8267701900261, + -69563011059811 / 9646580694205, + 7356628210526 / 4942186776405, + ] + ), + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, +) + +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([_γ]), + np.array([41 / 400, -567603406766 / 11931857230679]), + np.array([683785636431 / 9252920307686, 0, -110385047103 / 1367015193373]), + np.array( + [ + 3016520224154 / 10081342136671, + 0, + 30586259806659 / 12414158314087, + -22760509404356 / 11113319521817, + ] + ), + np.array( + [ + 218866479029 / 1489978393911, + 0, + 638256894668 / 5436446318841, + -1179710474555 / 5321154724896, + -60928119172 / 8023461067671, + ] + ), + np.array( + [ + 1020004230633 / 5715676835656, + 0, + 25762820946817 / 25263940353407, + -2161375909145 / 9755907335909, + -211217309593 / 5846859502534, + -4269925059573 / 7827059040719, + ] + ), + _b_sol[:-1], + ), + b_sol=_b_sol, + b_error=_b_error, + c=_c, + a_diagonal=np.array([0, _γ, _γ, _γ, _γ, _γ, _γ, _γ]), + # See + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + # for the construction of the a_predictor tableau, which is new here. + # They do also discuss this a little bit in Sections 2.1.7 and 3.2.2, but don't + # really pick any particular answer. + a_predictor=( + np.array([1.0]), + np.array([1 - _c_ratio, _c_ratio]), + np.array([1 - _c_ratio2, _c_ratio2, 0]), # c3 < c2 so use first two stages + np.array([1 - _c_ratio3, _c_ratio3, 0, 0]), # c4 < c2 also + np.array([1 - _c_ratio4, 0, _c_ratio4, 0, 0]), # c3≈c6 so use that + np.array([1 - _c_ratio5, 0, 0, 0, _c_ratio5, 0]), # arbitrary linear interp + np.array([1 - _c_ratio6, 0, 0, 0, _c_ratio6, 0, 0]), # arbitrary linear interp + ), +) + + +class _KenCarp5Interpolation(KenCarpInterpolation): + coeffs = np.array( + [ + [ + -9257016797708 / 5021505065439, + 43486358583215 / 12773830924787, + -17674230611817 / 10670229744614, + ], + [0, 0, 0], + [0, 0, 0], + [ + 26096422576131 / 11239449250142, + -91478233927265 / 11067650958493, + 65168852399939 / 7868540260826, + ], + [ + 92396832856987 / 20362823103730, + -79368583304911 / 10890268929626, + 15494834004392 / 5936557850923, + ], + [ + 30029262896817 / 10175596800299, + -12239297817655 / 9152339842473, + -99329723586156 / 26959484932159, + ], + [ + -26136350496073 / 3983972220547, + 115839755401235 / 10719374521269, + -19024464361622 / 5461577185407, + ], + [ + -5289405421727 / 3760307252460, + 5843115559534 / 2180450260947, + -6511271360970 / 6095937251113, + ], + ] + ) + + +class KenCarp5(AbstractRungeKutta, AbstractImplicitSolver): + """Kennedy--Carpenter's 5/4 IMEX method. + + 5th order ERK-ESDIRK implicit-explicit (IMEX) method. The implicit part is stiffly + accurate and A-L stable. Has an embedded 4th order method for adaptive step sizing. + Uses 8 stages. Uses 3rd order interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{kennedy2003additive, + title={Additive Runge--Kutta schemes for convection-diffusion-reaction + equations}, + author={Kennedy, Christopher A and Carpenter, Mark H}, + journal={Applied numerical mathematics}, + volume={44}, + number={1-2}, + pages={139--181}, + year={2003}, + publisher={Elsevier} + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.second_stage + interpolation_cls = _KenCarp5Interpolation + + def order(self, terms): + return 5 diff --git a/diffrax/solver/kvaerno3.py b/diffrax/solver/kvaerno3.py index 096a4939..cd767251 100644 --- a/diffrax/solver/kvaerno3.py +++ b/diffrax/solver/kvaerno3.py @@ -40,7 +40,8 @@ class Kvaerno3(AbstractESDIRK): r"""Kvaerno's 3/2 method. A-L stable stiffly accurate 3rd order ESDIRK method. Has an embedded 2nd order - method for adaptive step sizing. Uses 4 stages. + method for adaptive step sizing. Uses 4 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. ??? cite "Reference" diff --git a/diffrax/solver/kvaerno4.py b/diffrax/solver/kvaerno4.py index f5b15da7..e28088c6 100644 --- a/diffrax/solver/kvaerno4.py +++ b/diffrax/solver/kvaerno4.py @@ -78,7 +78,8 @@ class Kvaerno4(AbstractESDIRK): r"""Kvaerno's 4/3 method. A-L stable stiffly accurate 4th order ESDIRK method. Has an embedded 3rd order - method for adaptive step sizing. Uses 5 stages. + method for adaptive step sizing. Uses 5 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. When solving an ODE over the interval $[t_0, t_1]$, note that this method will make some evaluations slightly past $t_1$. diff --git a/diffrax/solver/kvaerno5.py b/diffrax/solver/kvaerno5.py index e8574613..0be7daab 100644 --- a/diffrax/solver/kvaerno5.py +++ b/diffrax/solver/kvaerno5.py @@ -84,7 +84,8 @@ class Kvaerno5(AbstractESDIRK): r"""Kvaerno's 5/4 method. A-L stable stiffly accurate 5th order ESDIRK method. Has an embedded 4th order - method for adaptive step sizing. Uses 7 stages. + method for adaptive step sizing. Uses 7 stages with FSAL. Uses 3rd order Hermite + interpolation for dense/ts output. When solving an ODE over the interval $[t_0, t_1]$, note that this method will make some evaluations slightly past $t_1$. diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py index ad6e99e1..b0f152d2 100644 --- a/diffrax/solver/leapfrog_midpoint.py +++ b/diffrax/solver/leapfrog_midpoint.py @@ -17,7 +17,8 @@ class LeapfrogMidpoint(AbstractSolver): r"""Leapfrog/midpoint method. - 2nd order linear multistep method. + 2nd order linear multistep method. Uses 1st order local linear interpolation for + dense/ts output. Note that this is referred to as the "leapfrog/midpoint method" as this is the name used by Shampine in the reference below. It should not be confused with any of the diff --git a/diffrax/solver/midpoint.py b/diffrax/solver/midpoint.py index 8a8b50fe..0da0b666 100644 --- a/diffrax/solver/midpoint.py +++ b/diffrax/solver/midpoint.py @@ -17,7 +17,8 @@ class Midpoint(AbstractERK, AbstractStratonovichSolver): """Midpoint method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense/ts + output. Also sometimes known as the "modified Euler method". diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index 17bdf59b..e1daea85 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -28,7 +28,11 @@ class StratonovichMilstein(AbstractStratonovichSolver): r"""Milstein's method; Stratonovich version. - Used to solve SDEs, and converges to the Stratonovich solution. + Used to solve SDEs, and converges to the Stratonovich solution. Uses local linear + interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. !!! warning @@ -96,7 +100,11 @@ def func( class ItoMilstein(AbstractItoSolver): r"""Milstein's method; Itô version. - Used to solve SDEs, and converges to the Itô solution. + Used to solve SDEs, and converges to the Itô solution. Uses local linear + interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(drift_term, diffusion_term)`, where the + drift is an `ODETerm`. !!! warning @@ -134,7 +142,7 @@ def step( made_jump: Bool, ) -> Tuple[PyTree, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: del solver_state, made_jump - drift, diffusion = terms + drift, diffusion = terms.terms Δt = drift.contr(t0, t1) Δw = diffusion.contr(t0, t1) diff --git a/diffrax/solver/ralston.py b/diffrax/solver/ralston.py index be3321b9..dda31d6d 100644 --- a/diffrax/solver/ralston.py +++ b/diffrax/solver/ralston.py @@ -26,7 +26,7 @@ class Ralston(AbstractERK, AbstractStratonovichSolver): """Ralston's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive - step sizing. + step sizing. Uses 2 stages. Uses 2nd order Hermite interpolation for dense output. When used to solve SDEs, converges to the Stratonovich solution. """ diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py index cb337af8..d0d3d2d1 100644 --- a/diffrax/solver/reversible_heun.py +++ b/diffrax/solver/reversible_heun.py @@ -17,7 +17,7 @@ class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for - adaptive step sizing. + adaptive step sizing. Uses 1st order local linear interpolation for dense/ts output. When used to solve SDEs, converges to the Stratonovich solution. diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index d9bf408f..a1011a53 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -1,17 +1,18 @@ +import functools as ft from dataclasses import dataclass, field from typing import get_args, get_origin, Literal, Optional, Tuple, Union import equinox as eqx import equinox.internal as eqxi import jax +import jax.flatten_util as jfu import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu import numpy as np from equinox.internal import ω -from jaxtyping import Array, Bool, PyTree, Scalar -from ..custom_types import DenseInfo +from ..custom_types import Array, DenseInfo, PyTree, Scalar, sentinel from ..solution import is_okay, RESULTS, update_result from ..term import AbstractTerm, MultiTerm, ODETerm, WrapTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot @@ -31,6 +32,7 @@ class ButcherTableau: # Implicit RK methods a_diagonal: Optional[np.ndarray] = None a_predictor: Optional[tuple[np.ndarray, ...]] = None + c1: float = 0.0 # Properties implied by the above tableaus, e.g. used to define fast-paths. ssal: bool = field(init=False) @@ -38,6 +40,53 @@ class ButcherTableau: implicit: bool = field(init=False) num_stages: int = field(init=False) + # Example! + # + # Consider a Butcher tableau: + # + # c1 | a11 a12 a13 a14 + # c2 | a21 a22 a23 a24 + # c3 | a31 a32 a33 a34 + # c4 | a41 a42 a43 a44 + # ---+---------------- + # | b1 b2 b3 b4 + # | β1 β2 β3 β4 + # + # Let y0 be the input to the step, and let y1 denote the output of the step. + # + # Then the output is computed via + # y1 = y0 + Σ_i bi ki + # where ki = fi dt (in the case of an ODE -- it is "fi dW" etc. for an SDE) + # and fi = f(ci, zi) + # and zi = y0 + Σ_j aij kj + # + # Note that "stage" may be used to refer to any of ki, fi, or zi. + # + # The error estimate is given by + # err = Σ_i βi ki + # (I.e. it is compute directly -- *not* as the difference of two solutions.) + # + # --- + # + # To encoder the above tableau in Diffrax, you would take: + # c = np.array([c2, c3, c4]) + # b_sol = np.array([b1, b2, b3, b4]) + # b_error = np.array([β1, β2, β3, β3]) + # a_lower = ( + # np.array([a21]), + # np.array([a31, a32]), + # np.array([a41, a42, a43]), + # ) + # a_diagonal = np.array([a11, a22, a33, a44]) # Optional if all zero + # c1 = c1 # Optional if zero + # + # Noting that a_diagonal and c1 are only used for implicit solvers, hence their + # optionality. + # + # In addition we support an additional `a_predictor` tableau for implicit solvers. + # This seems to be semi-new here; see + # https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ + def __post_init__(self): assert self.c.ndim == 1 for a_i in self.a_lower: @@ -70,17 +119,17 @@ def __post_init__(self): diagonal_b_sol_equal = self.b_sol[-1] == last_diagonal explicit_first_stage = self.a_diagonal is None or (self.a_diagonal[0] == 0) explicit_last_stage = self.a_diagonal is None or (self.a_diagonal[-1] == 0) - # Solution y1 is the same as the last stage + # (vector field)-control product `k1` is the same across first/last stages. object.__setattr__( self, - "ssal", - lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage, + "fsal", + lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, ) - # Vector field - control product k1 is the same across first/last stages. + # Solution `y1` is the same as the last stage object.__setattr__( self, - "fsal", - lower_b_sol_equal and diagonal_b_sol_equal and explicit_first_stage, + "ssal", + lower_b_sol_equal and diagonal_b_sol_equal and explicit_last_stage, ) object.__setattr__(self, "implicit", self.a_diagonal is not None) object.__setattr__(self, "num_stages", len(self.b_sol)) @@ -117,7 +166,12 @@ def __post_init__(self): class MultiButcherTableau(eqx.Module): """Wraps multiple [`diffrax.ButcherTableau`][]s together. Used in some multi-tableau - solvers, like stochastic Runge--Kutta methods or IMEX methods. + solvers, like IMEX methods. + + !!! important + + This API is not stable, and deliberately undocumented. (The reason is that we + might yet adapt this to implement Stochastic Runge--Kutta methods.) """ tableaus: Tuple[ButcherTableau, ...] @@ -138,19 +192,24 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta): `never`: used for explicit Runga--Kutta methods. - `every_step`: the Jacobian is calculated once per step; in particular it is - calculated at the start of the step and re-used for every stage in the step. - Used for SDIRK and ESDIRK methods. - `every_stage`: the Jacobian is calculated once per stage. Used for DIRK methods. + + `first_stage`: the Jacobian is calculated once per step; in particular it is + calculated in the first stage and re-used for every subsequent stage in the + step. Used for SDIRK methods. + + `second_stage`: the Jacobian is calculated once per step; in particular it is + calculated in the second stage and re-used for every subsequent stage in the + step. Used for ESDIRK methods. """ never = "never" - every_step = "every_step" every_stage = "every_stage" + first_stage = "first_stage" + second_stage = "second_stage" -_SolverState = Optional[tuple[Bool[Scalar, ""], PyTree[Array]]] +_SolverState = Optional[tuple[Scalar, PyTree[Array]]] # TODO: examine termination criterion for Newton iteration @@ -161,8 +220,8 @@ class CalculateJacobian(metaclass=eqxi.ContainerMeta): def _implicit_relation_f(fi, nonlinear_solve_args): diagonal, vf, prod, ti, yi_partial, args, control = nonlinear_solve_args diff = ( - vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω - - fi**ω + fi**ω + - vf(ti, (yi_partial**ω + diagonal * prod(fi, control) ** ω).ω, args) ** ω ).ω return diff @@ -174,8 +233,8 @@ def _implicit_relation_k(ki, nonlinear_solve_args): # (Bearing in mind that our ki is dt times smaller than theirs.) diagonal, vf_prod, ti, yi_partial, args, control = nonlinear_solve_args diff = ( - vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω - - ki**ω + ki**ω + - vf_prod(ti, (yi_partial**ω + diagonal * ki**ω).ω, args, control) ** ω ).ω return diff @@ -202,6 +261,19 @@ def _sum(*x): return total +def _filter_stop_gradient(x): + dynamic, static = eqx.partition(x, eqx.is_inexact_array) + dynamic = lax.stop_gradient(dynamic) + return eqx.combine(dynamic, static) + + +def _assert_same_structure(x, y): + x = jax.eval_shape(lambda: x) + y = jax.eval_shape(lambda: y) + x, y = jtu.tree_map(lambda a: (a.shape, a.dtype), (x, y)) + return eqx.tree_equal(x, y) is True + + class AbstractRungeKutta(AbstractAdaptiveSolver): """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit Runge--Kutta methods, which have a different computational structure.) @@ -216,7 +288,7 @@ class AbstractRungeKutta(AbstractAdaptiveSolver): instance of [`diffrax.CalculateJacobian`][]. """ - scan_kind: Union[None, Literal["lax"], Literal["checkpointed"]] = None + scan_kind: Union[None, Literal["lax", "checkpointed", "bounded"]] = None tableau: eqxi.AbstractClassVar[Union[ButcherTableau, MultiButcherTableau]] calculate_jacobian: eqxi.AbstractClassVar[CalculateJacobian] @@ -281,7 +353,7 @@ def init( _, fsal = self._common(terms, t0, t1, y0, args) if fsal: first_step = jnp.array(True) - f0 = sentinel = object() + f0 = sentinel if type(terms) is WrapTerm: # Privileged optimisations for some common cases _terms = terms.term @@ -309,686 +381,650 @@ def step( y0: PyTree, args: PyTree, solver_state: _SolverState, - made_jump: Bool, + made_jump: bool, ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: # - # Some Runge--Kutta methods have special structure that we can use to improve - # efficiency. + # Alright, settle in for what is probably the most advanced Runge-Kutta + # implementation on the planet. + # + # This is capable of handling all of: + # - Explicit Runge--Kutta methods (ERK) + # - Diagonal Implicit Runge--Kutta methods (DIRK) + # - Singular Diagonal Implicit Runge--Kutta methods (SDIRK) + # - Explicit Singular Diagonal Implicit Runge--Kutta methods (ESDIRK) + # - Implicit-Explicit Runge--Kutta methods (IMEX) + # + # In all cases it can handle applications to both ODEs and SDEs. + # Several of these are implicit methods. The latter two are multi-tableau + # methods. + # + # Both ODEs and SDEs: this is the usual innovation with Diffrax. We treat + # everything as a CDE against an arbitrary control. This also means we have a + # distinction between f-space (vector field values) and k-space + # ((vector field)-control products). + # + # Implicit methods: these all involve computing a Jacobian somewhere, and doing + # a root find. Any root finder can be used, although in practice the chord + # method is typical. Indeed it is common (SDIRK; ESDIRK) to reuse the Jacobian + # between stages. # - # The famous one is FSAL; "first same as last". That is, the final evaluation - # of the vector field on the previous step is the same as the first evaluation - # on the subsequent step. We can reuse it and save an evaluation. - # However note that this requires saving a vf evaluation, not a - # vf-control-product. (This comes up when we have a different control on the - # next step, e.g. as with adaptive step sizes, or with SDEs.) - # As such we disable FSAL if a vf is expensive and a vf-control-product is - # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. - # For this SDE, the vf-control product is a vector-Jacobian product, which is - # notably cheaper than evaluating a full Jacobian.) + # Multi-tableau methods: these are cases where each term has a different + # tableau, and their stages are interleaved. This means that the y-value at + # which we evaluate each stage depends on the previous stages of all tableaus. + # Note that these shouldn't be confused with splitting methods, where typically + # we solve one term using one solver, and then another term using another + # solver, without interleaving the stages. (Splitting methods instead interleave + # steps.) # - # Next we have SSAL; "solution same as last". That is, the output of the step - # has already been calculated during the internal stage calculations. We can - # reuse those and save a dot product. + # The other main innovation here (besides the unification of all these different + # solvers) is a JAX-specific thing: getting all of these to compile efficiently, + # with some tricks to trace through the vector field as few times as possible. # - # Finally we have a choice whether to save and work with vector field - # evaluations (fs), or to save and work with (vector field)-control products - # (ks). - # The former is needed for implicit FSAL solvers: they need to obtain the - # final f1 for the FSAL property, which means they need to do the implicit - # solve in vf-space rather than (vf-control-product)-space, which means they - # need to use `fs` to predict the initial point for the root finding operation. - # Meanwhile the latter is needed when solving optimise-then-discretise adjoint - # SDEs, for which vector field evaluations are prohibitively expensive, and we - # must necessarily work only with the (much cheaper) vf-control-products. (In - # this case this is the difference between computing a Jacobian and computing a - # vector-Jacobian product.) - # For other problems, we choose to use `ks`. This doesn't have a strong - # rationale although it does have some minor efficiency points in its favour, - # e.g. we need `ks` to perform dense interpolation if needed. + # As usual with JAX (and with a sprinkle of Equinox innovations), everything is + # also autovectorisable and autodifferentiable. + # + # This *doesn't* handle Fully Implicit Runge--Kutta methods (FIRK), as those + # have a different computational structure (they're just one big nonlinear + # solve). + # + # This also doesn't (yet) handle Stochastic Runge--Kutta methods (SRK), as those + # still require a bit more infrastructure: generating space-time Levy areas, or + # even space-space Levy areas. # - is_vf_expensive, fsal = self._common(terms, t0, t1, y0, args) + vf_expensive, fsal = self._common(terms, t0, t1, y0, args) - # The code below is actually quite generic: it handles a pytree of Butcher - # tableaus and a pytree of terms. - # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive. - # Here we just unpack from one to the other. + # The code below is actually quite generic: it handles a PyTree of Butcher + # tableaus and a PyTree of terms. (Which must match each other.) + # Our MultiTerm/MultiButcherTableau interface is slightly more restrictive, in + # that it only admits PyTree structures of `*` or `(*, ...)`. if isinstance(self.tableau, ButcherTableau): assert isinstance(terms, AbstractTerm) tableaus = self.tableau + implicit_tableau = self.tableau if self.tableau.implicit else None + implicit_term = terms if self.tableau.implicit else None else: assert isinstance(terms, MultiTerm) tableaus = self.tableau.tableaus terms = terms.terms - + assert len(tableaus) == len(terms) + for tab, term in zip(tableaus, terms): + if tab.implicit: + implicit_tableau = tab + implicit_term = term + break + else: + implicit_tableau = None + implicit_term = None assert jtu.tree_structure(terms, is_leaf=_is_term) == jtu.tree_structure( tableaus ) + # + # We have a choice whether to evaluate `vf` to get vector field evaluations + # ("values in f-space"), or to evaluate `vf_prod` to get (vector field)-control + # products ("values in k-space"). + # + # In addition we have a choice whether to *store* fs or ks. If we evaluate + # `vf_prod` then we must store ks, as we can't (cheaply) reconstruct fs from ks. + # If we evaluate `vf` then we can store either, as we can just do an + # `fs`-control product prior to storing them. + # + # The first most important case is if evaluating the vector field is expensive. + # The canonical example is solving optimise-then-discretise adjoint SDEs, for + # which the diffusion term takes the form (dg/dy)dW, which is a vjp against the + # control. This can be done most efficiently by never materialising the full + # diffusion matrix (the Jacobian dg/dy): don't call `vf`, and instead work + # directly with `vf_prod`. + # Cases of this nature are communicated via the `vf_expensive` flag. (Which + # in Diffrax by default is applied to all AdjointTerms with vector controls.) + # - Verdict: eval_fs=False, store_fs=False + # + # If we don't hit the above case, we consider FSAL. + # For any FSAL solver, we must evaluate `vf`: we need the final `f1` to pass to + # the next step. (The control changes from step-to-step, so we cannot simply + # pass `k1`.) + # In addition if the solver has an implicit tableau, then we must store `fs`. + # This is because to get the final f1, we need to do the implicit solve in + # f-space, which means we need to store fs to predict the initial point for the + # root finding operation. + # - Verdict: eval_fs=True, store_fs=True. + # If the solver is explicit-only, then we can store either. We choose to store + # ks instead, as this is perhaps slightly more efficient: other downstream tasks + # like error estimates and dense information use ks rather than fs. + # - Verdict: eval_fs=True, store_fs=False + # + # For all other cases, we don't have any hard restrictions. It *may* be the case + # that a user-provided term has an overloaded `vf_prod` to be more efficient. + # (The canonical example is if `vf` is the product of two matrices and the + # control is a vector: it's usually cheaper to do `A @ (B @ dx)` rather than + # `(A @ B) @ dx`.) Moreover downstream tasks like error estimatess and dense + # information still use ks rather than fs. So we also use ks in this case. + # - Verdict: eval_fs=False, store_fs=False + # + if vf_expensive: + eval_fs = False + store_fs = False + assert not fsal # fsal is disabled in this case + elif fsal: + if implicit_tableau is None: + eval_fs = True + store_fs = False + else: + eval_fs = True + store_fs = True + else: + eval_fs = False + store_fs = False + if not eval_fs: + assert not store_fs + + # + # We have a lot of PyTrees of various structures floating around. Here are some + # helpers to map over each structure. + # + # Structure of `terms` and `tableaus`. - def t_map(fn, *trees): - def _fn(_, *_trees): - return fn(*_trees) + def t_map(fn, *trees, implicit_val=sentinel): + def _fn(tableau, *_trees): + if tableau.implicit and implicit_val is not sentinel: + return implicit_val + else: + return fn(*_trees) return jtu.tree_map(_fn, tableaus, *trees) - def t_leaves(tree): - return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] - # Structure of `y` and `k`. - # (but not `f`, which can be arbitrary and different) - def s_map(fn, *trees): + def y_map(fn, *trees): def _fn(_, *_trees): return fn(*_trees) return jtu.tree_map(_fn, y0, *trees) - def ts_map(fn, *trees): - return t_map(lambda *_trees: s_map(fn, *_trees), *trees) + # Structure of `f`. Note that this is a suffix of `t_map`. + def f_map(fn, *trees): + def _fn(_, *_trees): + return fn(*_trees) + + assert f0 is not _unused + return jtu.tree_map(_fn, f0, *trees) + + def t_leaves(tree): + return [x.value for x in jtu.tree_leaves(t_map(_Leaf, tree))] + + def ty_map(fn, *trees): + return t_map(lambda *_trees: y_map(fn, *_trees), *trees) + + def get_implicit(xs): + def _get_implicit_impl(term, x): + nonlocal value + if term is implicit_term: + if value is sentinel: + value = x + else: + assert False + + value = sentinel + t_map(_get_implicit_impl, terms, xs) + assert value is not sentinel + return value - control = t_map(lambda term_i: term_i.contr(t0, t1), terms) dt = t1 - t0 + control = t_map(lambda term_i: term_i.contr(t0, t1), terms) + if implicit_tableau is None: + implicit_control = _unused + else: + implicit_control = get_implicit(control) - def vf(t, y): + def vf(t, y, *, implicit_val): + _assert_same_structure(y, y0) _vf = lambda term_i, t_i: term_i.vf(t_i, y, args) - return t_map(_vf, terms, t) + out = t_map(_vf, terms, t, implicit_val=implicit_val) + if f0 is not _unused: + _assert_same_structure(out, f0) + return out - def vf_prod(t, y): + def vf_prod(t, y, *, implicit_val): + _assert_same_structure(y, y0) _vf = lambda term_i, t_i, control_i: term_i.vf_prod(t_i, y, args, control_i) - return t_map(_vf, terms, t, control) + out = t_map(_vf, terms, t, control, implicit_val=implicit_val) + t_map(ft.partial(_assert_same_structure, y0), out) + return out def prod(f): + if f0 is not _unused: + _assert_same_structure(f, f0) _prod = lambda term_i, f_i, control_i: term_i.prod(f_i, control_i) - return t_map(_prod, terms, f, control) + out = t_map(_prod, terms, f, control) + t_map(ft.partial(_assert_same_structure, y0), out) + return out - num_stages = jtu.tree_leaves(tableaus)[0].num_stages + # + # Now get `f0` from an FSAL condition if possible. + # FSAL = first-same-as-last. It essentially refers to the last stage of the + # previous step only being used in error estimates, but not in advancing the + # solution. This means that it is also the value `vf(t0, y0)` in the this step. + # So provided our first stage is explicit (=necessarily just `vf(t0, y0)`) then + # we can skip evaluating our first stage. + # + # The only exception is on the very first step, or after a jump, in which case + # our stored value is invalid and must be (re-)computed. + # if fsal: assert solver_state is not None first_step, f0 = solver_state - stage_index = jnp.where(first_step, 0, 1) - # `made_jump` can be a tracer, hence the `is`. - if made_jump is False: - # Fast-path for compilation in the common case. - k0 = prod(f0) + eval_first_stage = eqxi.unvmap_any(first_step | made_jump) + init_stage_index = jnp.where(eval_first_stage, 0, 1) + # We do `fs.at[0].set(f0)` below. If we're actually going to evaluate the + # first stage, then zero out `f0` so that that is a no-op. + f0 = jtu.tree_map(lambda x: jnp.where(eval_first_stage, 0, x), f0) + if store_fs: + k0 = _unused else: - _t0 = t_map(lambda _: t0) - k0 = lax.cond(made_jump, lambda: vf_prod(_t0, y0), lambda: prod(f0)) - del _t0 + k0 = prod(f0) else: + # Non-FSAL solvers just iterate over all stages. f0 = _unused k0 = _unused - stage_index = 0 + init_stage_index = 0 del solver_state - # Must be initialised at zero as we do matmuls against the partially-filled - # array. - ks = t_map( - lambda: s_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0), - ) - if fsal: - ks = ts_map(lambda x, xs: xs.at[0].set(x), k0, ks) - - def embed_a_lower(tableau): - tableau_a_lower = np.zeros((num_stages, num_stages)) - for i, a_lower_i in enumerate(tableau.a_lower): - tableau_a_lower[i + 1, : i + 1] = a_lower_i - return jnp.asarray(tableau_a_lower) - - def embed_c(tableau): - tableau_c = np.zeros(num_stages) - tableau_c[1:] = tableau.c - return jnp.asarray(tableau_c) - - tableau_a_lower = t_map(embed_a_lower, tableaus) - tableau_c = t_map(embed_c, tableaus) - - def cond_fun(val): - _stage_index, *_ = val - return _stage_index < num_stages - - def body_fun(val): - stage_index, _, _, _, ks = val - a_lower_i = t_map(lambda t: t[stage_index], tableau_a_lower) - c_i = t_map(lambda t: t[stage_index], tableau_c) - # Unwrap buffers. This is only valid (=correct under autodiff) because we - # follow a triangular pattern and don't read from a location before it's - # written to, or write to the same location twice. - # (The reads in the matmuls don't count, as we initialise at zero.) - unsafe_ks = ts_map(lambda x: x[...], ks) - increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks) - yi_partial = s_map(_sum, y0, *t_leaves(increment)) - # No floating point error - ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i) - if fsal: - assert not is_vf_expensive - fi = vf(ti, yi_partial) - ki = prod(fi) - else: - fi = _unused - ki = vf_prod(ti, yi_partial) - ks = ts_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks) - return stage_index + 1, yi_partial, increment, fi, ks - - def buffers(val): - _, _, _, _, ks = val - return ks - - init_val = (stage_index, y0, t_map(lambda: y0), f0, ks) - final_val = eqxi.while_loop( - cond_fun, - body_fun, - init_val, - max_steps=num_stages, - buffers=buffers, - kind="checkpointed" if self.scan_kind is None else self.scan_kind, - checkpoints=num_stages, - ) - _, y1_partial, increment, f1, ks = final_val - - if all(tableau.ssal for tableau in jtu.tree_leaves(tableaus)): - y1 = y1_partial - else: - increment = t_map( - lambda t, k, i: i if t.ssal else vector_tree_dot(t.b_sol, k), - tableaus, - ks, - increment, - ) - y1 = s_map(_sum, y0, *t_leaves(increment)) - y_error = t_map(lambda t, k: vector_tree_dot(t.b_error, k), tableaus, ks) - dense_info = dict(y0=y0, y1=y1, k=ks) - if fsal: - new_solver_state = False, f1 - else: - new_solver_state = None - result = RESULTS.successful - return y1, y_error, dense_info, new_solver_state, result - - def old_step( - self, - terms: AbstractTerm, - t0: Scalar, - t1: Scalar, - y0: PyTree, - args: PyTree, - solver_state: _SolverState, - made_jump: Bool, - ) -> tuple[PyTree, PyTree, DenseInfo, _SolverState, RESULTS]: - # - # Some Runge--Kutta methods have special structure that we can use to improve - # efficiency. - # - # The famous one is FSAL; "first same as last". That is, the final evaluation - # of the vector field on the previous step is the same as the first evaluation - # on the subsequent step. We can reuse it and save an evaluation. - # However note that this requires saving a vf evaluation, not a - # vf-control-product. (This comes up when we have a different control on the - # next step, e.g. as with adaptive step sizes, or with SDEs.) - # As such we disable FSAL if a vf is expensive and a vf-control-product is - # cheap. (The canonical example is the optimise-then-discretise adjoint SDE. - # For this SDE, the vf-control product is a vector-Jacobian product, which is - # notably cheaper than evaluating a full Jacobian.) - # - # Next we have SSAL; "solution same as last". That is, the output of the step - # has already been calculated during the internal stage calculations. We can - # reuse those and save a dot product. - # - # Finally we have a choice whether to save and work with vector field - # evaluations (fs), or to save and work with (vector field)-control products - # (ks). - # The former is needed for implicit FSAL solvers: they need to obtain the - # final f1 for the FSAL property, which means they need to do the implicit - # solve in vf-space rather than (vf-control-product)-space, which means they - # need to use `fs` to predict the initial point for the root finding operation. - # Meanwhile the latter is needed when solving optimise-then-discretise adjoint - # SDEs, for which vector field evaluations are prohibitively expensive, and we - # must necessarily work only with the (much cheaper) vf-control-products. (In - # this case this is the difference between computing a Jacobian and computing a - # vector-Jacobian product.) - # For other problems, we choose to use `ks`. This doesn't have a strong - # rationale although it does have some minor efficiency points in its favour, - # e.g. we need `ks` to perform dense interpolation if needed. - # - - implicit_first_stage = self.tableau.implicit and self.tableau.a_diagonal[0] != 0 - # If we're computing the Jacobian at the start of the step, then we - # need this as a linearisation point. # - # If the first stage is implicit, then we need this as a predictor for - # where to start iterating from. - need_f0_or_k0 = ( - self.calculate_jacobian == CalculateJacobian.every_step - or implicit_first_stage - ) - vf_expensive, fsal = self._common(terms, t0, t1, y0, args) - if self.tableau.implicit and fsal: - use_fs = True - elif vf_expensive: - use_fs = False - else: # Choice not as important here; we use ks for minor efficiency reasons. - use_fs = False - del vf_expensive - - control = terms.contr(t0, t1) - dt = t1 - t0 - + # If using a DIRK or SDIRK implicit solver: we need to pick the location (in + # f-space or k-space) at which to compute our first Jacobian. + # See: https://docs.kidger.site/diffrax/devdocs/predictor_dirk/#first-stage # - # Calculate `f0` and `k0`. If this is just a first explicit stage then we'll - # sort that out later. But we might need these values for something else too - # (as a predictor for implicit stages; as a linearisation point for a Jacobian). - # - - f0 = None - k0 = None - if fsal: - f0 = solver_state - if not use_fs: - # `made_jump` can be a tracer, hence the `is`. - if made_jump is False: - # Fast-path for compilation in the common case. - k0 = terms.prod(f0, control) - else: - k0 = lax.cond( - made_jump, - lambda: terms.vf_prod(t0, y0, args, control), - lambda: terms.prod(f0, control), # noqa: F821 - ) + if self.calculate_jacobian == CalculateJacobian.never: # Typically ERK methods + f0_for_jac = _unused + k0_for_jac = _unused else: - if need_f0_or_k0: - if use_fs: - f0 = terms.vf(t0, y0, args) + if fsal: # Typically ESDIRK methods. + f0_for_jac = _unused + k0_for_jac = _unused + else: # Typically DIRK or SDIRK methods. + # Sadness. The extra evaluation increases compilation time, as we must + # trace our vector field again. + if eval_fs: + f0_for_jac = implicit_term.vf(t0, y0, args) + k0_for_jac = _unused else: - k0 = terms.vf_prod(t0, y0, args, control) + f0_for_jac = _unused + k0_for_jac = implicit_term.vf_prod(t0, y0, args, implicit_control) + # ( + # Possible sneaky sadness-ameliorating ideas which we don't do here: + # 1. Construct a candidate f0 or k0 by combining the stages of the + # previous step. I don't know of any theory for this but it sounds + # reasonable. As above the exact value here isn't that important. + # 2. Add an extra explicit stage at the end of the previous step, to do + # the above `vf` or `vf_prod` evaluation for us (FSAL-like, although + # this would actually end up being SSAL). Note that if we implemented + # that as `lax.cond(implicit, nonlinear_solve, explict_step)` then we + # would get no compile-time speedup (the goal here) as both branches + # involve tracing the vector field. So we would have to + # unconditionally run the nonlinear solver -- which is bad for + # runtime performance. So we don't do this. + # ) # - # Calculate `jac_f` and `jac_k` (maybe). That is to say, the Jacobian for use - # throughout an implicit method. In practice this is for SDIRK and ESDIRK - # methods, which use the same Jacobian throughout every stage. + # Create the buffers we'll populate with our f- or k-evaluations. # - jac_f = None - jac_k = None - if self.calculate_jacobian == CalculateJacobian.every_step: - assert self.tableau.a_diagonal is not None - # Skipping the first element to account for ESDIRK methods. - assert all( - x == self.tableau.a_diagonal[1] for x in self.tableau.a_diagonal[2:] + num_stages = jtu.tree_leaves(tableaus)[0].num_stages + # Must be initialised at zero as we later do matmuls against the + # partially-filled arrays. + if store_fs: + assert f0 is not _unused + fs = f_map(lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), f0) + ks = _unused + else: + fs = _unused + ks = t_map( + lambda: y_map( + lambda x: jnp.zeros((num_stages,) + x.shape, x.dtype), y0 + ), ) - diagonal0 = self.tableau.a_diagonal[1] - if use_fs: - if y0 is not None: - assert f0 is not None - jac_f = self.nonlinear_solver.jac( - _implicit_relation_f, - f0, - (diagonal0, terms.vf, terms.prod, t0, y0, args, control), - ) - else: - if y0 is not None: - assert k0 is not None - jac_k = self.nonlinear_solver.jac( - _implicit_relation_k, - k0, - (diagonal0, terms.vf_prod, t0, y0, args, control), - ) - del diagonal0 - - # - # Allocate `fs` or `ks` as a place to store the stage evaluations. - # - - if use_fs or fsal: - if f0 is None: - # Only perform this trace if we have to; tracing can actually be - # a bit expensive. - f0_struct = eqx.filter_eval_shape(terms.vf, t0, y0, args) + if fsal: + # !!! This is only valid because: + # - On the very first step, or if we have a jump, then `f0` and `k0` are + # zero and this is a no-op; + # - On later steps we have `init_stage_index=1` and thus don't write to + # index 0. + # We recall that the `buffers` of + # `eqxi.while_loop(..., kind="checkpointed", buffers=...)` + # must not have the same location written to multiple times, as otherwise + # we will get incorrect gradients. + # Either way we are correctly following the principle of "only write once". + if store_fs: + fs = f_map(lambda x, xs: xs.at[0].set(x), f0, fs) else: - f0_struct = jax.eval_shape(lambda: f0) # noqa: F821 - # else f0_struct deliberately left undefined, and is unused. - - num_stages = self.tableau.num_stages - if use_fs: - fs = jtu.tree_map(lambda f: jnp.zeros((num_stages,) + f.shape), f0_struct) - ks = None - else: - fs = None - ks = jtu.tree_map(lambda k: jnp.zeros((num_stages,) + jnp.shape(k)), y0) + ks = ty_map(lambda x, xs: xs.at[0].set(x), k0, ks) # - # First stage. Defines `result`, `scan_first_stage`. Places `f0` and `k0` into - # `fs` and `ks`. (+Redefines them if it's an implicit first stage.) Consumes - # `f0` and `k0`. + # Transform our tableaus into full square tableaus. (Rather than just the + # triangular ones in which they're stored.) This is needed so that we can do + # matvecs against them, which can't be of variable length. + # (We could maybe implement a variable-length matvec by using a while loop -- + # not clear that that would necessarily get good performance though. Not + # benchmarked.) # - if fsal: - scan_first_stage = False - result = RESULTS.successful - else: - if implicit_first_stage: - scan_first_stage = False - assert self.tableau.a_diagonal is not None - diagonal0 = self.tableau.a_diagonal[0] - if self.tableau.a_diagonal[0] == 1: - # No floating point error - t0_ = t1 - else: - t0_ = t0 + self.tableau.a_diagonal[0] * dt - if use_fs: - if y0 is not None: - assert jac_f is not None - nonlinear_sol = self.nonlinear_solver( - _implicit_relation_f, - f0, - (diagonal0, terms.vf, terms.prod, t0_, y0, args, control), - jac_f, - ) - f0 = nonlinear_sol.root - result = nonlinear_sol.result - else: - if y0 is not None: - assert jac_k is not None - nonlinear_sol = self.nonlinear_solver( - _implicit_relation_k, - k0, - (diagonal0, terms.vf_prod, t0_, y0, args, control), - jac_k, - ) - k0 = nonlinear_sol.root - result = nonlinear_sol.result - del diagonal0, t0_, nonlinear_sol - else: - scan_first_stage = True - result = RESULTS.successful - - if scan_first_stage: - assert f0 is None - assert k0 is None - else: - if use_fs: - if y0 is not None: - assert f0 is not None - fs = ω(fs).at[0].set(ω(f0)).ω - else: - if y0 is not None: - assert k0 is not None - ks = ω(ks).at[0].set(ω(k0)).ω - - del f0, k0 + def embed_a_lower(tab): + tab_a_lower = np.zeros((num_stages, num_stages)) + for i, a_lower_i in enumerate(tab.a_lower): + tab_a_lower[i + 1, : i + 1] = a_lower_i + return jnp.asarray(tab_a_lower) + + def embed_c(tab): + tab_c = np.zeros(num_stages) + if tab.c1 is not None: + tab_c[0] = tab.c1 + tab_c[1:] = tab.c + return jnp.asarray(tab_c) + + tableaus_a_lower = t_map(embed_a_lower, tableaus) + tableaus_c = t_map(embed_c, tableaus) + + if implicit_tableau is not None: + implicit_diagonal = jnp.asarray(implicit_tableau.a_diagonal) + implicit_predictor = np.zeros((num_stages, num_stages)) + for i, a_predictor_i in enumerate(implicit_tableau.a_predictor): + implicit_predictor[i + 1, : i + 1] = a_predictor_i + implicit_predictor = jnp.asarray(implicit_predictor) + implicit_c = get_implicit(tableaus_c) # - # Iterate through the stages. Fills in `fs` and `ks`. Consumes - # `scan_first_stage`. + # Run the loop over stages. (This is what you signed up for, and it's taken us + # several hundred lines of code just to get this far!) # - def eval_stage(_carry, _input): - _, _, _fs, _ks, _result = _carry - _i, _a_lower_i, _a_diagonal_i, _a_predictor_i, _c_i = _input - # Unwrap buffers. Take advantage of the fact that they're initialised at - # zero, so that we don't really read from a location before its written to. - _unsafe_fs_unwrapped = jtu.tree_map(lambda _, x: x[...], fs, _fs) - _unsafe_ks_unwrapped = jtu.tree_map(lambda _, x: x[...], ks, _ks) + def cond_stage(val): + stage_index, *_ = val + return stage_index < num_stages + def rk_stage(val): + stage_index, _, _, jac_f, jac_k, fs, ks, result = val # - # Evaluate the linear combination of previous stages + # Start by getting the linear combination of previous stages. # - - if use_fs: - _increment = vector_tree_dot(_a_lower_i, _unsafe_fs_unwrapped) - _increment = terms.prod(_increment, control) + a_lower_i = t_map(lambda tab: tab[stage_index], tableaus_a_lower) + c_i = t_map(lambda tab: tab[stage_index], tableaus_c) + # Unwrap buffers. This is only valid (=correct under autodiff) because we + # follow a triangular pattern and don't read from a location before it is + # written to, or write to the same location twice. + # (The reads in the vector_tree_dots don't count, as the operands are zero.) + if store_fs: + assert fs is not _unused + unsafe_fs = f_map(lambda x: x[...], fs) + unsafe_ks = _unused + increment = prod(t_map(vector_tree_dot, a_lower_i, unsafe_fs)) else: - _increment = vector_tree_dot(_a_lower_i, _unsafe_ks_unwrapped) - _yi_partial = (y0**ω + _increment**ω).ω - + assert ks is not _unused + unsafe_fs = _unused + unsafe_ks = ty_map(lambda x: x[...], ks) + increment = t_map(vector_tree_dot, a_lower_i, unsafe_ks) + yi_partial = y_map(_sum, y0, *t_leaves(increment)) # - # Figure out if we're computing a vector field ("f") or a - # vector-field-product ("k") - # - # Ask for fi if we're using fs; ask for ki if we're using ks. Makes sense! - # In addition, ask for fi if we're using an FSAL scheme, as we'll be passing - # that on to the next step. + # Find the y value at which to evaluate this stage. + # If we have only explicit tableaus, then this is just the linear + # combination we found above. + # If we have an implicit tableau, then perform the implicit solve. + # Note that we perform the solve in f-space or k-space; not y-space. # + if implicit_tableau is None: + implicit_fi = sentinel + implicit_ki = sentinel + yi = yi_partial + else: + implicit_diagonal_i = implicit_diagonal[stage_index] + implicit_predictor_i = implicit_predictor[stage_index] + implicit_c_i = implicit_c[stage_index] + # No floating point error + implicit_ti = jnp.where(implicit_c_i == 1, t1, t0 + implicit_c_i * dt) + if_first_stage = ft.partial(jnp.where, stage_index == 0) + if eval_fs: + f_pred = get_implicit( + vector_tree_dot(implicit_predictor_i, unsafe_fs) + ) + if not fsal: + # FSAL => explicit first stage so the choice of predictor + # doesn't matter. + f_pred = jtu.tree_map(if_first_stage, f0_for_jac, f_pred) + f_implicit_args = ( + implicit_diagonal_i, + implicit_term.vf, + implicit_term.prod, + implicit_ti, + yi_partial, + args, + implicit_control, + ) + k_pred = _unused + k_implicit_args = _unused + else: + f_pred = _unused + f_implicit_args = _unused + k_pred = vector_tree_dot( + implicit_predictor_i, get_implicit(unsafe_ks) + ) + if not fsal: + # FSAL => explicit first stage so the choice of predictor + # doesn't matter. + k_pred = jtu.tree_map(if_first_stage, k0_for_jac, k_pred) + k_implicit_args = ( + implicit_diagonal_i, + implicit_term.vf_prod, + implicit_ti, + yi_partial, + args, + implicit_control, + ) - _return_fi = use_fs or fsal - _return_ki = not use_fs + def eval_f_jac(): + return self.nonlinear_solver.jac( + _implicit_relation_f, + lax.stop_gradient(f_pred), + _filter_stop_gradient(f_implicit_args), + ) - # - # Evaluate the stage - # + def eval_k_jac(): + return self.nonlinear_solver.jac( + _implicit_relation_k, + lax.stop_gradient(k_pred), + _filter_stop_gradient(k_implicit_args), + ) - _ti = jnp.where(_c_i == 1, t1, t0 + _c_i * dt) # No floating point error - if self.tableau.implicit: - assert _a_diagonal_i is not None - # Predictor for where to start iterating from - if _return_fi: - _f_pred = vector_tree_dot(_a_predictor_i, _unsafe_fs_unwrapped) - else: - _k_pred = vector_tree_dot(_a_predictor_i, _unsafe_ks_unwrapped) - # Determine Jacobian to use at this stage if self.calculate_jacobian == CalculateJacobian.every_stage: - if _return_fi: - _jac_f = self.nonlinear_solver.jac( - _implicit_relation_f, - _f_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - ) - _jac_k = None + if eval_fs: + jac_f = eval_f_jac() + jac_k = _unused else: - _jac_f = None - _jac_k = self.nonlinear_solver.jac( - _implicit_relation_k, - _k_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - ) + jac_f = _unused + jac_k = eval_k_jac() else: - assert self.calculate_jacobian == CalculateJacobian.every_step - _jac_f = jac_f - _jac_k = jac_k - # Solve nonlinear problem - if _return_fi: - if y0 is not None: - assert _jac_f is not None - _nonlinear_sol = self.nonlinear_solver( - _implicit_relation_f, - _f_pred, - ( - _a_diagonal_i, - terms.vf, - terms.prod, - _ti, - _yi_partial, - args, - control, - ), - _jac_f, - ) - _fi = _nonlinear_sol.root - if _return_ki: - _ki = terms.prod(_fi, control) + if self.calculate_jacobian == CalculateJacobian.first_stage: + assert len(set(implicit_tableau.a_diagonal)) == 1 + jac_stage_index = 0 else: - _ki = None - else: - if _return_ki: - if y0 is not None: - assert _jac_k is not None - _nonlinear_sol = self.nonlinear_solver( - _implicit_relation_k, - _k_pred, - ( - _a_diagonal_i, - terms.vf_prod, - _ti, - _yi_partial, - args, - control, - ), - _jac_k, + assert self.calculate_jacobian == CalculateJacobian.second_stage + assert implicit_tableau.a_diagonal[0] == 0 + assert len(set(implicit_tableau.a_diagonal[1:])) == 1 + jac_stage_index = 1 + stage_index = eqxi.nonbatchable(stage_index) + # These `stop_gradients` are needed to work around the lack of + # symbolic zeros in `custom_vjp`s. + if eval_fs: + jac_f = lax.stop_gradient(jac_f) + jac_f = lax.cond( + stage_index == jac_stage_index, eval_f_jac, lambda: jac_f ) - _fi = None - _ki = _nonlinear_sol.root + jac_k = _unused else: - assert False - _result = update_result(_result, _nonlinear_sol.result) - del _nonlinear_sol - else: - # Explicit stage - if _return_fi: - _fi = terms.vf(_ti, _yi_partial, args) - if _return_ki: - _ki = terms.prod(_fi, control) - else: - _ki = None + jac_f = _unused + jac_k = lax.stop_gradient(jac_k) + jac_k = lax.cond( + stage_index == jac_stage_index, eval_k_jac, lambda: jac_k + ) + if eval_fs: + jac_f = eqxi.nondifferentiable(jac_f, name="jac_f") + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_f, f_pred, f_implicit_args, jac_f + ) + implicit_fi = nonlinear_sol.root + implicit_ki = _unused + implicit_inc = implicit_term.prod(implicit_fi, implicit_control) else: - _fi = None - if _return_ki: - _ki = terms.vf_prod(_ti, _yi_partial, args, control) - else: - assert False - + assert not fsal + jac_k = eqxi.nondifferentiable(jac_k, name="jac_k") + nonlinear_sol = self.nonlinear_solver( + _implicit_relation_k, k_pred, k_implicit_args, jac_k + ) + implicit_fi = _unused + implicit_ki = implicit_inc = nonlinear_sol.root + yi = y_map( + lambda a, b: a + implicit_diagonal_i * b, yi_partial, implicit_inc + ) + result = update_result(result, nonlinear_sol.result) # - # Store output + # Now evaluate our vector field at the value yi. + # If we had an implicit tableau then we can skip evaluating the vector field + # for that tableau, as we did the solve in f-space or k-space and already + # have its value. # - - if use_fs: - _fs = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _fi, _fs) - else: - _ks = jtu.tree_map(lambda x, xs: xs.at[_i].set(x), _ki, _ks) - if self.tableau.ssal: - _yi_partial_out = _yi_partial + # No floating point error + ti = t_map(lambda _c_i: jnp.where(_c_i == 1, t1, t0 + _c_i * dt), c_i) + if eval_fs: + assert not vf_expensive + assert implicit_fi is not _unused + fi = vf(ti, yi, implicit_val=implicit_fi) + if store_fs: + ki = _unused + else: + ki = prod(fi) else: - _yi_partial_out = None + assert implicit_ki is not _unused + assert not store_fs + fi = _unused + ki = vf_prod(ti, yi, implicit_val=implicit_ki) + # + # Update our outputs + # if fsal: - _fi_out = _fi + assert fi is not _unused + f1_for_fsal = fi else: - _fi_out = None - return (_yi_partial_out, _fi_out, _fs, _ks, _result), None + f1_for_fsal = _unused + if store_fs: + assert fi is not _unused + assert fs is not _unused + fs = f_map(lambda x, xs: xs.at[stage_index].set(x), fi, fs) + else: + assert ki is not _unused + assert ks is not _unused + ks = ty_map(lambda x, xs: xs.at[stage_index].set(x), ki, ks) + return ( + stage_index + 1, + yi, + f1_for_fsal, + jac_f, + jac_k, + fs, + ks, + result, + ) - # - # Iterate over stages - # + def buffers(val): + *_, fs, ks, _ = val + return fs, ks - if scan_first_stage: - tableau_a_lower = np.zeros((num_stages, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i + 1, : i + 1] = a_lower_i - tableau_a_diagonal = self.tableau.a_diagonal - tableau_a_predictor = self.tableau.a_predictor - tableau_c = np.zeros(num_stages) - tableau_c[1:] = self.tableau.c - i_init = 0 - assert tableau_a_diagonal is None - assert tableau_a_predictor is None - else: - tableau_a_lower = np.zeros((num_stages - 1, num_stages)) - for i, a_lower_i in enumerate(self.tableau.a_lower): - tableau_a_lower[i, : i + 1] = a_lower_i - if self.tableau.a_diagonal is None: - tableau_a_diagonal = None - else: - tableau_a_diagonal = self.tableau.a_diagonal[1:] - if self.tableau.a_predictor is None: - tableau_a_predictor = None - else: - tableau_a_predictor = np.zeros((num_stages - 1, num_stages)) - for i, a_predictor_i in enumerate(self.tableau.a_predictor): - tableau_a_predictor[i, : i + 1] = a_predictor_i - tableau_c = self.tableau.c - i_init = 1 - if self.tableau.ssal: - y_dummy = y0 - else: - y_dummy = None if fsal: - f_dummy = jtu.tree_map( - lambda x: jnp.zeros(x.shape, dtype=x.dtype), f0_struct - ) + assert f0 is not _unused + dummy_f = f0 else: - f_dummy = None - if self.scan_kind is None: - scan_kind = "checkpointed" + dummy_f = _unused + if self.calculate_jacobian == CalculateJacobian.never: + jac_f = _unused + jac_k = _unused else: - scan_kind = self.scan_kind - (y1_partial, f1, fs, ks, result), _ = eqxi.scan( - eval_stage, - (y_dummy, f_dummy, fs, ks, result), - ( - np.arange(i_init, num_stages), - tableau_a_lower, - tableau_a_diagonal, - tableau_a_predictor, - tableau_c, - ), - buffers=lambda x: (x[2], x[3]), # fs and ks - kind=scan_kind, - checkpoints="all", + # Set the initial Jacobian to be the identity matrix. + # For DIRK and SDIRK methods then the choice here doesn't matter; we compute + # the Jacobian straight away. + # For ESDIRK methods, this is the Jacobian of an explicit step. + # + # TODO: fix once we have more advanced nonlinear solvers. + # Mildly hacky hardcoding for now. + if eval_fs: + assert f0 is not _unused + struct = jax.eval_shape(lambda: jfu.ravel_pytree(get_implicit(f0))[0]) + jac_f = ( + jnp.eye(struct.size, dtype=struct.dtype), + jnp.arange(struct.size, dtype=jnp.int32), + ) + jac_k = _unused + else: + struct = jax.eval_shape(lambda: jfu.ravel_pytree(y0)[0]) + jac_f = _unused + jac_k = ( + jnp.eye(struct.size, dtype=struct.dtype), + jnp.arange(struct.size, dtype=jnp.int32), + ) + init_val = ( + init_stage_index, + y0, + dummy_f, + jac_f, + jac_k, + fs, + ks, + RESULTS.successful, + ) + # Needs to be an `eqxi.while_loop` as: + # (a) we may have variable length: e.g. an FSAL explicit RK scheme will have one + # more stage on the first step. + # (b) to work around a limitation of JAX's autodiff being unable to express + # "triangular computations" (every stage depends on all previous stages) + # without spurious copies. + final_val = eqxi.while_loop( + cond_stage, + rk_stage, + init_val, + max_steps=num_stages, + buffers=buffers, + kind="checkpointed" if self.scan_kind is None else self.scan_kind, + checkpoints=num_stages, + base=num_stages, ) - del y_dummy, f_dummy, scan_first_stage + _, y1, f1_for_fsal, _, _, fs, ks, result = final_val # - # Compute step output + # Calculate outputs: the final `y1` from our step, any dense information, etc. # - if self.tableau.ssal: - y1 = y1_partial - else: - if use_fs: - increment = vector_tree_dot(self.tableau.b_sol, fs) - increment = terms.prod(increment, control) + if store_fs: + assert ks == _unused + if fs is None: + # Handle edge-case of y0=None + ks = None else: - increment = vector_tree_dot(self.tableau.b_sol, ks) - y1 = (y0**ω + increment**ω).ω + ks = jax.vmap(prod)(fs) + if any(not tableau.ssal for tableau in jtu.tree_leaves(tableaus)): - # - # Compute error estimate - # + def _increment(tab_i, k_i): + return vector_tree_dot(tab_i.b_sol, k_i) - if use_fs: - y_error = vector_tree_dot(self.tableau.b_error, fs) - y_error = terms.prod(y_error, control) - else: - y_error = vector_tree_dot(self.tableau.b_error, ks) + increment = t_map(_increment, tableaus, ks) + y1 = y_map(_sum, y0, *t_leaves(increment)) + y_error = t_map(lambda tab, k: vector_tree_dot(tab.b_error, k), tableaus, ks) + y_error = y_map(_sum, *t_leaves(y_error)) y_error = jtu.tree_map( lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf), y_error, ) # i.e. an implicit step failed to converge - - # - # Compute dense info - # - - if use_fs: - if fs is None: - # Edge case for diffeqsolve(y0=None) - ks = None - else: - ks = jax.vmap(lambda f: terms.prod(f, control))(fs) dense_info = dict(y0=y0, y1=y1, k=ks) - - # - # Compute next solver state - # - if fsal: - solver_state = f1 + new_solver_state = False, f1_for_fsal else: - solver_state = None - - return y1, y_error, dense_info, solver_state, result + new_solver_state = None + return y1, y_error, dense_info, new_solver_state, result class AbstractERK(AbstractRungeKutta): @@ -1024,7 +1060,7 @@ def __init_subclass__(cls, **kwargs): diagonal = cls.tableau.a_diagonal[0] assert (cls.tableau.a_diagonal == diagonal).all() - calculate_jacobian = CalculateJacobian.every_step + calculate_jacobian = CalculateJacobian.second_stage class AbstractESDIRK(AbstractDIRK): @@ -1042,4 +1078,4 @@ def __init_subclass__(cls, **kwargs): diagonal = cls.tableau.a_diagonal[1] assert (cls.tableau.a_diagonal[1:] == diagonal).all() - calculate_jacobian = CalculateJacobian.every_step + calculate_jacobian = CalculateJacobian.second_stage diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py index e5eaa499..e0267b0d 100644 --- a/diffrax/solver/semi_implicit_euler.py +++ b/diffrax/solver/semi_implicit_euler.py @@ -16,7 +16,8 @@ class SemiImplicitEuler(AbstractSolver): """Semi-implicit Euler's method. - Symplectic method. Does not support adaptive step sizing. + Symplectic method. Does not support adaptive step sizing. Uses 1st order local + linear interpolation for dense/ts output. """ term_structure = (AbstractTerm, AbstractTerm) diff --git a/diffrax/solver/sil3.py b/diffrax/solver/sil3.py new file mode 100644 index 00000000..86f80993 --- /dev/null +++ b/diffrax/solver/sil3.py @@ -0,0 +1,86 @@ +import numpy as np +from equinox.internal import ω + +from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation +from .base import AbstractImplicitSolver +from .runge_kutta import ( + AbstractRungeKutta, + ButcherTableau, + CalculateJacobian, + MultiButcherTableau, +) + + +# See +# https://docs.kidger.site/diffrax/devdocs/predictor_dirk/ +# for the construction of the a_predictor tableau, which is new here. +_implicit_tableau = ButcherTableau( + a_lower=( + np.array([1 / 6]), + np.array([1 / 3, 0]), + np.array([3 / 8, 0, 3 / 8]), + ), + b_sol=np.array([3 / 8, 0, 3 / 8, 1 / 4]), + b_error=np.array( + [1 / 8, 0, -3 / 8, 1 / 4] + ), # just Heun; could maybe do something else + c=np.array([1 / 3, 2 / 3, 1]), + a_diagonal=np.array([0, 1 / 6, 1 / 3, 1 / 4]), + a_predictor=( + np.array([1.0]), + np.array([-1.0, 2.0]), + np.array([-1.0, 2.0, 0.0]), # arbitrary choice for this one + ), +) +_explicit_tableau = ButcherTableau( + a_lower=( + np.array([1 / 3]), + np.array([1 / 6, 0.5]), + np.array([0.5, -0.5, 1]), + ), + b_sol=np.array([0.5, -0.5, 1, 0]), + b_error=np.array([0, 0.5, -1, 0.5]), # just Heun; could maybe do something else + c=np.array([1 / 3, 2 / 3, 1]), +) + + +class Sil3(AbstractRungeKutta, AbstractImplicitSolver): + """Whitaker--Kar's fast-slow IMEX method. + + 3rd order in the explicit (ERK) term; 2nd order in the implicit (EDIRK) term. Uses + a 2nd-order embedded Heun method for adaptive step sizing. Uses 4 stages with FSAL. + Uses 2nd order Hermite interpolation for dense/ts output. + + This should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + + ??? Reference + + ```bibtex + @article{whitaker2013implicit, + author={Jeffrey S. Whitaker and Sajal K. Kar}, + title={Implicit–Explicit Runge–Kutta Methods for Fast–Slow Wave Problems}, + journal={Monthly Weather Review}, + year={2013}, + publisher={American Meteorological Society}, + volume={141}, + number={10}, + doi={https://doi.org/10.1175/MWR-D-13-00132.1}, + pages={3426--3434}, + } + ``` + """ + + tableau = MultiButcherTableau(_explicit_tableau, _implicit_tableau) + calculate_jacobian = CalculateJacobian.every_stage + + @staticmethod + def interpolation_cls(t0, t1, y0, y1, k): + k_explicit, k_implicit = k + k0 = (ω(k_explicit)[0] + ω(k_implicit)[0]).ω + k1 = (ω(k_explicit)[-1] + ω(k_implicit)[-1]).ω + return ThirdOrderHermitePolynomialInterpolation( + t0=t0, t1=t1, y0=y0, y1=y1, k0=k0, k1=k1 + ) + + def order(self, terms): + return 2 diff --git a/diffrax/solver/tsit5.py b/diffrax/solver/tsit5.py index 8322aad7..63fff33a 100644 --- a/diffrax/solver/tsit5.py +++ b/diffrax/solver/tsit5.py @@ -98,9 +98,14 @@ class _Tsit5Interpolation(AbstractLocalInterpolation): y0: PyTree[Array[...]] - y1: PyTree[Array[...]] # Unused, just here for API compatibility k: PyTree[Array["order":7, ...]] # noqa: F821 + def __init__(self, *, y0, y1, k, **kwargs): + del y1 # exists for API compatibility + super().__init__(**kwargs) + self.y0 = y0 + self.k = k + def evaluate( self, t0: Scalar, t1: Optional[Scalar] = None, left: bool = True ) -> PyTree: # noqa: F821 @@ -147,7 +152,8 @@ class Tsit5(AbstractERK): r"""Tsitouras' 5/4 method. 5th order explicit Runge--Kutta method. Has an embedded 4th order method for - adaptive step sizing. + adaptive step sizing. Uses 7 stages with FSAL. Uses 5th order interpolation + for dense/ts output. ??? cite "Reference" diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md index 23775db7..2942c989 100644 --- a/docs/api/solvers/abstract_solvers.md +++ b/docs/api/solvers/abstract_solvers.md @@ -81,11 +81,6 @@ In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use members: - __init__ -::: diffrax.MultiButcherTableau - selection: - members: - - __init__ - ::: diffrax.CalculateJacobian selection: members: false diff --git a/docs/api/solvers/ode_solvers.md b/docs/api/solvers/ode_solvers.md index 24c8e731..1e128694 100644 --- a/docs/api/solvers/ode_solvers.md +++ b/docs/api/solvers/ode_solvers.md @@ -72,6 +72,32 @@ Each of these takes a `nonlinear_solver` argument at initialisation, defaulting --- +### IMEX methods + +These "implicit-explicit" methods are suitable for problems of the form $\frac{\mathrm{d}y}{\mathrm{d}t} = f(t, y(t)) + g(t, y(t))$, where $f$ is the non-stiff part (explicit integration) and $g$ is the stiff part (implicit integration). + +??? info "Term structure" + + These methods should be called with `terms=MultiTerm(explicit_term, implicit_term)`. + +::: diffrax.Sil3 + selection: + members: false + +::: diffrax.KenCarp3 + selection: + members: false + +::: diffrax.KenCarp4 + selection: + members: false + +::: diffrax.KenCarp5 + selection: + members: false + +--- + ### Symplectic methods These methods are suitable for problems with symplectic structure; that is to say those ODEs of the form diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index 73aed4ce..713b4cc8 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -34,6 +34,10 @@ See also the [Stiff ODE example](../examples/stiff_ode.ipynb). - Taking many more solver steps than necessary (e.g. 8 steps -> 800 steps); - Wrapping with `jax.value_and_grad` or `jax.grad` actually changing the result of the primal (forward) computation. +### Split problems + +For "split stiffness" problems, with one term that is stiff and another term that is non-stiff, then IMEX methods are appropriate: [`diffrax.KenCarp4`][] is recommended. In addition you should almost always use an adaptive step size controller such as [`diffrax.PIDController`][]. + --- ## Stochastic differential equations diff --git a/test/helpers.py b/test/helpers.py index 4a5fa749..b4764ffe 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -25,6 +25,13 @@ diffrax.Kvaerno5(), ) +all_split_solvers = ( + diffrax.Sil3(), + diffrax.KenCarp3(), + diffrax.KenCarp4(), + diffrax.KenCarp5(), +) + def implicit_tol(solver): if isinstance(solver, diffrax.AbstractImplicitSolver): diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index 6c70d827..cfcce19f 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -1,5 +1,6 @@ import functools as ft import operator +from typing import Tuple import diffrax import jax @@ -8,7 +9,7 @@ import jax.tree_util as jtu import pytest -from .helpers import all_ode_solvers, implicit_tol, shaped_allclose +from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose @pytest.mark.parametrize("mode", ["linear", "linear2", "cubic"]) @@ -315,8 +316,18 @@ def _test(firstderiv, derivs, y0, y1): def _test_dense_interpolation(solver, key, t1): y0 = jrandom.uniform(key, (), minval=0.4, maxval=2) dt0 = t1 / 1e3 + if ( + solver.term_structure + == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] + ): + term = diffrax.MultiTerm( + diffrax.ODETerm(lambda t, y, args: -0.7 * y), + diffrax.ODETerm(lambda t, y, args: -0.3 * y), + ) + else: + term = diffrax.ODETerm(lambda t, y, args: -y) sol = diffrax.diffeqsolve( - diffrax.ODETerm(lambda t, y, args: -y), + term, solver=solver, t0=0, t1=t1, @@ -334,7 +345,7 @@ def _test_dense_interpolation(solver, key, t1): return vals, true_vals, derivs, true_derivs -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_dense_interpolation(solver, getkey): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) @@ -360,7 +371,7 @@ def test_dense_interpolation(solver, getkey): # When vmap'ing then it can happen that some batch elements take more steps to solve # than others. This means some padding is used to make things line up; here we test # that all of this works as intended. -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_dense_interpolation_vmap(solver, getkey): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) diff --git a/test/test_integrate.py b/test/test_integrate.py index 30d7e74b..b55ee318 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -1,5 +1,6 @@ import math import operator +from typing import Tuple import diffrax import equinox as eqx @@ -13,6 +14,7 @@ from .helpers import ( all_ode_solvers, + all_split_solvers, implicit_tol, random_pytree, shaped_allclose, @@ -115,7 +117,7 @@ def f(t, y, args): assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("solver", all_ode_solvers) +@pytest.mark.parametrize("solver", all_ode_solvers + all_split_solvers) def test_ode_order(solver): solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) @@ -123,10 +125,24 @@ def test_ode_order(solver): A = jrandom.normal(akey, (10, 10), dtype=jnp.float64) * 0.5 - def f(t, y, args): - return A @ y + if ( + solver.term_structure + == diffrax.MultiTerm[Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm]] + ): + + def f1(t, y, args): + return 0.3 * A @ y + + def f2(t, y, args): + return 0.7 * A @ y + + term = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2)) + else: + + def f(t, y, args): + return A @ y - term = diffrax.ODETerm(f) + term = diffrax.ODETerm(f) t0 = 0 t1 = 4 y0 = jrandom.normal(ykey, (10,), dtype=jnp.float64) diff --git a/test/test_interpolation.py b/test/test_interpolation.py index 2c280579..03113b3a 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import jax.random as jrandom -from .helpers import all_ode_solvers, implicit_tol, shaped_allclose +from .helpers import all_ode_solvers, all_split_solvers, implicit_tol, shaped_allclose def _test_path_derivative(path, name): @@ -69,6 +69,24 @@ def test_derivative(getkey): y1 = solution.ys[-1] paths.append((solution, type(solver).__name__, y0, y1)) + for solver in all_split_solvers: + solver = implicit_tol(solver) + y0 = jrandom.normal(getkey(), (3,)) + solution = diffrax.diffeqsolve( + diffrax.MultiTerm( + diffrax.ODETerm(lambda t, y, p: -0.7 * y), + diffrax.ODETerm(lambda t, y, p: -0.3 * y), + ), + solver, + 0, + 1, + 0.01, + y0, + saveat=diffrax.SaveAt(dense=True, t1=True), + ) + y1 = solution.ys[-1] + paths.append((solution, type(solver).__name__, y0, y1)) + # actually do tests for path, name, y0, y1 in paths: diff --git a/test/test_solver.py b/test/test_solver.py index 36d3b10e..ea161a09 100644 --- a/test/test_solver.py +++ b/test/test_solver.py @@ -1,9 +1,13 @@ +from typing import Tuple + import diffrax import equinox as eqx import jax.numpy as jnp import jax.random as jr import pytest +from .helpers import shaped_allclose + def test_half_solver(): term = diffrax.ODETerm(lambda t, y, args: -y) @@ -49,16 +53,60 @@ def test_implicit_euler_adaptive(): assert out2.result == diffrax.RESULTS.successful -def test_multiple_tableau1(): +@pytest.mark.parametrize("vf_expensive", (False, True)) +def test_multiple_tableau_single_step(vf_expensive): class DoubleDopri5(diffrax.AbstractRungeKutta): tableau = diffrax.MultiButcherTableau( diffrax.Dopri5.tableau, diffrax.Dopri5.tableau ) + interpolation_cls = None calculate_jacobian = diffrax.CalculateJacobian.never - def interpolation_cls(self, *, k, **kwargs): + mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0)) + mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1)) + term1 = diffrax.ODETerm(lambda t, y, args: mlp1(y)) + term2 = diffrax.ODETerm(lambda t, y, args: mlp2(y)) + terms = diffrax.MultiTerm(term1, term2) + solver1 = diffrax.Dopri5() + solver2 = DoubleDopri5() + t0 = 0.3 + t1 = 0.7 + y0 = jnp.array([1.0, 2.0]) + if vf_expensive: + # Huge hack, do this via subclassing AbstractTerm if you're going to do this + # properly! + object.__setattr__(terms, "is_vf_expensive", lambda t0, t1, y, args: True) + solver_state1 = None + solver_state2 = None + else: + solver_state1 = solver1.init(terms, t0, t1, y0, None) + solver_state2 = solver2.init(terms, t0, t1, y0, None) + out1 = solver1.step( + terms, t0, t1, y0, None, solver_state=solver_state1, made_jump=False + ) + out2 = solver2.step( + terms, t0, t1, y0, None, solver_state=solver_state2, made_jump=False + ) + out2[2]["k"] = out2[2]["k"][0] + out2[2]["k"][1] + assert shaped_allclose(out1, out2) + + +@pytest.mark.parametrize("adaptive", (True, False)) +def test_multiple_tableau1(adaptive): + class DoubleDopri5(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau( + diffrax.Dopri5.tableau, diffrax.Dopri5.tableau + ) + calculate_jacobian = diffrax.CalculateJacobian.never + + @staticmethod + def interpolation_cls(**kwargs): + kwargs.pop("k") return diffrax.LocalLinearInterpolation(**kwargs) + def order(self, terms): + return 5 + mlp1 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(0)) mlp2 = eqx.nn.MLP(2, 2, 32, 1, key=jr.PRNGKey(1)) @@ -68,6 +116,10 @@ def interpolation_cls(self, *, k, **kwargs): t1 = 1 dt0 = 0.1 y0 = jnp.array([1.0, 2.0]) + if adaptive: + stepsize_controller = diffrax.PIDController(rtol=1e-3, atol=1e-6) + else: + stepsize_controller = diffrax.ConstantStepSize() out_a = diffrax.diffeqsolve( diffrax.MultiTerm(term1, term2), diffrax.Dopri5(), @@ -75,6 +127,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) out_b = diffrax.diffeqsolve( diffrax.MultiTerm(term1, term2), @@ -83,6 +136,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) assert jnp.allclose(out_a.ys, out_b.ys, rtol=1e-8, atol=1e-8) @@ -94,6 +148,7 @@ def interpolation_cls(self, *, k, **kwargs): t1, dt0, y0, + stepsize_controller=stepsize_controller, ) @@ -130,3 +185,272 @@ class Z(diffrax.AbstractRungeKutta): def interpolation_cls(self, *, k, **kwargs): return diffrax.LocalLinearInterpolation(**kwargs) + + +@pytest.mark.parametrize("implicit", (True, False)) +@pytest.mark.parametrize("vf_expensive", (True, False)) +@pytest.mark.parametrize("adaptive", (True, False)) +def test_everything_pytree(implicit, vf_expensive, adaptive): + class Term(diffrax.AbstractTerm): + coeff: float + + def vf(self, t, y, args): + return {"f": -self.coeff * y["y"]} + + def contr(self, t0, t1): + return {"t": t1 - t0} + + def prod(self, vf, control): + return {"y": vf["f"] * control["t"]} + + def is_vf_expensive(self, t0, t1, y, args): + return vf_expensive + + term = diffrax.MultiTerm(Term(0.3), Term(0.7)) + + if implicit: + tableau_ = diffrax.Kvaerno5.tableau + calculate_jacobian_ = diffrax.CalculateJacobian.second_stage + else: + tableau_ = diffrax.Dopri5.tableau + calculate_jacobian_ = diffrax.CalculateJacobian.never + + class DoubleSolver(diffrax.AbstractRungeKutta): + tableau = diffrax.MultiButcherTableau(diffrax.Dopri5.tableau, tableau_) + calculate_jacobian = calculate_jacobian_ + if implicit: + nonlinear_solver = diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-3) + + @staticmethod + def interpolation_cls(*, t0, t1, y0, y1, k): + k_left, k_right = k + k = {"y": k_left["y"] + k_right["y"]} + return diffrax.solver.dopri5._Dopri5Interpolation( + t0=t0, t1=t1, y0=y0, y1=y1, k=k + ) + + def order(self, terms): + return 5 + + solver = DoubleSolver() + t0 = 0.4 + t1 = 0.9 + dt0 = 0.0007 + y0 = {"y": jnp.array([[1.0, 2.0], [3.0, 4.0]])} + saveat = diffrax.SaveAt(ts=jnp.linspace(t0, t1, 23)) + if adaptive: + stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10) + else: + stepsize_controller = diffrax.ConstantStepSize() + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + true_sol = diffrax.diffeqsolve( + diffrax.ODETerm(lambda t, y, args: {"y": -y["y"]}), + diffrax.Dopri5(), + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + if implicit: + tol = 1e-4 # same ODE but different solver + else: + tol = 1e-8 # should be exact same numerics, up to floating point weirdness + assert shaped_allclose(sol.ys, true_sol.ys, rtol=tol, atol=tol) + + +# Essentially used as a check that our general IMEX implementation is correct. +def test_sil3(): + class ReferenceSil3(diffrax.AbstractImplicitSolver): + term_structure = diffrax.MultiTerm[ + Tuple[diffrax.AbstractTerm, diffrax.AbstractTerm] + ] + interpolation_cls = diffrax.LocalLinearInterpolation + + def order(self, terms): + return 2 + + def init(self, terms, t0, t1, y0, args): + return None + + def func(self, terms, t, y, args): + assert False + + def step(self, terms, t0, t1, y0, args, solver_state, made_jump): + del solver_state, made_jump + explicit, implicit = terms.terms + dt = t1 - t0 + ex_vf_prod = lambda t, y: explicit.vf(t, y, args) * dt + im_vf_prod = lambda t, y: implicit.vf(t, y, args) * dt + fs = [] + gs = [] + + # first stage is explicit + fs.append(ex_vf_prod(t0, y0)) + gs.append(im_vf_prod(t0, y0)) + + def _second_stage(ya, _): + [f0] = fs + [g0] = gs + g1 = im_vf_prod(ta, ya) + return ya - (y0 + (1 / 3) * f0 + (1 / 6) * g0 + (1 / 6) * g1) + + ta = t0 + (1 / 3) * dt + ya = self.nonlinear_solver(_second_stage, y0, None).root + fs.append(ex_vf_prod(ta, ya)) + gs.append(im_vf_prod(ta, ya)) + + def _third_stage(yb, _): + [f0, f1] = fs + [g0, g1] = gs + g2 = im_vf_prod(tb, yb) + return yb - ( + y0 + (1 / 6) * f0 + (1 / 2) * f1 + (1 / 3) * g0 + (1 / 3) * g2 + ) + + tb = t0 + (2 / 3) * dt + yb = self.nonlinear_solver(_third_stage, ya, None).root + fs.append(ex_vf_prod(tb, yb)) + gs.append(im_vf_prod(tb, yb)) + + def _fourth_stage(yc, _): + [f0, f1, f2] = fs + [g0, g1, g2] = gs + g3 = im_vf_prod(tc, yc) + return yc - ( + y0 + + (1 / 2) * f0 + + (-1 / 2) * f1 + + f2 + + (3 / 8) * g0 + + (3 / 8) * g2 + + (1 / 4) * g3 + ) + + tc = t1 + yc = self.nonlinear_solver(_fourth_stage, yb, None).root + fs.append(ex_vf_prod(tc, yc)) + gs.append(im_vf_prod(tc, yc)) + + [f0, f1, f2, f3] = fs + [g0, g1, g2, g3] = gs + y1 = ( + y0 + + (1 / 2) * f0 + - (1 / 2) * f1 + + f2 + + (3 / 8) * g0 + + (3 / 8) * g2 + + (1 / 4) * g3 + ) + + # Use Heun as the embedded method. + y_error = y0 + 0.5 * (f0 + g0 + f3 + g3) - y1 + ks = (jnp.stack(fs), jnp.stack(gs)) + dense_info = dict(y0=y0, y1=y1, k=ks) + state = (False, (f3 / dt, g3 / dt)) + return y1, y_error, dense_info, state, jnp.array(diffrax.RESULTS.successful) + + reference_solver = ReferenceSil3( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8) + ) + solver = diffrax.Sil3( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-8, atol=1e-8) + ) + + key = jr.PRNGKey(5678) + mlpkey1, mlpkey2, ykey = jr.split(key, 3) + + mlp1 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey1) + mlp2 = eqx.nn.MLP(3, 2, 8, 1, key=mlpkey2) + + def f1(t, y, args): + y = jnp.concatenate([t[None], y]) + return mlp1(y) + + def f2(t, y, args): + y = jnp.concatenate([t[None], y]) + return mlp2(y) + + terms = diffrax.MultiTerm(diffrax.ODETerm(f1), diffrax.ODETerm(f2)) + t0 = jnp.array(0.3) + t1 = jnp.array(1.5) + y0 = jr.normal(ykey, (2,), dtype=jnp.float64) + args = None + + state = solver.init(terms, t0, t1, y0, args) + out = solver.step(terms, t0, t1, y0, args, solver_state=state, made_jump=False) + reference_out = reference_solver.step( + terms, t0, t1, y0, args, solver_state=None, made_jump=False + ) + assert shaped_allclose(out, reference_out) + + +# Honestly not sure how meaningful this test is -- Rober isn't *that* stiff. +# In fact, even Heun will get the correct answer with the tolerances we specify! +@pytest.mark.parametrize( + "solver", + ( + diffrax.Kvaerno3(), + diffrax.Kvaerno4(), + diffrax.Kvaerno5(), + diffrax.KenCarp3(), + diffrax.KenCarp4(), + diffrax.KenCarp5(), + ), +) +def test_rober(solver): + def rober(t, y, args): + y0, y1, y2 = y + k1 = 0.04 + k2 = 3e7 + k3 = 1e4 + f0 = -k1 * y0 + k3 * y1 * y2 + f1 = k1 * y0 - k2 * y1**2 - k3 * y1 * y2 + f2 = k2 * y1**2 + return jnp.stack([f0, f1, f2]) + + term = diffrax.ODETerm(rober) + if solver.__class__.__name__.startswith("KenCarp"): + term = diffrax.MultiTerm(diffrax.ODETerm(lambda t, y, args: 0), term) + t0 = 0 + t1 = 100 + y0 = jnp.array([1.0, 0, 0]) + dt0 = 0.0002 + saveat = diffrax.SaveAt(ts=jnp.array([0.0, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2])) + stepsize_controller = diffrax.PIDController(rtol=1e-10, atol=1e-10) + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + max_steps=None, + ) + # Obtained using Kvaerno5 with rtol,atol=1e-20 + true_ys = jnp.array( + [ + [1.0000000000000000e00, 0.0000000000000000e00, 0.0000000000000000e00], + [9.9999600000801137e-01, 3.9840684637775332e-06, 1.5923523513217297e-08], + [9.9996000156321818e-01, 2.9169034944881154e-05, 1.0829401837965007e-05], + [9.9960068268829505e-01, 3.6450478878442643e-05, 3.6286683282835678e-04], + [9.9607774744245892e-01, 3.5804372350422432e-05, 3.8864481851928275e-03], + [9.6645973733301294e-01, 3.0746265785786866e-05, 3.3509516401211095e-02], + [8.4136992384147014e-01, 1.6233909379904643e-05, 1.5861384224914774e-01], + [6.1723488239606716e-01, 6.1535912746388841e-06, 3.8275896401264059e-01], + ] + ) + assert jnp.allclose(sol.ys, true_ys, rtol=1e-3, atol=1e-8) From 77b1a6024921ce8f4bce2248ac884f5024a7e592 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 21 May 2023 22:44:59 -0700 Subject: [PATCH 7/7] Bump minimum version of dependencies --- .github/workflows/build_docs.yml | 2 +- .github/workflows/release.yml | 2 +- .github/workflows/run_tests.yml | 2 +- README.md | 2 +- docs/index.md | 2 +- setup.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/build_docs.yml b/.github/workflows/build_docs.yml index b20b4864..8fdf2492 100644 --- a/.github/workflows/build_docs.yml +++ b/.github/workflows/build_docs.yml @@ -9,7 +9,7 @@ jobs: build: strategy: matrix: - python-version: [ 3.8 ] + python-version: [ 3.11 ] os: [ ubuntu-latest ] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 82573314..c1cbe8cf 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -12,7 +12,7 @@ jobs: - name: Release uses: patrick-kidger/action_update_python_project@v1 with: - python-version: "3.8" + python-version: "3.11" test-script: | python -m pip install pytest psutil jax jaxlib equinox scipy optax cp -r ${{ github.workspace }}/test ./test diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 935e17f6..5184ed84 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -7,7 +7,7 @@ jobs: run-tests: strategy: matrix: - python-version: [ 3.8, 3.9 ] + python-version: [ 3.9, 3.11 ] os: [ ubuntu-latest ] fail-fast: false runs-on: ${{ matrix.os }} diff --git a/README.md b/README.md index 48fcb2ca..b2797403 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+. +Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+. ## Documentation diff --git a/docs/index.md b/docs/index.md index 52dd17af..8639d695 100644 --- a/docs/index.md +++ b/docs/index.md @@ -20,7 +20,7 @@ _From a technical point of view, the internal structure of the library is pretty pip install diffrax ``` -Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+. +Requires Python 3.9+, JAX 0.4.4+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.4+. ## Quick example diff --git a/setup.py b/setup.py index b51839f6..4bc09613 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ "Topic :: Scientific/Engineering :: Mathematics", ] -python_requires = "~=3.8" +python_requires = "~=3.9" install_requires = ["jax>=0.4.3", "equinox>=0.10.4"]