-
-
Notifications
You must be signed in to change notification settings - Fork 134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Latent SDE #104
base: main
Are you sure you want to change the base?
Adding Latent SDE #104
Conversation
+ add Latent SDE (notebook, mkdocs) + change neural_sde.ipynd to neural_sde_gan.ipynb + fix doc links according to the change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you want to bump the version number? If we update the docs we should do a new release to people can use the update sde_kl_divergence
functionality.
diffrax/misc/sde_kl_divergence.py
Outdated
inv_diffusion = jnp.linalg.pinv(diffusion) | ||
scale = inv_diffusion @ (drift1 - drift2) | ||
if diffusion.ndim == 1: | ||
scale = (drift1 - drift2) / diffusion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So my original code here in sde_kl_divergence
was pretty hacky and not library-ready, and I think it'll still need some more work to get ready.
In particular I think it would make most sense to operate the level of terms. This would allow for abstracting over the kind of diffusion used -- e.g. ControlTerm
versus WeaklyDiagonalControlTerm
etc. -- rather than the current vector-field-based approach.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may need to bump the version number. Here, I have changed sde_kl_divergence
API from taking drift functions, a diffusion function ... into taking two MultiTerm
. Although there is a duplication in control terms as they share the same, this sounds more natural as we compare two SDEs.
diffrax/misc/sde_kl_divergence.py
Outdated
@@ -23,7 +26,7 @@ class _AugDrift(eqx.Module): | |||
def __call__(self, t, y, args): | |||
y, _ = y | |||
context = self.context(t) | |||
aug_y = jnp.concatenate([y, context], axis=-1) | |||
aug_y = jnp.concatenate([y, context], axis=-1) if context is not None else y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: flipping the if
and else
branches allows for switching if context is not None
down toj ust if context is None
.
examples/neural_sde_gan.ipynb
Outdated
@@ -725,7 +728,7 @@ | |||
"name": "python", | |||
"nbconvert_exporter": "python", | |||
"pygments_lexer": "ipython3", | |||
"version": "3.9.7" | |||
"version": "3.10.4" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the spurious changes to this file be removed? You don't need to actually re-run a notebook when you just make changes to the documentation, as it's just a big JSON file you can edit.
examples/neural_sde_vae.ipynb
Outdated
"source": [ | ||
"# Neural SDE (VAE)\n", | ||
"\n", | ||
"This implementation is based on the Pytorch version of Latent SDE from [`torchsde`](https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py) library. \n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: put this at the bottom of this introduction; the first line is the most important and this isn't the most important piece of information.
Honestly, the theory of latent SDEs is pretty nontrivial, and the remainder of this section pretty impenetrable, so I'd start off with a sentence here that just says something very simple to the effect of "this is a VAE".
If you want to give folks a readable reference for this topic then I'd recommend also adding a link to the appropriate section of On Neural Differential Equations. (I'm biased I suppose, but I definitely didn't find the original paper that clear on this front.)
examples/neural_sde_vae.ipynb
Outdated
"from diffrax import (MultiTerm, ODETerm, ControlTerm,\n", | ||
" diffeqsolve, Euler,\n", | ||
" SaveAt, VirtualBrownianTree)\n", | ||
"from diffrax.misc import sde_kl_divergence\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So anything not imported as diffrax.*
is considered private API. If we're going to expose this publicly then sde_kl_divergence
should be offered as diffrax.sde_kl_divergence
.
examples/neural_sde_vae.ipynb
Outdated
"from diffrax.misc import sde_kl_divergence\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import seaborn as sns" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To minimise dependencies can we do without seaborn
.
examples/neural_sde_vae.ipynb
Outdated
" maxval=1.6,\n", | ||
" shape=(16,),\n", | ||
" key=key))\n", | ||
" ys = jnp.sin(ts * 2 * 3.14)[:, None] * 0.8\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
jnp.pi
instead of 3.14
.
Looks like the formatting is failing. Have a look at I've not really gone through most of the example yet; I'll leave a proper review of that once everything so far has been organised. I will say that I don't think I really believe what's happening here, mathematically. In the infinite-training limit you're just matching the SDE against a single trajectory, so it collapses to an ODE (zero noise). Have a look at the Lorenz example in torchsde for a more convincing (to me) example of training a latent SDE as a generative model, rather than this case which I think is pretty much just supervised learning. (The real giveaway here is that you're using |
Thanks for the detailed review. I will get back on this after a few days :) |
Sorry for taking so long. In the recent commits, I have changed I've implemented the notebook of Latent SDE for Lorenz data as you suggested to make it more like VAE than just supervised learning. This takes some time for me to make it run. (It seems KL annealing is the trick to train this model) Your other comments into the recent changes are included as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so -- sorry for taking so long to do a review! Thanks again for this implementation, which I think is now nearly there.
As you can see I've got one big comment against the KL divergence implementation, but I've also provided a possible solution. So I think tweaking that file should be straightforward.
Apparently the VAE code is too large for me to leave a diff against it line-by-line, so some comments here instead:
In generate_lorenz
:
- Use
jnp.stack([foo, bar])
rather thanjnp.concatenate([foo[None], bar[None])
. - typo: normialize -> normalize
In the modules:
- The
super().__init__()
is essentially unnecessarily. (eqx.Module.__init__
does nothing). Good practice in Python is either (a) not to includesuper().__init__
, and treat the class as final (meaning "not subclassable"), or (b) to includesuper().__init__
but also accept**kwargs
in the__init__
and then forward them on assuper().__init__(**kwargs)
; this is known as co-operative multiple inheritance.
In the training:
- I think we could probably train for less time; the samples get good enough about halfway through I think.
I really like the visualisation throughout training; that looks really cool.
In passing, it's interesting to note just how small the diffusion is in the learnt model; much smaller than the dataset. This has always been a big weakness of latent SDEs. I feel like there's probably a way to tweak the loss function to try and fix that somehow. (I'm just musing about an open research question here though.)
Overall I like both the sde_kl_divergence
implementation in terms of terms, and the new example showing it off.
diffrax/misc/sde_kl_divergence.py
Outdated
|
||
class _AugControlTerm(ControlTerm): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should inherit from AbstractTerm
rather than ControlTerm
. At the moment you're using both inheritance (from ControlTerm
) and composition (passing in a ControlTerm
instance as an argument); almost always you only ever need one of these approaches.
In this case I think composition is most natural, since the "base" ControlTerm
already exists.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah you're right. I'll update it.
diffrax/misc/sde_kl_divergence.py
Outdated
sde1: MultiTerm | ||
sde2: MultiTerm | ||
context: callable | ||
kl: callable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should require an explicit drift1, drift2, and diffusion here. (Rather than wrapping them in MultiTerm
s.)
I'm also thinking we can probably just remove context
altogether? This isn't always used -- if you just want to compute the KL divergence between fixed SDEs -- and in the latent SDE case then the context can be handled via the args
that are passed through. So better to have a simpler API I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I just realize now that we can check the class of diffusion. I will go back to the API with drift1, drift2, and diffusion.
diffrax/misc/sde_kl_divergence.py
Outdated
kl: callable | ||
|
||
def __init__(self, sde1, sde2, context) -> None: | ||
super().__init__(sde1.terms[0].vector_field) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here also -- pick only one of composition or inheritance.
diffrax/misc/sde_kl_divergence.py
Outdated
diffusion = self.sde1.terms[1].vf(t, y, args) | ||
kl_divergence = jax.tree_map(self.kl, drift1, drift2, diffusion) | ||
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence) | ||
return drift1, kl_divergence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So unfortunately this approach is going to break whenever a custom diffusion term is used.
At the moment, the current implementation assumes that the diffusion is either a ControlTerm
-- which produces a diffusion matrix, and does a diffusion-control product as a matrix-vector multiply -- or a WeaklyDiagonalControlTerm
-- which produces a diagonal diffusion matrix, and does a diffusion-control product as a (diagonal-matrix)-vector multiply.
But the API around terms is specifically chosen so that the output of .vf(...)
can really be anything (dense matrix; diagonal matrix; ... someone may also wish to write something special for tridiagonal matrices, sparse matrices etc.) so that diffusion
could really be an arbitrarily-structured PyTree, for which .prod(...)
is the only thing that knows how to consume it.
I'd need to think a lot harder about what the general case really is here. I'd welcome any thoughts on how that might be done, but if seems more complicated than you really want to get in to right now, a simple-but-inefficient approach is to forcibly materialise the diffusion as a matrix, ignoring any custom (diagonal/sparse/whatever) structure. This won't be efficient with user-specified control terms, but at the very least won't break.
Untested, but I think something like the following would work as an implementation.
def materialise_vf(t, y, args, contr, vf_prod):
# Only used for its shape/dtype/structure; value is irrelevant
control = contr(t, t)
y_size = sum(np.size(yi) for yi in jax.tree_leaves(y))
control_size = sum(np.size(ci) for ci in jax.tree_leaves(control))
if y_size > control_size:
make_jac = jax.jacfwd
else:
make_jac = jax.jacrev
# Find the tree structure of vf_prod by smuggling it out as an additional
# result from the Jacobian calculation.
sentinel = vf_prod_tree = object()
control_tree = jax.tree_structure(control)
def _fn(_control):
_out = vf_prod(t, y, args, _control)
nonlocal vf_prod_tree
structure = jax.tree_structure(_out)
if vf_prod_tree is sentinel:
vf_prod_tree = structure
else:
assert vf_prod_tree == structure
return _out
jac = make_jac(_fn)(control)
assert vf_prod_tree is not sentinel
if jax.tree_structure(None) in (vf_prod_tree, control_tree):
# An unusual/not-useful edge case to handle.
raise NotImplementedError(
"`materialise_vf` not implemented for `None` controls or states."
)
return jax.tree_transpose(vf_prod_tree, control_tree, jac)
def _assert_array(x):
if not isinstance(x, jnp.ndarray):
raise NotImplementedError("`sde_kl_divergence` can only handle array-valued drifts and diffusions")
class _AugDrift(AbstractTerm):
drift1: ODETerm
drift2: ODETerm
diffusion: AbstractTerm
def vf(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree:
y, _ = y
drift1 = self.drift1.vf(t, y, args)
drift2 = self.drift2.vf(t, y, args)
_assert_array(drift1)
_assert_array(drift2)
# Ugly hack special-casing built-in control terms.
if isinstance(self.diffusion, WeaklyDiagonalControlTerm):
diffusion = self.diffusion.vf(t, y, args)
_assert_array(diffusion)
kl_divergence = _kl_diagonal(drift1, drift2, diffusion)
elif isinstance(self.diffusion, ControlTerm):
diffusion = self.diffusion.vf(t, y, args)
_assert_array(diffusion)
kl_divergence = _kl(drift1, drift2, diffusion)
else:
# TODO: think about how to handle arbitrary control terms here, without forcibly
# materialising the whole diffusion matrix. It'll require analysing `self.diffusion.prod` or
# `self.diffusion.vf_prod` and looking at its structure, I think? Or possibly extending the
# `AbstractTerm` api to require specifying how to invert things?
warnings.warn("`sde_kl_divergence` may be slow when used with custom diffusion terms")
diffusion = materialise_vf(t, y, args, self.diffusion.contr, self.diffusion.vf_prod)
_assert_array(diffusion)
kl_divergence = _kl_general(drift1, drift2, diffusion)
kl_divergence = jax.tree_util.tree_reduce(operator.add, kl_divergence)
return drift1, kl_divergence
@staticmethod
def contr(t0: Scalar, t1: Scalar) -> Scalar:
return t1 - t0
@staticmethod
def prod(vf: PyTree, control: Scalar) -> PyTree:
return jax.tree_map(lambda v: control * v, vf)
This approach works, but falls short in the general case in two main respects. As already discussed the first is handling general AbstractTerm
s for the diffusion.
The second is more subtle, and is the reason for the _assert_array
statements. The current approach of tree-map'ing isn't actually mathematically correct. By way of example, suppose we chose to represent the state/drift as a list-of-scalars (rather than a one-dimensional array), and the diffusion as a list-of-list-scalars instead of as a matrix. Then the tree-map'ing would unpack the first list in the diffusion term, but leave the second list in place. In _kl
we'd then try to compute jnp.pinverse(...list of scalars...)
. Obviously that isn't programatically defined, but more importantly: regardless of how we adjust our implementation of _kl
we could never compute the thing desired, as we need the whole diffusion to do the inversion, and right now we only have a single column.
Once again this is something I'd need to think hard about how to handle efficiently in the general case. (As an inefficient general-case implementation you could use jax.flatten_util.ravel_pytree
, though -- I'd be happy to have that in there with a warnings.warn
if that branch is taken, just like the above case.)
Phew! As you can tell, all of this gets nontrivial fast.
By the way: this approach of materialising the diffusion matrix is something that's come up before in other contexts. I copied the code for doing that from AdjointTerm. If you decide to include materialise_vf
in your implementation then do it factor out and use it in both places.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for such an insightful comment! I understand better how diffrax works now.
I like the idea of extending AbstractTerm
to invert matricies (vectors). I will try to go for this direction with the your suggested code here.
diffrax/misc/sde_kl_divergence.py
Outdated
context: callable, | ||
y0: PyTree, | ||
bm: AbstractBrownianPath, | ||
*, sde1: MultiTerm, sde2: MultiTerm, context: callable, y0: PyTree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've mentioned it earlier, but just to reiterate since this is the public API: I'd make this API accept drift1: ODETerm, drift2: ODETerm, diffusion: AbstractTerm
, since that's what we actually need.
(e.g. a MultiTerm
could include an arbitrary number of terms)
Hi all, thank you for the great work! I was just trying to run the code from the pull request and encountered this error
Was not sure if there was something I missed? I didn't change anything in the code. This occurs when I call the
Thanks in advance! |
Hmm. This looks like a bug in core JAX -- I'm not sure when |
Thanks for the quick reply Patrick! I see, this makes sense. It's strange that the other examples in Diffrax do still seem to work, I'll investigate this a bit more. |
@patrick-kidger Sorry for taking so long. I will try to get back to this pull request in a couple of days. What I can think of now is to handle the case that diffusion matrices are diagonal. @harrisonzhu508 I will take a look at the bug. If other examples do not have the problem, it may be because of the current implementation of latent SDE. |
Hi @harrisonzhu508, you can run the current code with As Patrick mentioned, it must be something to do with JAX core. I also find that the bug occurs when we use |
Hi @patrick-kidger , I tried to make If I understand correctly, the goal of If we can restrict our case where both The current implementation can handle block diagonal difussion matrices having PyTree as drift = {
"block1": jnp.zeros((2,)),
"block2": jnp.zeros((2,)),
"block3": jnp.zeros((3,)),
}
diffusion = {
"block1": jnp.ones((2,)),
"block2": jnp.ones((2, 3)),
"block3": jnp.ones((3, 4)),
} The first block corresponds to I also pass The difficulty I encounter when handling the more general case can be described in this code. import jax.tree_util as jtu
import jax.numpy as jnp
vf_prod = {'block1': jnp.ones((2,)), 'block2': jnp.ones((1))}
diffusion = {'block1': jnp.ones((2,)), "block2": [[1., 1., 1.]]}
# vf_prod_tree obtained either from `materialise_vf` or input `drift`
vf_prod_tree = jtu.tree_structure(vf_prod) # PyTreeDef({'block1': *, 'block2': *})
diffusion_tree = jtu.tree_structure(diffusion) # PyTreeDef({'block1': *, 'block2': [[*, *, *]]})
transposed = jtu.tree_map(lambda *xs: list(xs), *[vf_prod, diffusion])
# PyTreeDef({'block1': [*, *], 'block2': [*, [[*, *, *]]]})
# next step is to convert the diffusion part of `block2` to array. But we don't know how
# maybe can use `is_leaf` in `jtu.tree_map`. But what is the condition to decide a leaf? |
Thanks a lot! |
Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot! |
Hi, I guess this happens because the current parameter setting with
|
That makes sense, thanks for the explanation! I haven't got it working (I'm training on samples from a stochastic process) but I'll try and play around with the KL annealing! |
Hi Patrick,
Following up on the last discussion, I create a pull request containing
diffrax.misc.sde_kl_divegence
i.e., handling context and compute KLexamples/neural_sde_vae.ipynb
examples/neural_sde.ipynb
toexamples/neural_sde_gan.ipynb
(fix link in the description as well)