Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.6.1 release #534

Merged
merged 14 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build_docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
mkdocs build # twice, see https://github.com/patrick-kidger/pytkdocs_tweaks

- name: Upload docs
uses: actions/upload-artifact@v2
uses: actions/upload-artifact@v4
with:
name: docs
path: site # where `mkdocs build` puts the built site
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions]
6 changes: 6 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
AbstractDIRK as AbstractDIRK,
AbstractERK as AbstractERK,
AbstractESDIRK as AbstractESDIRK,
AbstractFosterLangevinSRK as AbstractFosterLangevinSRK,
AbstractImplicitSolver as AbstractImplicitSolver,
AbstractItoSolver as AbstractItoSolver,
AbstractRungeKutta as AbstractRungeKutta,
Expand All @@ -79,6 +80,7 @@
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
ALIGN as ALIGN,
Bosh3 as Bosh3,
ButcherTableau as ButcherTableau,
CalculateJacobian as CalculateJacobian,
Expand All @@ -100,11 +102,13 @@
LeapfrogMidpoint as LeapfrogMidpoint,
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
QUICSORT as QUICSORT,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
ShOULD as ShOULD,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
Expand All @@ -125,6 +129,8 @@
ControlTerm as ControlTerm,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
)

Expand Down
56 changes: 53 additions & 3 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

Expand Down Expand Up @@ -258,12 +259,10 @@ def _maybe_static(static_x: Optional[ArrayLike], x: ArrayLike) -> ArrayLike:
# Some values (made_jump and result) are not used in many common use-cases. If we
# detect that they're unused then we make sure they're non-Array Python values, so
# that we can special case on them at trace time and get a performance boost.
if isinstance(static_x, (bool, int, float, complex)):
if isinstance(static_x, (bool, int, float, complex, np.ndarray)):
return static_x
elif static_x is None:
return x
elif type(jax.core.get_aval(static_x)) is jax.core.ConcreteArray:
return static_x
else:
return x

Expand Down Expand Up @@ -776,9 +775,60 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_if_t0_equals_t1(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
out_size = 1 if subsaveat.t0 else 0
out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0
out_size += len(subsaveat.ts)

def _make_ys(out, old_outs):
outs = jnp.stack([out] * out_size)
if subsaveat.steps:
outs = jnp.concatenate(
[
outs,
jnp.full(
(max_steps,) + out.shape, jnp.inf, dtype=out.dtype
),
]
)
assert outs.shape == old_outs.shape
return outs

ts = jnp.full(out_size, t0)
if subsaveat.steps:
ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf, dtype=ts.dtype)))
assert ts.shape == save_state.ts.shape
ys = jtu.tree_map(_make_ys, subsaveat.fn(t0, yfinal, args), save_state.ys)
save_state = SaveState(
saveat_ts_index=out_size,
ts=ts,
ys=ys,
save_index=out_size,
)
return save_state

save_state = jtu.tree_map(
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
)

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
eqxi.unvmap_any(t0 == t1),
lambda __save_state: jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_if_t0_equals_t1, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
__save_state,
),
lambda __save_state: __save_state,
save_state,
)

final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
)
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
):
def _calculate(_y0, _y1, _k):
with jax.numpy_dtype_promotion("standard"):
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1)
_ymid = _y0 + jnp.tensordot(self.c_mid, _k, axes=1).astype(_y0.dtype)
_f0 = _k[0]
_f1 = _k[-1]
# TODO: rewrite as matrix-vector product?
Expand Down
9 changes: 3 additions & 6 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optimistix as optx
from jaxtyping import Array, ArrayLike, PyTree, Shaped

Expand Down Expand Up @@ -146,12 +147,8 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
# predicate is statically known.
# This in turn allows us to perform some trace-time optimisations that XLA isn't
# smart enough to do on its own.
if (
type(pred) is not bool
and type(jax.core.get_aval(pred)) is jax.core.ConcreteArray
):
with jax.ensure_compile_time_eval():
pred = pred.item()
if isinstance(pred, (np.ndarray, np.generic)) and pred.shape == ():
pred = pred.item()
if pred is True:
return a
elif pred is False:
Expand Down
4 changes: 4 additions & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .align import ALIGN as ALIGN
from .base import (
AbstractAdaptiveSolver as AbstractAdaptiveSolver,
AbstractImplicitSolver as AbstractImplicitSolver,
Expand All @@ -12,6 +13,7 @@
from .dopri8 import Dopri8 as Dopri8
from .euler import Euler as Euler
from .euler_heun import EulerHeun as EulerHeun
from .foster_langevin_srk import AbstractFosterLangevinSRK as AbstractFosterLangevinSRK
from .heun import Heun as Heun
from .implicit_euler import ImplicitEuler as ImplicitEuler
from .kencarp3 import KenCarp3 as KenCarp3
Expand All @@ -26,6 +28,7 @@
ItoMilstein as ItoMilstein,
StratonovichMilstein as StratonovichMilstein,
)
from .quicsort import QUICSORT as QUICSORT
from .ralston import Ralston as Ralston
from .reversible_heun import ReversibleHeun as ReversibleHeun
from .runge_kutta import (
Expand All @@ -42,6 +45,7 @@
from .semi_implicit_euler import SemiImplicitEuler as SemiImplicitEuler
from .shark import ShARK as ShARK
from .shark_general import GeneralShARK as GeneralShARK
from .should import ShOULD as ShOULD
from .sil3 import Sil3 as Sil3
from .slowrk import SlowRK as SlowRK
from .spark import SPaRK as SPaRK
Expand Down
191 changes: 191 additions & 0 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
AbstractSpaceTimeLevyArea,
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import (
UnderdampedLangevinLeaf,
UnderdampedLangevinTuple,
UnderdampedLangevinX,
)
from .foster_langevin_srk import (
AbstractCoeffs,
AbstractFosterLangevinSRK,
UnderdampedLangevinArgs,
)


# For an explanation of the coefficients, see foster_langevin_srk.py
class _ALIGNCoeffs(AbstractCoeffs):
beta: PyTree[ArrayLike]
a1: PyTree[ArrayLike]
b1: PyTree[ArrayLike]
aa: PyTree[ArrayLike]
chh: PyTree[ArrayLike]
dtype: jnp.dtype = eqx.field(static=True)

def __init__(self, beta, a1, b1, aa, chh):
self.beta = beta
self.a1 = a1
self.b1 = b1
self.aa = aa
self.chh = chh
all_leaves = jtu.tree_leaves([self.beta, self.a1, self.b1, self.aa, self.chh])
self.dtype = jnp.result_type(*all_leaves)


_ErrorEstimate = UnderdampedLangevinTuple


class ALIGN(AbstractFosterLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
r"""The Adaptive Langevin via Interpolated Gradients and Noise method
designed by James Foster. This is a second order solver for the
Underdamped Langevin Diffusion, and accepts terms of the form
`MultiTerm(UnderdampedLangevinDriftTerm, UnderdampedLangevinDiffusionTerm)`.
Uses two evaluations of the vector
field per step, but is FSAL, so in practice it only requires one.

??? cite "Reference"

This is a modification of the Strang-Splitting method from Definition 4.2 of

```bibtex
@misc{foster2021shiftedode,
title={The shifted ODE method for underdamped Langevin MCMC},
author={James Foster and Terry Lyons and Harald Oberhauser},
year={2021},
eprint={2101.03446},
archivePrefix={arXiv},
primaryClass={math.NA},
url={https://arxiv.org/abs/2101.03446},
}
```

"""

interpolation_cls = LocalLinearInterpolation
minimal_levy_area = AbstractSpaceTimeLevyArea
taylor_threshold: float = eqx.field(static=True)
_is_fsal = True

def __init__(self, taylor_threshold: float = 0.1):
r"""**Arguments:**

- `taylor_threshold`: If the product `h*gamma` is less than this, then
the Taylor expansion will be used to compute the coefficients.
Otherwise they will be computed directly. When using float32, the
empirically optimal value is 0.1, and for float64 about 0.01.
"""
self.taylor_threshold = taylor_threshold

def order(self, terms):
del terms
return 2

def strong_order(self, terms):
del terms
return 2.0

def _directly_compute_coeffs_leaf(
self, h: RealScalarLike, c: UnderdampedLangevinLeaf
) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
# compute the coefficients directly (as opposed to via Taylor expansion)
al = c * h
beta = jnp.exp(-al)
a1 = (1 - beta) / c
b1 = (beta + al - 1) / (c * al)
aa = a1 / h

al2 = al**2
chh = 6 * (beta * (al + 2) + al - 2) / (al2 * c)

return _ALIGNCoeffs(
beta=beta,
a1=a1,
b1=b1,
aa=aa,
chh=chh,
)

def _tay_coeffs_single(self, c: UnderdampedLangevinLeaf) -> _ALIGNCoeffs:
del self
# c is a leaf of gamma
zero = jnp.zeros_like(c)
one = jnp.ones_like(c)
c2 = jnp.square(c)
c3 = c2 * c
c4 = c3 * c
c5 = c4 * c

# Coefficients of the Taylor expansion, starting from 5th power
# to 0th power. The descending power order is because of jnp.polyval
beta = jnp.stack([-c5 / 120, c4 / 24, -c3 / 6, c2 / 2, -c, one], axis=-1)
a1 = jnp.stack([c4 / 120, -c3 / 24, c2 / 6, -c / 2, one, zero], axis=-1)
b1 = jnp.stack([c4 / 720, -c3 / 120, c2 / 24, -c / 6, one / 2, zero], axis=-1)
aa = jnp.stack([-c5 / 720, c4 / 120, -c3 / 24, c2 / 6, -c / 2, one], axis=-1)
chh = jnp.stack([c4 / 168, -c3 / 30, 3 * c2 / 20, -c / 2, one, zero], axis=-1)

correct_shape = jnp.shape(c) + (6,)
assert (
beta.shape == a1.shape == b1.shape == aa.shape == chh.shape == correct_shape
)

return _ALIGNCoeffs(
beta=beta,
a1=a1,
b1=b1,
aa=aa,
chh=chh,
)

def _compute_step(
self,
h: RealScalarLike,
levy: AbstractSpaceTimeLevyArea,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _ALIGNCoeffs,
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
) -> tuple[
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinX,
UnderdampedLangevinTuple,
]:
dtypes = jtu.tree_map(jnp.result_type, x0)
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = underdamped_langevin_args

uh = (u**ω * h).ω
f0 = prev_f
x1 = (
x0**ω
+ coeffs.a1**ω * v0**ω
- coeffs.b1**ω * uh**ω * f0**ω
+ rho**ω * (coeffs.b1**ω * w**ω + coeffs.chh**ω * hh**ω)
).ω
f1 = f(x1)
v1 = (
coeffs.beta**ω * v0**ω
- u**ω * ((coeffs.a1**ω - coeffs.b1**ω) * f0**ω + coeffs.b1**ω * f1**ω)
+ rho**ω * (coeffs.aa**ω * w**ω - gamma**ω * coeffs.chh**ω * hh**ω)
).ω

error_estimate = (
jtu.tree_map(jnp.zeros_like, x0),
(-(u**ω) * coeffs.b1**ω * (f1**ω - f0**ω)).ω,
)

return x1, v1, f1, error_estimate
Loading
Loading