Skip to content

GPU matmul rules in Enzyme CUDA extension #3122

@rsenne

Description

@rsenne

Following up on this comment from @wsmoses on rsenne/ParallelMCMC.jl#32 -- I'd like to upstream the GPU matmul EnzymeRules I wrote there.

Plan: a CUDA package extension registering forward / augmented_primal / reverse for Base.* on CuArray / CuMatrix / CuVector (plus transpose/adjoint variants). Rules compute primal and cotangents with plain * so the cuBLAS call stays opaque to Enzyme -- sidesteps the gc-transition abort during LLVM lowering. Width-1 to start.

Reference implementation: ParallelMCMC.jl/ext/EnzymeExt.jl.

OK to open a PR?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions