Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Latent SDE #104

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
107 changes: 61 additions & 46 deletions diffrax/misc/sde_kl_divergence.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import operator

import equinox as eqx
import jax
import jax.numpy as jnp

from ..brownian import AbstractBrownianPath
from ..custom_types import PyTree
from ..custom_types import PyTree, Scalar
from ..term import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm


def _kl(drift1, drift2, diffusion):
Expand All @@ -14,61 +13,77 @@ def _kl(drift1, drift2, diffusion):
return 0.5 * jnp.sum(scale**2)


class _AugDrift(eqx.Module):
drift1: callable
drift2: callable
diffusion: callable
context: callable
def _kl_diagonal(drift1, drift2, diffusion):
# stable division
diffusion = jnp.where(
jax.lax.stop_gradient(diffusion) > 1e-7,
diffusion,
jnp.full_like(diffusion, fill_value=1e-7) * jnp.sign(diffusion),
)
scale = (drift1 - drift2) / diffusion
return 0.5 * jnp.sum(scale**2)

def __call__(self, t, y, args):
y, _ = y
context = self.context(t)
aug_y = jnp.concatenate([y, context], axis=-1)
drift1 = self.drift1(t, aug_y, args)
drift2 = self.drift2(t, y, args)
diffusion = self.diffusion(t, y, args)
kl_divergence = jax.tree_map(_kl, drift1, drift2, diffusion)
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
return drift1, kl_divergence

class _AugControlTerm(ControlTerm):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should inherit from AbstractTerm rather than ControlTerm. At the moment you're using both inheritance (from ControlTerm) and composition (passing in a ControlTerm instance as an argument); almost always you only ever need one of these approaches.

In this case I think composition is most natural, since the "base" ControlTerm already exists.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah you're right. I'll update it.


control_term: ControlTerm

class _AugDiffusion(eqx.Module):
diffusion: callable
def __init__(self, control_term: ControlTerm) -> None:
super().__init__(control_term.vector_field, control_term.control)
self.control_term = control_term

def __call__(self, t, y, args):
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
diffusion = self.diffusion(t, y, args)
return diffusion, 0.0
vf = self.control_term.vf(t, y, args)
return vf, 0.0

def contr(self, t0: Scalar, t1: Scalar) -> PyTree:
return self.control_term.contr(t0, t1), 0.0

class _AugBrownianPath(eqx.Module):
bm: AbstractBrownianPath
def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree:
y, _ = y
return self.control_term.vf_prod(t, y, args, control), 0.0

@property
def t0(self):
return self.bm.t0

@property
def t1(self):
return self.bm.t1
class _AugVectorField(ODETerm):

def evaluate(self, t0, t1):
return self.bm.evaluate(t0, t1), 0.0
sde1: MultiTerm
sde2: MultiTerm
context: callable
kl: callable
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should require an explicit drift1, drift2, and diffusion here. (Rather than wrapping them in MultiTerms.)

I'm also thinking we can probably just remove context altogether? This isn't always used -- if you just want to compute the KL divergence between fixed SDEs -- and in the latent SDE case then the context can be handled via the args that are passed through. So better to have a simpler API I think.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just realize now that we can check the class of diffusion. I will go back to the API with drift1, drift2, and diffusion.


def __init__(self, sde1, sde2, context) -> None:
super().__init__(sde1.terms[0].vector_field)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also -- pick only one of composition or inheritance.

if sde1.terms[1] is not sde2.terms[1]:
raise ValueError("Two SDEs should share the same control terms")
self.sde1 = sde1
self.sde2 = sde2
if isinstance(self.sde1.terms[1], WeaklyDiagonalControlTerm):
self.kl = _kl_diagonal
else:
self.kl = _kl
self.context = context

def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
context = self.context(t)
aug_y = y if context is None else jnp.concatenate([y, context], axis=-1)
drift1 = self.sde1.terms[0].vf(t, aug_y, args)
drift2 = self.sde2.terms[0].vf(t, y, args)
diffusion = self.sde1.terms[1].vf(t, y, args)
kl_divergence = jax.tree_map(self.kl, drift1, drift2, diffusion)
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
return drift1, kl_divergence
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So unfortunately this approach is going to break whenever a custom diffusion term is used.

At the moment, the current implementation assumes that the diffusion is either a ControlTerm -- which produces a diffusion matrix, and does a diffusion-control product as a matrix-vector multiply -- or a WeaklyDiagonalControlTerm -- which produces a diagonal diffusion matrix, and does a diffusion-control product as a (diagonal-matrix)-vector multiply.

But the API around terms is specifically chosen so that the output of .vf(...) can really be anything (dense matrix; diagonal matrix; ... someone may also wish to write something special for tridiagonal matrices, sparse matrices etc.) so that diffusion could really be an arbitrarily-structured PyTree, for which .prod(...) is the only thing that knows how to consume it.

I'd need to think a lot harder about what the general case really is here. I'd welcome any thoughts on how that might be done, but if seems more complicated than you really want to get in to right now, a simple-but-inefficient approach is to forcibly materialise the diffusion as a matrix, ignoring any custom (diagonal/sparse/whatever) structure. This won't be efficient with user-specified control terms, but at the very least won't break.

Untested, but I think something like the following would work as an implementation.

def materialise_vf(t, y, args, contr, vf_prod):
    # Only used for its shape/dtype/structure; value is irrelevant
    control = contr(t, t)

    y_size = sum(np.size(yi) for yi in jax.tree_leaves(y))
    control_size = sum(np.size(ci) for ci in jax.tree_leaves(control))
    if y_size > control_size:
        make_jac = jax.jacfwd
    else:
        make_jac = jax.jacrev

    # Find the tree structure of vf_prod by smuggling it out as an additional
    # result from the Jacobian calculation.
    sentinel = vf_prod_tree = object()
    control_tree = jax.tree_structure(control)

    def _fn(_control):
        _out = vf_prod(t, y, args, _control)
        nonlocal vf_prod_tree
        structure = jax.tree_structure(_out)
        if vf_prod_tree is sentinel:
            vf_prod_tree = structure
        else:
            assert vf_prod_tree == structure
        return _out

    jac = make_jac(_fn)(control)
    assert vf_prod_tree is not sentinel
    if jax.tree_structure(None) in (vf_prod_tree, control_tree):
        # An unusual/not-useful edge case to handle.
        raise NotImplementedError(
            "`materialise_vf` not implemented for `None` controls or states."
        )
    return jax.tree_transpose(vf_prod_tree, control_tree, jac)

def _assert_array(x):
    if not isinstance(x, jnp.ndarray):
        raise NotImplementedError("`sde_kl_divergence` can only handle array-valued drifts and diffusions")

class _AugDrift(AbstractTerm):
    drift1: ODETerm
    drift2: ODETerm
    diffusion: AbstractTerm

   def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
        y, _ = y
        drift1 = self.drift1.vf(t, y, args)
        drift2 = self.drift2.vf(t, y, args)
        _assert_array(drift1)
        _assert_array(drift2)
        # Ugly hack special-casing built-in control terms.
        if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
            diffusion = self.diffusion.vf(t, y, args)
            _assert_array(diffusion)
            kl_divergence = _kl_diagonal(drift1, drift2, diffusion)
        elif isinstance(self.diffusion, ControlTerm):
            diffusion = self.diffusion.vf(t, y, args)
            _assert_array(diffusion)
            kl_divergence = _kl(drift1, drift2, diffusion)
        else:
            # TODO: think about how to handle arbitrary control terms here, without forcibly
            # materialising the whole diffusion matrix. It'll require analysing `self.diffusion.prod` or 
            # `self.diffusion.vf_prod` and looking at its structure, I think? Or possibly extending the
            # `AbstractTerm` api to require specifying how to invert things?
            warnings.warn("`sde_kl_divergence` may be slow when used with custom diffusion terms")
            diffusion = materialise_vf(t, y, args, self.diffusion.contr, self.diffusion.vf_prod)
            _assert_array(diffusion)
            kl_divergence = _kl_general(drift1, drift2, diffusion)
        kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
        return drift1, kl_divergence

    @staticmethod
    def contr(t0: Scalar, t1: Scalar) -> Scalar:
        return t1 - t0

    @staticmethod
    def prod(vf: PyTree, control: Scalar) -> PyTree:
        return jax.tree_map(lambda v: control * v, vf)

This approach works, but falls short in the general case in two main respects. As already discussed the first is handling general AbstractTerms for the diffusion.

The second is more subtle, and is the reason for the _assert_array statements. The current approach of tree-map'ing isn't actually mathematically correct. By way of example, suppose we chose to represent the state/drift as a list-of-scalars (rather than a one-dimensional array), and the diffusion as a list-of-list-scalars instead of as a matrix. Then the tree-map'ing would unpack the first list in the diffusion term, but leave the second list in place. In _kl we'd then try to compute jnp.pinverse(...list of scalars...). Obviously that isn't programatically defined, but more importantly: regardless of how we adjust our implementation of _kl we could never compute the thing desired, as we need the whole diffusion to do the inversion, and right now we only have a single column.

Once again this is something I'd need to think hard about how to handle efficiently in the general case. (As an inefficient general-case implementation you could use jax.flatten_util.ravel_pytree, though -- I'd be happy to have that in there with a warnings.warn if that branch is taken, just like the above case.)

Phew! As you can tell, all of this gets nontrivial fast.

By the way: this approach of materialising the diffusion matrix is something that's come up before in other contexts. I copied the code for doing that from AdjointTerm. If you decide to include materialise_vf in your implementation then do it factor out and use it in both places.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for such an insightful comment! I understand better how diffrax works now.

I like the idea of extending AbstractTerm to invert matricies (vectors). I will try to go for this direction with the your suggested code here.



def sde_kl_divergence(
*,
drift1: callable,
drift2: callable,
diffusion: callable,
context: callable,
y0: PyTree,
bm: AbstractBrownianPath,
*, sde1: MultiTerm, sde2: MultiTerm, context: callable, y0: PyTree
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've mentioned it earlier, but just to reiterate since this is the public API: I'd make this API accept drift1: ODETerm, drift2: ODETerm, diffusion: AbstractTerm, since that's what we actually need.

(e.g. a MultiTerm could include an arbitrary number of terms)

):
if context is None:
context = lambda t: None
aug_y0 = (y0, 0.0)
return (
_AugDrift(drift1, drift2, diffusion, context),
_AugDiffusion(diffusion),
aug_y0,
_AugBrownianPath(bm),
)
aug_drift = _AugVectorField(sde1, sde2, context=context)
aug_control = _AugControlTerm(sde1.terms[1])
aug_sde = MultiTerm(aug_drift, aug_control)

return aug_sde, aug_y0
6 changes: 3 additions & 3 deletions examples/neural_sde.ipynb → examples/neural_sde_gan.ipynb

Large diffs are not rendered by default.

Loading