Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KL Divergence for Latent SDEs #463

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,13 @@
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
initialize_kl as initialize_kl,
ItoMilstein as ItoMilstein,
KenCarp3 as KenCarp3,
KenCarp4 as KenCarp4,
KenCarp5 as KenCarp5,
KLSolver as KLSolver,
KLState as KLState,
Kvaerno3 as Kvaerno3,
Kvaerno4 as Kvaerno4,
Kvaerno5 as Kvaerno5,
Expand Down
3 changes: 3 additions & 0 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
ItoMilstein,
StratonovichMilstein,
)
from ._solver.kl import KLState
from ._step_size_controller import (
AbstractAdaptiveStepSizeController,
AbstractStepSizeController,
Expand Down Expand Up @@ -124,6 +125,8 @@ def _term_compatible(
contr_kwargs: PyTree[dict],
) -> bool:
error_msg = "term_structure"
if isinstance(y, KLState):
y = y.y
lockwo marked this conversation as resolved.
Show resolved Hide resolved

def _check(term_cls, term, term_contr_kwargs, yi):
if get_origin_no_specials(term_cls, error_msg) is MultiTerm:
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .kencarp3 import KenCarp3 as KenCarp3
from .kencarp4 import KenCarp4 as KenCarp4
from .kencarp5 import KenCarp5 as KenCarp5
from .kl import initialize_kl as initialize_kl, KLSolver as KLSolver, KLState as KLState
from .kvaerno3 import Kvaerno3 as Kvaerno3
from .kvaerno4 import Kvaerno4 as Kvaerno4
from .kvaerno5 import Kvaerno5 as Kvaerno5
Expand Down
291 changes: 291 additions & 0 deletions diffrax/_solver/kl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
import operator
from typing import Optional

import equinox as eqx
import jax.tree_util as jtu
import lineax as lx
from jax import numpy as jnp
from jaxtyping import Array, PyTree

from .._custom_types import (
Args,
BoolScalarLike,
Control,
DenseInfo,
RealScalarLike,
VF,
Y,
)
from .._heuristics import is_sde
from .._solution import RESULTS
from .._term import (
AbstractTerm,
ControlTerm,
MultiTerm,
ODETerm,
)
from .base import (
_SolverState,
AbstractSolver,
AbstractWrappedSolver,
)


class KLState(eqx.Module):
"""
The state of the SDE and the KL divergence.
"""

y: Y
kl_metric: Array


def _compute_kl_integral(
drift_term1: ODETerm,
drift_term2: ODETerm,
diffusion_term: ControlTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
linear_solver: lx.AbstractLinearSolver,
) -> KLState:
"""
Compute the KL divergence.
"""
drift1 = drift_term1.vf(t0, y0, args)
drift2 = drift_term2.vf(t0, y0, args)
drift = jtu.tree_map(operator.sub, drift1, drift2)

diffusion = diffusion_term.vf(t0, y0, args) # assumes same diffusion

divergences = lx.linear_solve(diffusion, drift, solver=linear_solver).value
lockwo marked this conversation as resolved.
Show resolved Hide resolved

kl_divergence = jtu.tree_map(lambda x: 0.5 * jnp.sum(x**2), divergences)
kl_divergence = jtu.tree_reduce(operator.add, kl_divergence)

return KLState(drift1, jnp.squeeze(kl_divergence))
lockwo marked this conversation as resolved.
Show resolved Hide resolved


class _KLDrift(AbstractTerm):
drift1: ODETerm
drift2: ODETerm
diffusion: ControlTerm
linear_solver: lx.AbstractLinearSolver

def vf(self, t: RealScalarLike, y: Y, args: Args) -> KLState:
y = y.y
lockwo marked this conversation as resolved.
Show resolved Hide resolved
return _compute_kl_integral(
self.drift1, self.drift2, self.diffusion, t, y, args, self.linear_solver
)

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> Control:
return t1 - t0

def prod(self, vf: VF, control: RealScalarLike) -> Y:
return jtu.tree_map(lambda v: control * v, vf)


class _KLControlTerm(AbstractTerm):
control_term: ControlTerm

def vf(self, t: RealScalarLike, y: Y, args: Args) -> KLState:
y = y.y
vf = self.control_term.vf(t, y, args)
return KLState(vf, jnp.array(0.0))

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> KLState:
return KLState(self.control_term.contr(t0, t1), jnp.array(0.0))
lockwo marked this conversation as resolved.
Show resolved Hide resolved

def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: Control) -> KLState:
y = y.y
control = control.y
return KLState(self.control_term.vf_prod(t, y, args, control), jnp.array(0.0))

def prod(self, vf: VF, control: Control) -> KLState:
vf = vf.y
control = control.y
return KLState(self.control_term.prod(vf, control), jnp.array(0.0))


class KLSolver(AbstractWrappedSolver[_SolverState]):
r"""Given an SDE of the form

$$
\mathrm{d}y(t) = f_\theta (t, y(t)) dt + g_\phi (t, y(t)) dW(t) \qquad \zeta_\theta (ts[0]) = y_0
$$

$$
\mathrm{d}z(t) = h_\psi (t, z(t)) dt + g_\phi (t, z(t)) dW(t) \qquad \nu_\psi (ts[0]) = z_0
$$

compute:

$$
\int_{ts[i-1]}^{ts[i]} g_\phi (t, y(t))^{-1} (f_\theta (t, y(y)) - h_\psi (t, y(t))) dt
$$

for every time interval. This is useful for KL based latent SDEs. The output
of the solution.ys will be a tuple containing (ys, kls) where kls is the KL
divergence integration at that time. Unless the noise is diagonal, this
inverse can be extremely costly for higher dimenions.

The input must be a `MultiTerm` composed of the first SDE with drift `f`
and diffusion `g` and the second either a SDE or just the drift term
(since the diffusion is assumed to be the same). For example, a type
of: `MuliTerm(MultiTerm(ODETerm, _DiffusionTerm), ODETerm)`.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per this comment:
#402 (review)
and also the updated term docs:
https://docs.kidger.site/diffrax/api/terms/
then this outer MultiTerm isn't really in-keeping. We're not adding all of these extra terms on to the same evolving state.

Bearing in mind that the rest of Diffrax has to see this as just another SDE solve.

I think this one might take a bit more iteration to get to something that's obeying the abstractions in the way they're designed, I'm afraid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I get what you're saying, multiterm implies a single differential equation "unit". So the composed multi terms is bad form. However, I'm not sure I see the difficulty going forward, I can replace it with tuple (multiterm, multiterm) or even tuple (multiterm, ode term). Which seems to adhere to this principle of multiterm = sde unit, since we are integrating two simultaneous SDEs, while also falling in line with other solvers (such as implicit Euler as you remarked).

On the terms vs solver approach, I am open to both. I think in my many iterations/experimentations I found the solver approach more in line with my thinking about the nature of the problem, specifically the original idea of (terms, kl_term) I didn't see as appealing since the KL_term relies on information from the other term and I didn't see a clean way to do that. However, having terms with a term wrapper is very doable (but may not mesh with the repo as well).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe introducing a totally new term is ok (given the remarks in #453), in which case the approach of a KLTerm (rather than a solver), is doable. Given the restricted nature of terms so far, I originally thought that wasn't in line with the package


??? cite "References"

See section 5 of:

```bibtex
@inproceedings{li2020scalable,
title={Scalable gradients for stochastic differential equations},
author={Li, Xuechen and Wong, Ting-Kam Leonard and Chen, Ricky TQ and Duvenaud, David},
booktitle={International Conference on Artificial Intelligence and Statistics},
pages={3870--3882},
year={2020},
organization={PMLR}
}
```

Or section 4.3.2 of:

```bibtex
@article{kidger2022neural,
title={On neural differential equations},
author={Kidger, Patrick},
journal={arXiv preprint arXiv:2202.02435},
year={2022}
}
```
""" # noqa: E501

solver: AbstractSolver[_SolverState]
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None)

def order(self, terms: PyTree[AbstractTerm]) -> Optional[int]:
return self.solver.order(terms)

def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]:
return self.solver.strong_order(terms)

def error_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]:
if is_sde(terms):
order = self.strong_order(terms)
else:
order = self.order(terms)
return order

@property
def term_structure(self):
return self.solver.term_structure

@property
def interpolation_cls(self): # pyright: ignore
return self.solver.interpolation_cls

def init(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: KLState,
args: Args,
) -> _SolverState:
return self.solver.init(terms, t0, t1, y0, args)

def step(
self,
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
terms1, terms2 = terms.terms
lockwo marked this conversation as resolved.
Show resolved Hide resolved
drift_term1 = jtu.tree_map(
lambda x: x if isinstance(x, ODETerm) else None,
terms1,
is_leaf=lambda x: isinstance(x, ODETerm),
)
drift_term1 = jtu.tree_leaves(
drift_term1, is_leaf=lambda x: isinstance(x, ODETerm)
)
drift_term2 = jtu.tree_map(
lambda x: x if isinstance(x, ODETerm) else None,
terms2,
is_leaf=lambda x: isinstance(x, ODETerm),
)
drift_term2 = jtu.tree_leaves(
drift_term2, is_leaf=lambda x: isinstance(x, ODETerm)
)
lockwo marked this conversation as resolved.
Show resolved Hide resolved

drift_term1 = eqx.error_if(
drift_term1, len(drift_term1) != 1, "First SDE doesn't have one ODETerm!"
)
drift_term2 = eqx.error_if(
drift_term2, len(drift_term2) != 1, "Second SDE doesn't have one ODETerm!"
)
lockwo marked this conversation as resolved.
Show resolved Hide resolved
drift_term1, drift_term2 = drift_term1[0], drift_term2[0]

diffusion_term = jtu.tree_map(
lambda x: x if isinstance(x, ControlTerm) else None,
terms1,
is_leaf=lambda x: isinstance(x, ControlTerm),
)
diffusion_term = jtu.tree_leaves(
diffusion_term,
is_leaf=lambda x: isinstance(x, ControlTerm),
)

diffusion_term = eqx.error_if(
diffusion_term, len(diffusion_term) != 1, "SDE has multiple control terms!"
)
diffusion_term = diffusion_term[0]
kl_terms = MultiTerm(
_KLDrift(drift_term1, drift_term2, diffusion_term, self.linear_solver),
_KLControlTerm(diffusion_term),
)
y1, y_error, dense_info, solver_state, result = self.solver.step(
kl_terms, t0, t1, y0, args, solver_state, made_jump
)
return y1, y_error, dense_info, solver_state, result

def func(
self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args
) -> VF:
return self.solver.func(terms, t0, y0, args)


KLSolver.__init__.__doc__ = """**Arguments:**

- `solver`: The solver to wrap.
- `linear_solver`: The lineax solver to use when computing $g^{-1}f$.
"""


def initialize_kl(
solver: AbstractSolver,
y0: Y,
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None),
) -> tuple[KLSolver, KLState]:
"""
Initialize the KL solver and state.


**Arguments**

- `solver`: the method for solving the SDE.
- `y0`: the initial state
- `linear_solver`: the method for computing $g^{-1}f$.

**Returns**

A `KLState` containing the `KLSolver` and the new initial state. Both of
these can be directly fed into `diffeqsolve`.

"""
return KLSolver(solver, linear_solver), KLState(y=y0, kl_metric=jnp.array(0.0))
7 changes: 7 additions & 0 deletions docs/api/solvers/sde_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,10 @@ These are reversible in the same way as when applied to ODEs. [See here.](./ode_
selection:
members:
- __init__

::: diffrax.initialize_kl

::: diffrax.KLSolver
selection:
members:
- __init__
Loading
Loading