GPU fixes#32
Conversation
…ault enzyme machinery.
# Conflicts: # Project.toml # ext/DynamicPPLExt.jl # ext/LogDensityProblemsExt.jl # test/test-DEER-Turing-Logistic.jl # test/test-Turing-Integration.jl
Codecov Report❌ Patch coverage is
❌ Your patch check has failed because the patch coverage (22.61%) is below the target coverage (90.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #32 +/- ##
==========================================
- Coverage 88.65% 80.10% -8.56%
==========================================
Files 6 8 +2
Lines 1040 1166 +126
==========================================
+ Hits 922 934 +12
- Misses 118 232 +114 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
Hey @rsenne, need a review here? |
|
Hi @gdalle yes I would love that. This is my first pass on this and it would be much appreciated! |
| return DI.prepare_pushforward( | ||
| f, _hvp_forward_backend(backend), x_template, (v_template,); strict=Val(false) | ||
| ) |
There was a problem hiding this comment.
Why not use DI.hvp directly here?
There was a problem hiding this comment.
If it is because you need a batched gradient, you may be interested in JuliaDiff/DifferentiationInterface.jl#991
There was a problem hiding this comment.
Note that you can already batch a small number of tangents by passing a tuple though
There was a problem hiding this comment.
Two reasons:
- The lack of batching (happy to help tackle this!)
- Perhaps, more importantly, every second order method I've tried breaks
Here are two MRWE if interested
#=
Direct `DI.hvp(logp, ...)` fails for every Mooncake-based config on GPU.
A single Mooncake reverse pass over the user's `gradlogp` succeeds.
Target (same shape as `test/test-GPU-AD-HVP.jl`):
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Run from repo root:
julia --project=test dev/di_hvp_gpu_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Mooncake
import ForwardDiff
import CUDA
import LinearAlgebra: dot
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== reverse-on-grad (Mooncake gradient of dot ∘ gradlogp) ==")
try
closure = β -> dot(gradlogp(β), v_gpu)
g = DI.gradient(closure, AutoMooncake(; config=nothing), β_gpu)
g_vec = g isa Tuple ? first(g) : g
Hv = Array(g_vec)
println(" Hv[1:3] = ", Hv[1:3])
println(" matches: ", isapprox(Hv, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoMooncake (DI default SecondOrder)" => AutoMooncake(; config=nothing),
"SecondOrder(AutoForwardDiff, AutoMooncake)" => SecondOrder(AutoForwardDiff(), AutoMooncake(; config=nothing)),
"SecondOrder(AutoMooncake, AutoForwardDiff)" => SecondOrder(AutoMooncake(; config=nothing), AutoForwardDiff()),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 600))
end
end#=
Direct `DI.hvp(logp, AutoEnzyme(...))` fails for every config on GPU.
A forward-mode Enzyme pushforward of the user's `gradlogp` succeeds.
Same target as `dev/di_hvp_gpu_mwe.jl` / `test/test-GPU-AD-HVP.jl`:
logp(β) = -0.5 * (||β||^2 + ||Xβ||^2 / N)
gradlogp(β) = -(β + Xᵀ X β / N)
Hv = -v - Xᵀ X v / N
Five Enzyme variants tested. Three hard-abort the Julia process during
Enzyme compilation and therefore cannot share a script with the rest:
AutoEnzyme() — hard abort during compile
AutoEnzyme(mode=Reverse) — hard abort during compile
SecondOrder(AutoEnzyme(Forward),
AutoEnzyme(Reverse)) — hard abort during compile
To reproduce the aborts run one variant at a time. The two below throw
catchable exceptions and run cleanly in the same process.
Run from repo root:
julia --project=test dev/di_hvp_gpu_enzyme_mwe.jl
=#
push!(LOAD_PATH, abspath(joinpath(@__DIR__, "..")))
using ParallelMCMC
using ADTypes
using DifferentiationInterface
const DI = DifferentiationInterface
import Enzyme
import CUDA
using Random
CUDA.functional() || error("requires functional CUDA device")
function _logp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
-oftype(zero(eltype(β)), 0.5) * (sum(abs2, β) + sum(abs2, Xβ) / N)
end
function _gradlogp_single(β, X)
Xβ = pmcmc_matmul(X, β)
N = oftype(zero(eltype(β)), size(X, 1))
Y = pmcmc_matmul(transpose(X), Xβ)
Y = Y ./ N
Y = Y .+ β
return -Y
end
const D = 20
const Ndat = 64
rng = MersenneTwister(20251231)
X_cpu = randn(rng, Float32, Ndat, D)
X_gpu = CUDA.CuMatrix(X_cpu)
β_cpu = randn(rng, Float32, D)
v_cpu = ones(Float32, D)
β_gpu = CUDA.CuArray(β_cpu)
v_gpu = CUDA.CuArray(v_cpu)
logp(β) = _logp_single(β, X_gpu)
gradlogp(β) = _gradlogp_single(β, X_gpu)
ref = -v_cpu .- (transpose(X_cpu) * (X_cpu * v_cpu)) ./ Float32(Ndat)
println("analytic Hv[1:3]: ", ref[1:3])
println()
println("== forward-on-grad (Enzyme.Forward pushforward of gradlogp) ==")
try
be = AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)
prep = DI.prepare_pushforward(gradlogp, be, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.pushforward(gradlogp, prep, be, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
println(" Hv[1:3] = ", host[1:3])
println(" matches: ", isapprox(host, ref; atol=1f-3, rtol=1f-2))
catch e
println(" FAILED: ", first(sprint(showerror, e), 600))
end
println()
configs = [
"AutoEnzyme(mode=Forward)" =>
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
"SecondOrder(AutoEnzyme(Reverse), AutoEnzyme(Forward))" =>
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const)),
]
for (lbl, b) in configs
print("== $lbl ==\n ")
try
prep = DI.prepare_hvp(logp, b, β_gpu, (v_gpu,); strict=Val(false))
Hv = DI.hvp(logp, prep, b, β_gpu, (v_gpu,))
vec_ = Hv isa Tuple ? first(Hv) : Hv
host = Array(vec_)
ok = isapprox(host, ref; atol=1f-3, rtol=1f-2)
println("Hv[1:3]: ", host[1:3], " matches: ", ok)
catch e
println("FAILED: ", first(sprint(showerror, e), 800))
end
endTotally and completely possible I am doing something fundamentally wrong--in which case please correct me
| (c::_BatchHvpReverseClosure)(X, V) = pmcmc_dotsum(c.grad_batch(X), V) | ||
|
|
||
| #= | ||
| Pick the AD-HVP fallback strategy from the user's backend. |
There was a problem hiding this comment.
This whole logic exists in DI so it would be great if we added batching there and that alleviated your troubles here
| Owned wrappers: identical semantics to their Base counterparts, but provide | ||
| stable function identities for backend-specific AD rules in `ext/EnzymeExt.jl` | ||
| without committing type piracy on `Base.*` / `Base.dot` / `Base.sum`. User | ||
| model code that wants those rules to fire (notably on GPU, where the default | ||
| rules trip on cuBLAS gc-transition bundles) should call these instead. | ||
|
|
||
| Why both a matmul and reductions: the GPU AD-HVP path runs Enzyme reverse mode | ||
| through `gradlogp` *and* through a scalar reduction wrapping it (see DEER's | ||
| `_HvpReverseClosure`). The reduction is what actually emits the | ||
| `cuMemcpyDtoHAsync_v2` that crashes Enzyme. Owning both the matmul AND the | ||
| reduction lets Enzyme treat each as opaque, so neither bundle ever enters | ||
| its IR. |
There was a problem hiding this comment.
I'm really surprised by Enzyme reacting badly to LinearAlgebra, I suspect you might be holding it wrong. @wsmoses send help
|
hey @gdalle I addressed the two addressable point (e.g., type instability and nixing symbols) thoughts? Also, included 2 MWRE above. If I'm not crazy--I can open issues for these on the respective repos though I'm not confident yet till someone who knows better than i says so. Also, happy to help tackle the linked DI issue for batching--it seems reasonably approachable? Let me know what you think! |
|
Hi @wsmoses -- could I get your review on the Enzyme extension I put together here? The basic point of the extension was to get some basic LinAlg working on the GPU, but I have concerns I may have overengineered here. So, before I commit any of these changes I want to make sure what I have implemented makes sense. Thanks! |
| if !(B isa Const) | ||
| Av = cache_A !== nothing ? cache_A : A.val | ||
| B.dval .+= dr .* Av | ||
| end |
There was a problem hiding this comment.
you also need to fill!(dr, 0) at the end
There was a problem hiding this comment.
Quick question on this , pmcmc_dot and pmcmc_dotsum return scalars so I annotated them Active, so in reverse dr = dret.val is a Float32 value rather than a buffer, and fill! on it errors. Did you mean to restructure these to carry a Ref{T} shadow through AugmentedReturn (mirroring the dY pattern in the matmul rule). Want to make sure I do what you actually had in mind.
| if !(b isa Const) | ||
| av = cache_a !== nothing ? cache_a : a.val | ||
| b.dval .+= dr .* av | ||
| end |
|
|
||
| function EnzymeRules.forward( | ||
| config::FwdConfig, | ||
| ::Const{typeof(pmcmc_matmul)}, |
There was a problem hiding this comment.
Why not just overload the matmul for gpu arrays within enzyme.jl proper as an extension?
There was a problem hiding this comment.
Thanks for the review @wsmoses. Just for clarity are you suggesting i make these changes upstream, at least for matmul? If so I'm happy to do that and can open a PR today?
I'll probably integrate your other two points and include this extension here until upstream merging. What do you think?
There was a problem hiding this comment.
Pull request overview
This PR overhauls the GPU + AD backend integration by routing AD through DifferentiationInterface and making Enzyme/Mooncake/ForwardDiff optional weak dependencies (via extensions). It also adds GPU-focused Enzyme rules (via owned wrapper functions) to avoid Enzyme failures on common CUDA kernels, and expands the test/docs surface for GPU execution and AD-HVP fallbacks.
Changes:
- Make
DifferentiationInterfacethe unified AD entry point and require an explicitbackendforParallelMALASampler. - Add Enzyme extension rules for owned wrappers (
pmcmc_matmul/pmcmc_dot/pmcmc_dotsum) to make GPU AD-HVP paths Enzyme-safe. - Add GPU AD-HVP/performance tests and new GPU documentation (limitations + worked example).
Reviewed changes
Copilot reviewed 33 out of 35 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
Project.toml |
Moves AD packages to [weakdeps] and wires extensions/compat for modular AD backends. |
src/ParallelMCMC.jl |
Adds and exports owned wrapper functions for matmul/dot/reduction to support backend-specific AD rules. |
src/interface.jl |
Makes backend a required keyword for ParallelMALASampler and updates DEER-rec builder to use strategy-dispatched AD-HVP factories. |
src/DEER/DEER.jl |
Removes hard-coded Enzyme defaults; adds HVP strategy dispatch (forward-on-grad vs reverse-on-grad) and backend normalization hooks. |
src/DEER/DEERScan.jl |
Minor comment/style update in scan implementation. |
src/MALA/MALA.jl |
Tweaks internal quadratic forms and refactors JVP scalar computations. |
ext/EnzymeExt.jl |
Adds native Enzyme rules for pmcmc_* wrappers and backend normalization for Enzyme HVP paths. |
ext/DynamicPPLExt.jl |
Updates Turing/DynamicPPL DensityModel convenience constructor defaulting to AutoForwardDiff(). |
test/test-Owned-Matmul.jl |
New tests validating owned wrappers and Enzyme rules (forward + reverse + Const-arg branches). |
test/test-GPU-AD-HVP.jl |
New tests for GPU AD-HVP fallback across Enzyme/Mooncake/Zygote backends. |
test/test-GPU-Performance.jl |
New GPU-vs-CPU performance regression sanity test. |
test/test-Turing-Integration.jl |
Expands DynamicPPL/Turing integration tests and passes explicit sampler backend. |
test/test-DEER-Interface.jl |
Updates DEER interface tests for required sampler backend. |
test/test-DEER-Turing-Logistic.jl |
Updates logistic regression tests for explicit backend and adjusts tolerances/params for stability. |
test/test-Deer-vs-MALA.jl |
Updates internal AD-HVP calls to use explicit AutoEnzyme() backend. |
test/test-Jacobian-Estimator.jl |
Updates tests to avoid removed DEFAULT_BACKEND and use explicit backend. |
test/test-MALA-Kernel.jl |
Formatting-only refactors in MALA kernel tests. |
test/test-GPU-MALA.jl |
Improves determinism (CUDA.seed!) and adjusts stationary mean tolerance. |
test/test-Adaptive-MALA.jl |
Comment formatting change. |
test/test-Code-Quality.jl |
Enables Aqua + JET checks in the test suite. |
docs/src/10-getting-started.md |
Updates getting-started narrative and points to the new GPU execution page. |
docs/src/15-gpu.md |
New GPU execution page covering limitations, backend choices, and a worked logistic regression example. |
docs/src/assets/make_julia_deer_gif.jl |
Formatting-only refactor in docs asset script. |
benchmarks/ParallelMCMCBenchmarks/Project.toml |
Adds Enzyme dependency and normalizes [sources] layout. |
benchmarks/ParallelMCMCBenchmarks/src/pr_suite.jl |
Updates benchmark suite to pass explicit sampler backend. |
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_logreg.jl |
Comment formatting tweaks in benchmark model. |
benchmarks/ParallelMCMCBenchmarks/src/models/bayes_linreg.jl |
Comment formatting tweaks in benchmark model. |
benchmarks/ParallelMCMCBenchmarks/scripts/profile_deer_logreg_components.jl |
Updates profiling script to pass explicit backend in DEER rec + sampler. |
benchmarks/ParallelMCMCBenchmarks/scripts/prof_view.jl |
Updates helper default backend to AutoEnzyme(). |
benchmarks/ParallelMCMCBenchmarks/scripts/new_bench.jl |
Updates helper default backend to AutoEnzyme(). |
benchmarks/ParallelMCMCBenchmarks/scripts/bench_mala_bayes.jl |
Updates DEER sampler construction to pass explicit backend. |
benchmarks/ParallelMCMCBenchmarks/scripts/bench_deer_logreg.jl |
Updates GPU DEER benchmarking to pass explicit backend. |
benchmarks/ParallelMCMCBenchmarks/scripts/pr_benchmarks.jl |
Formatting-only refactor. |
benchmarks/ParallelMCMCBenchmarks/scripts/compare_pr_benchmarks.jl |
Formatting-only refactor. |
.gitignore |
Ignores local debugging script directories (/scripts, /dev). |
Comments suppressed due to low confidence (1)
ext/DynamicPPLExt.jl:18
- The docstring says only
DynamicPPLmust be loaded, but the extension is configured to load only whenDynamicPPL,ForwardDiff, andLogDensityProblemsare all loaded (Project.toml). Either update the docstring to reflect the actual requirement (and mention thatForwardDiffis needed for the defaultAutoForwardDiff()), or loosen the extension deps and add a runtime check/error whenAutoForwardDiff()is selected withoutForwardDiffloaded.
Convenience constructor: wraps a DynamicPPL/Turing `@model` directly as a
`DensityModel`, automatically extracting parameter names and wiring up gradient
computation via DynamicPPL's `adtype` interface.
Requires `DynamicPPL` to be loaded.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
Hi @penelopeysm -- could I get a quick review on some of the statements I wrote in the docs about DynamicPPL? I want to make it clear that ParallelMCMC on the GPU does not play nicely with Turing because, as far as I'm aware, DynamicPPl is not set up at the moment to handle GPU bound models. I think this is in the works, but you'd know better than I. If its incorrect in anyway please lmk! Thanks! |
|
That is reasonable to me, fwiw at one point when I was still on the project I was told that GPU compatibility was not a priority |
|
Hmm okay. all the more reason to integrate FlexiChains with first-class support I guess |
This branch lands the GPU + AD-backend overhaul on top of main. The big themes:
Modular AD via DifferentiationInterface
EnzymeExt: GPU-safe Enzyme rules
Tests
Needed Changes Prior to Merging
Resolves #29 and provides a workaround to #25