-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Latent SDE #104
base: main
Are you sure you want to change the base?
Adding Latent SDE #104
Changes from 9 commits
0378b23
dbed128
71e3ccc
47c87ec
1513384
9217c06
a59a111
7ab9bb0
92a1842
da636e5
d8a68b2
dc5fde4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,10 @@ | ||
import operator | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
from ..brownian import AbstractBrownianPath | ||
from ..custom_types import PyTree | ||
from ..custom_types import PyTree, Scalar | ||
from ..term import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm | ||
|
||
|
||
def _kl(drift1, drift2, diffusion): | ||
|
@@ -14,61 +13,77 @@ def _kl(drift1, drift2, diffusion): | |
return 0.5 * jnp.sum(scale**2) | ||
|
||
|
||
class _AugDrift(eqx.Module): | ||
drift1: callable | ||
drift2: callable | ||
diffusion: callable | ||
context: callable | ||
def _kl_diagonal(drift1, drift2, diffusion): | ||
# stable division | ||
diffusion = jnp.where( | ||
jax.lax.stop_gradient(diffusion) > 1e-7, | ||
diffusion, | ||
jnp.full_like(diffusion, fill_value=1e-7) * jnp.sign(diffusion), | ||
) | ||
scale = (drift1 - drift2) / diffusion | ||
return 0.5 * jnp.sum(scale**2) | ||
|
||
def __call__(self, t, y, args): | ||
y, _ = y | ||
context = self.context(t) | ||
aug_y = jnp.concatenate([y, context], axis=-1) | ||
drift1 = self.drift1(t, aug_y, args) | ||
drift2 = self.drift2(t, y, args) | ||
diffusion = self.diffusion(t, y, args) | ||
kl_divergence = jax.tree_map(_kl, drift1, drift2, diffusion) | ||
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence) | ||
return drift1, kl_divergence | ||
|
||
class _AugControlTerm(ControlTerm): | ||
|
||
control_term: ControlTerm | ||
|
||
class _AugDiffusion(eqx.Module): | ||
diffusion: callable | ||
def __init__(self, control_term: ControlTerm) -> None: | ||
super().__init__(control_term.vector_field, control_term.control) | ||
self.control_term = control_term | ||
|
||
def __call__(self, t, y, args): | ||
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: | ||
y, _ = y | ||
diffusion = self.diffusion(t, y, args) | ||
return diffusion, 0.0 | ||
vf = self.control_term.vf(t, y, args) | ||
return vf, 0.0 | ||
|
||
def contr(self, t0: Scalar, t1: Scalar) -> PyTree: | ||
return self.control_term.contr(t0, t1), 0.0 | ||
|
||
class _AugBrownianPath(eqx.Module): | ||
bm: AbstractBrownianPath | ||
def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree: | ||
y, _ = y | ||
return self.control_term.vf_prod(t, y, args, control), 0.0 | ||
|
||
@property | ||
def t0(self): | ||
return self.bm.t0 | ||
|
||
@property | ||
def t1(self): | ||
return self.bm.t1 | ||
class _AugVectorField(ODETerm): | ||
|
||
def evaluate(self, t0, t1): | ||
return self.bm.evaluate(t0, t1), 0.0 | ||
sde1: MultiTerm | ||
sde2: MultiTerm | ||
context: callable | ||
kl: callable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should require an explicit drift1, drift2, and diffusion here. (Rather than wrapping them in I'm also thinking we can probably just remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I just realize now that we can check the class of diffusion. I will go back to the API with drift1, drift2, and diffusion. |
||
|
||
def __init__(self, sde1, sde2, context) -> None: | ||
super().__init__(sde1.terms[0].vector_field) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here also -- pick only one of composition or inheritance. |
||
if sde1.terms[1] is not sde2.terms[1]: | ||
raise ValueError("Two SDEs should share the same control terms") | ||
self.sde1 = sde1 | ||
self.sde2 = sde2 | ||
if isinstance(self.sde1.terms[1], WeaklyDiagonalControlTerm): | ||
self.kl = _kl_diagonal | ||
else: | ||
self.kl = _kl | ||
self.context = context | ||
|
||
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: | ||
y, _ = y | ||
context = self.context(t) | ||
aug_y = y if context is None else jnp.concatenate([y, context], axis=-1) | ||
drift1 = self.sde1.terms[0].vf(t, aug_y, args) | ||
drift2 = self.sde2.terms[0].vf(t, y, args) | ||
diffusion = self.sde1.terms[1].vf(t, y, args) | ||
kl_divergence = jax.tree_map(self.kl, drift1, drift2, diffusion) | ||
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence) | ||
return drift1, kl_divergence | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So unfortunately this approach is going to break whenever a custom diffusion term is used. At the moment, the current implementation assumes that the diffusion is either a But the API around terms is specifically chosen so that the output of I'd need to think a lot harder about what the general case really is here. I'd welcome any thoughts on how that might be done, but if seems more complicated than you really want to get in to right now, a simple-but-inefficient approach is to forcibly materialise the diffusion as a matrix, ignoring any custom (diagonal/sparse/whatever) structure. This won't be efficient with user-specified control terms, but at the very least won't break. Untested, but I think something like the following would work as an implementation. def materialise_vf(t, y, args, contr, vf_prod):
# Only used for its shape/dtype/structure; value is irrelevant
control = contr(t, t)
y_size = sum(np.size(yi) for yi in jax.tree_leaves(y))
control_size = sum(np.size(ci) for ci in jax.tree_leaves(control))
if y_size > control_size:
make_jac = jax.jacfwd
else:
make_jac = jax.jacrev
# Find the tree structure of vf_prod by smuggling it out as an additional
# result from the Jacobian calculation.
sentinel = vf_prod_tree = object()
control_tree = jax.tree_structure(control)
def _fn(_control):
_out = vf_prod(t, y, args, _control)
nonlocal vf_prod_tree
structure = jax.tree_structure(_out)
if vf_prod_tree is sentinel:
vf_prod_tree = structure
else:
assert vf_prod_tree == structure
return _out
jac = make_jac(_fn)(control)
assert vf_prod_tree is not sentinel
if jax.tree_structure(None) in (vf_prod_tree, control_tree):
# An unusual/not-useful edge case to handle.
raise NotImplementedError(
"`materialise_vf` not implemented for `None` controls or states."
)
return jax.tree_transpose(vf_prod_tree, control_tree, jac)
def _assert_array(x):
if not isinstance(x, jnp.ndarray):
raise NotImplementedError("`sde_kl_divergence` can only handle array-valued drifts and diffusions")
class _AugDrift(AbstractTerm):
drift1: ODETerm
drift2: ODETerm
diffusion: AbstractTerm
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
drift1 = self.drift1.vf(t, y, args)
drift2 = self.drift2.vf(t, y, args)
_assert_array(drift1)
_assert_array(drift2)
# Ugly hack special-casing built-in control terms.
if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
diffusion = self.diffusion.vf(t, y, args)
_assert_array(diffusion)
kl_divergence = _kl_diagonal(drift1, drift2, diffusion)
elif isinstance(self.diffusion, ControlTerm):
diffusion = self.diffusion.vf(t, y, args)
_assert_array(diffusion)
kl_divergence = _kl(drift1, drift2, diffusion)
else:
# TODO: think about how to handle arbitrary control terms here, without forcibly
# materialising the whole diffusion matrix. It'll require analysing `self.diffusion.prod` or
# `self.diffusion.vf_prod` and looking at its structure, I think? Or possibly extending the
# `AbstractTerm` api to require specifying how to invert things?
warnings.warn("`sde_kl_divergence` may be slow when used with custom diffusion terms")
diffusion = materialise_vf(t, y, args, self.diffusion.contr, self.diffusion.vf_prod)
_assert_array(diffusion)
kl_divergence = _kl_general(drift1, drift2, diffusion)
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
return drift1, kl_divergence
@staticmethod
def contr(t0: Scalar, t1: Scalar) -> Scalar:
return t1 - t0
@staticmethod
def prod(vf: PyTree, control: Scalar) -> PyTree:
return jax.tree_map(lambda v: control * v, vf) This approach works, but falls short in the general case in two main respects. As already discussed the first is handling general The second is more subtle, and is the reason for the Once again this is something I'd need to think hard about how to handle efficiently in the general case. (As an inefficient general-case implementation you could use Phew! As you can tell, all of this gets nontrivial fast. By the way: this approach of materialising the diffusion matrix is something that's come up before in other contexts. I copied the code for doing that from AdjointTerm. If you decide to include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for such an insightful comment! I understand better how diffrax works now. I like the idea of extending |
||
|
||
|
||
def sde_kl_divergence( | ||
*, | ||
drift1: callable, | ||
drift2: callable, | ||
diffusion: callable, | ||
context: callable, | ||
y0: PyTree, | ||
bm: AbstractBrownianPath, | ||
*, sde1: MultiTerm, sde2: MultiTerm, context: callable, y0: PyTree | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've mentioned it earlier, but just to reiterate since this is the public API: I'd make this API accept (e.g. a |
||
): | ||
if context is None: | ||
context = lambda t: None | ||
aug_y0 = (y0, 0.0) | ||
return ( | ||
_AugDrift(drift1, drift2, diffusion, context), | ||
_AugDiffusion(diffusion), | ||
aug_y0, | ||
_AugBrownianPath(bm), | ||
) | ||
aug_drift = _AugVectorField(sde1, sde2, context=context) | ||
aug_control = _AugControlTerm(sde1.terms[1]) | ||
aug_sde = MultiTerm(aug_drift, aug_control) | ||
|
||
return aug_sde, aug_y0 |
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should inherit from
AbstractTerm
rather thanControlTerm
. At the moment you're using both inheritance (fromControlTerm
) and composition (passing in aControlTerm
instance as an argument); almost always you only ever need one of these approaches.In this case I think composition is most natural, since the "base"
ControlTerm
already exists.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you're right. I'll update it.