Skip to content

Fix zero-type of logjac for ReshapeTransform #851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 19, 2025

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Mar 18, 2025

We were always returning Ints. This caused a type instability when the transform of a variable depended on whether it was linked or not, and thus sometimes these Ints came into the mix and sometimes they didn't. Returning Ints was obviously silly for other reasons too.

I haven't added tests. We should generally go through adding @inferred tests to many DPPL internals, see #777, but that should be a separate PR.

Copy link
Contributor

github-actions bot commented Mar 18, 2025

Benchmark Report for Commit 9aa8d7bb5795530e6dd45a62b663dbb163d56959

Computer Information

Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.4 |                 1.6 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                616.3 |                39.6 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                367.9 |                50.2 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1231.3 |                26.5 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               4146.5 |                18.3 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1450.2 |                29.5 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                942.0 |                 5.3 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5608.2 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1059.0 |                 8.7 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              61904.2 |                 3.6 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8701.0 |                 9.9 |
|               Dynamic |        10 |    mooncake |             typed |   true |                122.3 |                12.1 |
|              Submodel |         1 |    mooncake |             typed |   true |                 25.0 |                 7.5 |
|                   LDA |        12 | reversediff |             typed |   true |                453.4 |                 4.7 |

@penelopeysm
Copy link
Member

Does this fix Enzyme issues, do you have an MWE I can run before/after to see?

Copy link

codecov bot commented Mar 18, 2025

Codecov Report

Attention: Patch coverage is 80.00000% with 5 lines in your changes missing coverage. Please review.

Project coverage is 84.43%. Comparing base (9df42bf) to head (9aa8d7b).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
src/distribution_wrappers.jl 61.53% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #851      +/-   ##
==========================================
+ Coverage   84.40%   84.43%   +0.03%     
==========================================
  Files          34       34              
  Lines        3840     3849       +9     
==========================================
+ Hits         3241     3250       +9     
  Misses        599      599              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

Pull Request Test Coverage Report for Build 13930337870

Details

  • 10 of 10 (100.0%) changed or added relevant lines in 1 file are covered.
  • 22 unchanged lines in 4 files lost coverage.
  • Overall coverage decreased (-3.3%) to 81.204%

Files with Coverage Reduction New Missed Lines %
src/model.jl 1 85.83%
src/simple_varinfo.jl 5 75.72%
src/varinfo.jl 6 82.04%
src/threadsafe.jl 10 60.55%
Totals Coverage Status
Change from base Build 13927957181: -3.3%
Covered Lines: 3115
Relevant Lines: 3836

💛 - Coveralls

@mhauru
Copy link
Member Author

mhauru commented Mar 18, 2025

Yes, you can run this, which is a reduced version of the MWE in TuringLang/Turing.jl#2510:

module MWE

using DynamicPPL: DynamicPPL
using Distributions: Beta
using Enzyme

mode = set_runtime_activity(Forward)

function f(x, vi)
    vi_x = DynamicPPL.unflatten(vi, x)
    logp = DynamicPPL.assume(
        Beta(), DynamicPPL.@varname(s), vi_x
    )[2]
    return logp
end

vi = DynamicPPL.VarInfo()
vi = DynamicPPL.push!!(vi, DynamicPPL.@varname(s), 1.0, Beta())
vi = DynamicPPL.TypedVarInfo(vi)

@show Enzyme.gradient(mode, f, [0.5], Const(vi))

end

@coveralls
Copy link

coveralls commented Mar 18, 2025

Pull Request Test Coverage Report for Build 13945812690

Details

  • 20 of 25 (80.0%) changed or added relevant lines in 3 files are covered.
  • 48 unchanged lines in 14 files lost coverage.
  • Overall coverage increased (+0.04%) to 84.525%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/distribution_wrappers.jl 8 13 61.54%
Files with Coverage Reduction New Missed Lines %
ext/DynamicPPLEnzymeCoreExt.jl 1 0.0%
ext/DynamicPPLForwardDiffExt.jl 1 63.64%
src/distribution_wrappers.jl 1 42.86%
src/sampler.jl 1 89.06%
src/varnamedvector.jl 1 90.1%
src/logdensityfunction.jl 2 52.27%
src/utils.jl 2 73.7%
src/contexts.jl 3 64.62%
src/model.jl 3 84.17%
src/values_as_in_model.jl 3 59.52%
Totals Coverage Status
Change from base Build 13927957181: 0.04%
Covered Lines: 3250
Relevant Lines: 3845

💛 - Coveralls

@mhauru
Copy link
Member Author

mhauru commented Mar 18, 2025

A fair question you may ask is: Why did we have false confidence in saying that these sorts of deep internals of DPPL were type stable? If not for that false confidence, I would have looked into the Enzyme error in greater detail earlier.

I think the answer is that we assumed (I did, at least) that type instability in something like invlink_with_logpdf should have caused a big performance hit, which we should have noticed. The bug fixed in this PR turns out not to cause a noticeable performance hit (you can compare the benchmark results above to some recently merged PR, they are nearly identical), presumably because the type instability is quite "mild" and short-lived: There's one variable that gets a type of Union{Float64, Int64} and is then very soon after added to something that is definitely a Float64, thus stabilising types from that point onwards again. This doesn't really cause any trouble performance-wise, but it does cause trouble for Enzyme.

Not sure what to learn from that. I wonder if JET.jl would have caught this?

Copy link
Member

@penelopeysm penelopeysm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really nice 🎉 I ran the MWE in that Turing issue too!

I would in principle be happy to approve this on its own. However, I grepped the repo for 0 (not a smart idea i know) and it seems there are a few '0's inside the NoDist code for src/distribution_wrappers.jl (lines 57-81). It seems like a bit of an edge case, but should we fix those as well while we're here?

@wsmoses
Copy link
Contributor

wsmoses commented Mar 18, 2025

Awesome find @mhauru !

to give us some confidence in the future, does it make sense to add Enztme to the test suite here?

@mhauru
Copy link
Member Author

mhauru commented Mar 18, 2025

@penelopeysm, good point, changed.

I'm now having doubts about whether this is the right thing to do though. If one asks for the log determinant of a Jacobian of a transform where a vector is reshaped to a matrix, you could say that the type is ill-defined. In practice, if we make it be the element type of the vector being transformed, which is what this PR does, that will probably result in less trouble than just returning an Int. But you could in principle be transforming a vector of booleans, and then we would say that the logjac is false, which isn't great. Probably the least bad option is to wrap things in float_type_with_fallback, I'll add that.

Note that in some sense the real root cause here, and a type instability we still aren't fixing in this PR, is that from_maybe_linked_internal_transform returns a different function depending on whether some variable in a VarInfo is linked or not. That instability I just don't know how to fix.

@mhauru
Copy link
Member Author

mhauru commented Mar 18, 2025

@wsmoses, that would be good, see #813

@wsmoses
Copy link
Contributor

wsmoses commented Mar 18, 2025

Is there anything blocking that? I had opened an earlier quick PR to add here https://github.com/TuringLang/DynamicPPL.jl/pull/743/files which was closed, seeming in favor of that one?

@penelopeysm
Copy link
Member

penelopeysm commented Mar 18, 2025

Markus and I discussed it a bit and basically we think that

  • the type of logp and logjac should be the same, but it should be independent of the type of the values. we think this should default to float(Real) i.e. f64 on 64-bit systems, f32 on 32-bit.
  • this is because values and logp are two different things, we wouldn't ever add one to the other
  • right now we use float_type_with_fallback which is not the same, because if the value is a f32, then float_type_with_fallback(typeof(value)) will be f32. We don't want the logp to then be a f32
  • it would be pretty neat if the user could control the precision that varinfo uses for its logp field, however we think it'd be easier to tidy up the dispatch chains in src/varinfo.jl before trying to change this

Note that VarInfo currently uses f64 for its logp and this is hardcoded in

VarInfo(meta=Metadata()) = VarInfo(meta, Ref{Float64}(0.0), Ref(0))

whereas SimpleVarInfo is kind of all over the place

# Makes things a bit more readable vs. putting `Float64` everywhere.
const SIMPLEVARINFO_DEFAULT_ELTYPE = Float64
function SimpleVarInfo{NT,T}(values, logp) where {NT,T}
return SimpleVarInfo{NT,T,NoTransformation}(values, logp, NoTransformation())
end
function SimpleVarInfo{T}(θ) where {T<:Real}
return SimpleVarInfo{typeof(θ),T}(θ, zero(T))
end
# Constructors without type-specification.
SimpleVarInfo(θ) = SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
function SimpleVarInfo::Union{<:NamedTuple,<:AbstractDict})
return if isempty(θ)
# Can't infer from values, so we just use default.
SimpleVarInfo{SIMPLEVARINFO_DEFAULT_ELTYPE}(θ)
else
# Infer from `values`.
SimpleVarInfo{float_type_with_fallback(infer_nested_eltype(typeof(θ)))}(θ)
end
end
SimpleVarInfo(values, logp) = SimpleVarInfo{typeof(values),typeof(logp)}(values, logp)
# Using `kwargs` to specify the values.
function SimpleVarInfo{T}(; kwargs...) where {T<:Real}
return SimpleVarInfo{T}(NamedTuple(kwargs))
end
function SimpleVarInfo(; kwargs...)
return SimpleVarInfo(NamedTuple(kwargs))
end
# Constructor from `Model`.
function SimpleVarInfo(
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
)
return SimpleVarInfo{Float64}(model, args...)
end
function SimpleVarInfo{T}(
model::Model, args::Union{AbstractVarInfo,AbstractSampler,AbstractContext}...
) where {T<:Real}
return last(evaluate!!(model, SimpleVarInfo{T}(), args...))
end
# Constructor from `VarInfo`.
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
end
function SimpleVarInfo{T}(
vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
) where {T<:Real,names,D}
values = values_as(vi, D)
return SimpleVarInfo(values, convert(T, getlogp(vi)))
end
function untyped_simple_varinfo(model::Model)
varinfo = SimpleVarInfo(OrderedDict())
return last(evaluate!!(model, varinfo, SamplingContext()))
end
function typed_simple_varinfo(model::Model)
varinfo = SimpleVarInfo{Float64}()
return last(evaluate!!(model, varinfo, SamplingContext()))
end

@penelopeysm
Copy link
Member

penelopeysm commented Mar 18, 2025

Is there anything blocking that?

Yes and no. Some of them segfault; due to IRL stuff I've unfortunately been way too busy to look into it carefully and to report issues upstream, although that has been the intention. It's possible with this change more of them will pass. Also it takes a very long time to run Enzyme tests and it's unlikely that we will want to run Enzyme on every combination of model + varinfo data structure in CI. I have a separate PR to make the AD testing interface more customisable (#799) but again I haven't had enough time to really get into that. This is stuff we can control on our end though so I think the most immediate blockers to that are on our end 🙂

@mhauru
Copy link
Member Author

mhauru commented Mar 19, 2025

Based on the conversation @penelopeysm mentioned above, I introduced

const LogProbType = float(Real)

and made all these logjacs be zero(LogProbType).

@mhauru mhauru requested a review from penelopeysm March 19, 2025 11:53
@mhauru mhauru added this pull request to the merge queue Mar 19, 2025
Merged via the queue into main with commit 915059b Mar 19, 2025
17 of 18 checks passed
@mhauru mhauru deleted the mhauru/reshapetransform-jac-eltype branch March 19, 2025 17:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants