Following up on this comment from @wsmoses on rsenne/ParallelMCMC.jl#32 -- I'd like to upstream the GPU matmul EnzymeRules I wrote there.
Plan: a CUDA package extension registering forward / augmented_primal / reverse for Base.* on CuArray / CuMatrix / CuVector (plus transpose/adjoint variants). Rules compute primal and cotangents with plain * so the cuBLAS call stays opaque to Enzyme -- sidesteps the gc-transition abort during LLVM lowering. Width-1 to start.
Reference implementation: ParallelMCMC.jl/ext/EnzymeExt.jl.
OK to open a PR?
Following up on this comment from @wsmoses on rsenne/ParallelMCMC.jl#32 -- I'd like to upstream the GPU matmul EnzymeRules I wrote there.
Plan: a CUDA package extension registering forward / augmented_primal / reverse for Base.* on CuArray / CuMatrix / CuVector (plus transpose/adjoint variants). Rules compute primal and cotangents with plain * so the cuBLAS call stays opaque to Enzyme -- sidesteps the gc-transition abort during LLVM lowering. Width-1 to start.
Reference implementation: ParallelMCMC.jl/ext/EnzymeExt.jl.
OK to open a PR?