Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Störmer-Verlet method + symplectic test changes #303

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
Ralston,
ReversibleHeun,
SemiImplicitEuler,
StormerVerlet,
Sil3,
StratonovichMilstein,
Tsit5,
Expand Down
1 change: 1 addition & 0 deletions diffrax/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@
MultiButcherTableau,
)
from .semi_implicit_euler import SemiImplicitEuler
from .stormer_verlet import StormerVerlet
from .sil3 import Sil3
from .tsit5 import Tsit5
78 changes: 78 additions & 0 deletions diffrax/solver/stormer_verlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from typing import Tuple

from equinox.internal import ω

from ..custom_types import Bool, DenseInfo, PyTree, Scalar
from ..local_interpolation import LocalLinearInterpolation
from ..solution import RESULTS
from ..term import AbstractTerm
from .base import AbstractSolver

_ErrorEstimate = None
_SolverState = None

class StormerVerlet(AbstractSolver):
""" Störmer-Verlet method.

Symplectic method. Does not support adaptive step sizing. Uses 1st order local
linear interpolation for dense/ts output.
"""

term_structure = (AbstractTerm, AbstractTerm)
interpolation_cls = LocalLinearInterpolation

def order(self, terms):
return 2

def init(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: PyTree,
args: PyTree,
) -> _SolverState:
return None

def step(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
t1: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree,
solver_state: _SolverState,
made_jump: Bool,
) -> Tuple[Tuple[PyTree, PyTree], _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

term_1, term_2 = terms
y0_1, y0_2 = y0
midpoint = (t1 + t0)/2

control1_half_1 = term_1.contr(t0, midpoint)
control1_half_2 = term_1.contr(midpoint, t1)
control2 = term_2.contr(t0, t1)

yhalf_1 = (y0_1 ** ω + term_1.vf_prod(t0, y0_2, args, control1_half_1) ** ω).ω
y1_2 = (y0_2 ** ω + term_2.vf_prod(midpoint, yhalf_1, args, control2) ** ω).ω
y1_1 = (yhalf_1 ** ω + term_1.vf_prod(t1, y1_2, args, control1_half_2 ** ω)).ω
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this just semi-implicit Euler written in kick-drift-kick form? (I.e. offset by half a step.) Justification: it looks to me like y1_2 on this step is y0_2 on the next step, so the term_1.vf_prod(...) evaluations happen at the same point twice. (Which also means that this is increasing runtime/compiletime.)

Copy link
Contributor Author

@packquickly packquickly Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty subtle actually. long story short: the half step difference has an impact, and they are indeed different (you can check numerically, Störmer-Verlet is order 2, symplectic Euler is order 1.)

Störmer-Verlet is the composition of the symplectic euler method and it's adjoint (reverse method,) ie. it is both variants of symplectic Euler stacked with step-size $h/2$:

$$ \begin{aligned} q_{n + 1/2} &= q_n + \frac{h}{2} f(p_n) \\ p_{n + 1/2} &= p_n + \frac{h}{2} g(q_{n + 1/2}) \\ p_{n + 1} &= p_{n + 1/2} + \frac{h}{2} f(q_{n + 1/2}) \\ q_{n + 1} &= q_{n + 1/2} + \frac{h}{2} g(p_{n + 1}) \end{aligned} $$

The implementation in non kick-drift-kick form, ie.

$$ \begin{aligned} q_{n + 1/2} &= q_{n - 1/2} + h f(p_n)\\ p_{n + 1} &= p_n + h g(q_{n + 1/2}) \end{aligned} $$

is distinct from symplectic Euler primarily because of the initialization. Looking at the second-order diffeq case ($g(q) = q$) the initial $p_1$ for symplectic Euler is

$$ p_1 = p_0 + h q_0 + h^2 f(p_0)$$

and for Störmer-Verlet it's:

$$ p_1 = p_0 + hq_0 + \frac{h^2}{2} f(p_0).$$

This is pretty much the only difference for these two though, and the non kick-drift-kick (often called the leapfrog-Verlet implementation, ugh) is definitely the better one when we don't care about knowing $q_{n+1}$. However, I assumed we needed to be able to return the tuple $(p_{n+1}, q_{n+1})$ in each call to step, hence the more expensive kick-drift-kick implementation.

If you see a way to switch to the leapfrog-Verlet implementation though by all means I'll do that instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, you might find it interesting that Hairer wrote a long article about Störmer-Verlet as a precursor to "Geometric Numerical Integration," where he used the method to demonstrate a bunch of the ideas later expanded on in the book


y1 = (y1_1, y1_2)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: Tuple[AbstractTerm, AbstractTerm],
t0: Scalar,
y0: Tuple[PyTree, PyTree],
args: PyTree
) -> Tuple[PyTree, PyTree]:
term_1, term_2 = terms
y0_1, y0_2 = y0
f1 = term_1.func(t0, y0_2, args)
f2 = term_2.func(t0, y0_1, args)
return (f1, f2)


6 changes: 5 additions & 1 deletion test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import jax.random as jrandom
import jax.tree_util as jtu


all_ode_solvers = (
diffrax.Bosh3(),
diffrax.Dopri5(),
Expand All @@ -32,6 +31,11 @@
diffrax.KenCarp5(),
)

all_symplectic_solvers = (
diffrax.SemiImplicitEuler(),
diffrax.StormerVerlet(),
)


def implicit_tol(solver):
if isinstance(solver, diffrax.AbstractImplicitSolver):
Expand Down
59 changes: 57 additions & 2 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .helpers import (
all_ode_solvers,
all_split_solvers,
all_symplectic_solvers,
implicit_tol,
random_pytree,
shaped_allclose,
Expand Down Expand Up @@ -165,6 +166,59 @@ def f(t, y, args):
assert -0.9 < order - solver.order(term) < 0.9


@pytest.mark.parametrize("solver", all_symplectic_solvers)
def test_symplectic_ode_order(solver):
solver = implicit_tol(solver)
key = jrandom.PRNGKey(17)
p_key, q_key, k_key = jrandom.split(key, 3)
p0 = jrandom.uniform(p_key, shape=(), minval=0, maxval=1)
q0 = jrandom.uniform(q_key, shape=(), minval=0, maxval=1)
k = jrandom.uniform(k_key, shape=(), minval=0.1, maxval=10)
y0 = (p0, q0)
t0 = 0
t1 = 4

def p_vector_field(t, q, k):
return q

def q_vector_field(t, p, k):
return -k * p

def analytic_solution(t, k, p0, q0):
φ = jnp.sqrt(k)
p_t = p0 * jnp.cos(φ * t) + (q0/φ) * jnp.sin(φ * t)
q_t = -p0 * φ * jnp.sin(φ * t) + q0 * jnp.cos(φ * t)
return p_t, q_t


term = (
diffrax.ODETerm(p_vector_field),
diffrax.ODETerm(q_vector_field),
)

true_pT, true_qT = analytic_solution(t1, k, p0, q0)
exponents = []
errors_p = []
errors_q = []
for exponent in [0, -1, -2, -3, -4, -6, -8, -12]:
dt0 = 2**exponent
sol = diffrax.diffeqsolve(term, solver, t0, t1, dt0, y0, k, max_steps=None)
pT, qT = sol.ys
error_p = jnp.sum(jnp.abs(pT - true_pT))
error_q = jnp.sum(jnp.abs(qT - true_qT))
if error_p < 2**-28 and error_q < 2**-28:
break
exponents.append(exponent)
errors_p.append(jnp.log2(error_q))
errors_q.append(jnp.log2(error_q))

order_p = scipy.stats.linregress(exponents, errors_p). slope
order_q = scipy.stats.linregress(exponents, errors_q). slope
# Same wide range as for general ODE solvers, but we
# require this approximate order both for `p` and `q`
assert -0.9 < order_p - solver.order(term) < 0.9
assert -0.9 < order_q - solver.order(term) < 0.9

def _squareplus(x):
return 0.5 * (x + jnp.sqrt(x**2 + 4))

Expand Down Expand Up @@ -338,14 +392,15 @@ def f(t, y, args):
assert shaped_allclose(sol1.derivative(ti), -sol2.derivative(-ti))


def test_semi_implicit_euler():
@pytest.mark.parametrize("solver", all_symplectic_solvers)
def test_symplectic_solvers(solver):
term1 = diffrax.ODETerm(lambda t, y, args: -y)
term2 = diffrax.ODETerm(lambda t, y, args: y)
y0 = (1.0, -0.5)
dt0 = 0.00001
sol1 = diffrax.diffeqsolve(
(term1, term2),
diffrax.SemiImplicitEuler(),
solver,
0,
1,
dt0,
Expand Down