-
Notifications
You must be signed in to change notification settings - Fork 31
Fix zero-type of logjac for ReshapeTransform #851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Benchmark Report for Commit
|
Does this fix Enzyme issues, do you have an MWE I can run before/after to see? |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #851 +/- ##
==========================================
+ Coverage 84.40% 84.43% +0.03%
==========================================
Files 34 34
Lines 3840 3849 +9
==========================================
+ Hits 3241 3250 +9
Misses 599 599 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 13930337870Details
💛 - Coveralls |
Yes, you can run this, which is a reduced version of the MWE in TuringLang/Turing.jl#2510: module MWE
using DynamicPPL: DynamicPPL
using Distributions: Beta
using Enzyme
mode = set_runtime_activity(Forward)
function f(x, vi)
vi_x = DynamicPPL.unflatten(vi, x)
logp = DynamicPPL.assume(
Beta(), DynamicPPL.@varname(s), vi_x
)[2]
return logp
end
vi = DynamicPPL.VarInfo()
vi = DynamicPPL.push!!(vi, DynamicPPL.@varname(s), 1.0, Beta())
vi = DynamicPPL.TypedVarInfo(vi)
@show Enzyme.gradient(mode, f, [0.5], Const(vi))
end |
Pull Request Test Coverage Report for Build 13945812690Details
💛 - Coveralls |
A fair question you may ask is: Why did we have false confidence in saying that these sorts of deep internals of DPPL were type stable? If not for that false confidence, I would have looked into the Enzyme error in greater detail earlier. I think the answer is that we assumed (I did, at least) that type instability in something like Not sure what to learn from that. I wonder if JET.jl would have caught this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really nice 🎉 I ran the MWE in that Turing issue too!
I would in principle be happy to approve this on its own. However, I grepped the repo for 0
(not a smart idea i know) and it seems there are a few '0's inside the NoDist
code for src/distribution_wrappers.jl
(lines 57-81). It seems like a bit of an edge case, but should we fix those as well while we're here?
Awesome find @mhauru ! to give us some confidence in the future, does it make sense to add Enztme to the test suite here? |
@penelopeysm, good point, changed. I'm now having doubts about whether this is the right thing to do though. If one asks for the log determinant of a Jacobian of a transform where a vector is reshaped to a matrix, you could say that the type is ill-defined. In practice, if we make it be the element type of the vector being transformed, which is what this PR does, that will probably result in less trouble than just returning an Int. But you could in principle be transforming a vector of booleans, and then we would say that the Note that in some sense the real root cause here, and a type instability we still aren't fixing in this PR, is that |
Is there anything blocking that? I had opened an earlier quick PR to add here https://github.com/TuringLang/DynamicPPL.jl/pull/743/files which was closed, seeming in favor of that one? |
Markus and I discussed it a bit and basically we think that
Note that VarInfo currently uses f64 for its logp and this is hardcoded in Line 824 in 9df42bf
whereas SimpleVarInfo is kind of all over the place DynamicPPL.jl/src/simple_varinfo.jl Lines 202 to 265 in 9df42bf
|
Yes and no. Some of them segfault; due to IRL stuff I've unfortunately been way too busy to look into it carefully and to report issues upstream, although that has been the intention. It's possible with this change more of them will pass. Also it takes a very long time to run Enzyme tests and it's unlikely that we will want to run Enzyme on every combination of model + varinfo data structure in CI. I have a separate PR to make the AD testing interface more customisable (#799) but again I haven't had enough time to really get into that. This is stuff we can control on our end though so I think the most immediate blockers to that are on our end 🙂 |
Based on the conversation @penelopeysm mentioned above, I introduced const LogProbType = float(Real) and made all these logjacs be |
We were always returning
Int
s. This caused a type instability when the transform of a variable depended on whether it was linked or not, and thus sometimes theseInt
s came into the mix and sometimes they didn't. ReturningInt
s was obviously silly for other reasons too.I haven't added tests. We should generally go through adding
@inferred
tests to many DPPL internals, see #777, but that should be a separate PR.