-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
KL Divergence for Latent SDEs #463
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I really like the example here!
I'm afraid this might still take a bit more iteration to get to something clean enough to merge, though -- see my comments. :)
diffrax/_solver/kl.py
Outdated
The input must be a `MultiTerm` composed of the first SDE with drift `f` | ||
and diffusion `g` and the second either a SDE or just the drift term | ||
(since the diffusion is assumed to be the same). For example, a type | ||
of: `MuliTerm(MultiTerm(ODETerm, _DiffusionTerm), ODETerm)`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per this comment:
#402 (review)
and also the updated term docs:
https://docs.kidger.site/diffrax/api/terms/
then this outer MultiTerm
isn't really in-keeping. We're not adding all of these extra terms on to the same evolving state.
Bearing in mind that the rest of Diffrax has to see this as just another SDE solve.
I think this one might take a bit more iteration to get to something that's obeying the abstractions in the way they're designed, I'm afraid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I get what you're saying, multiterm implies a single differential equation "unit". So the composed multi terms is bad form. However, I'm not sure I see the difficulty going forward, I can replace it with tuple (multiterm, multiterm) or even tuple (multiterm, ode term). Which seems to adhere to this principle of multiterm = sde unit, since we are integrating two simultaneous SDEs, while also falling in line with other solvers (such as implicit Euler as you remarked).
On the terms vs solver approach, I am open to both. I think in my many iterations/experimentations I found the solver approach more in line with my thinking about the nature of the problem, specifically the original idea of (terms, kl_term) I didn't see as appealing since the KL_term relies on information from the other term and I didn't see a clean way to do that. However, having terms with a term wrapper is very doable (but may not mesh with the repo as well).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe introducing a totally new term is ok (given the remarks in #453), in which case the approach of a KLTerm (rather than a solver), is doable. Given the restricted nature of terms so far, I originally thought that wasn't in line with the package
Happy to iterate on cleaning it, I think the biggest question is the design one (on solvers, terms, and how to represent the problem in diffrax). Once that is resolved, I can iterate quickly to get the rest in :) |
Ok, I took the feedback from Andraz's Langevin PR regarding terms and incorporated it into this PR. I think it made things simpler and more in line with the diffrax philosophy, let me know what you think. Basically, now like Langevin, there's just a function that accepts multi term and returns a multi term of private terms that can be consumed by any solver. The reason I went with returning a single multiterm is you are kind of only solving the one SDE. You use the prior SDE to inform the KL divergence, but its not like fully integrated or anything |
A continuation of #402 with the new 0.6.0 lineax changes