Skip to content

Commit 972f3f0

Browse files
authored
Enzyme: add make_zero of cuarrays (#2600)
1 parent aa28279 commit 972f3f0

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

.buildkite/pipeline.yml

+3-2
Original file line numberDiff line numberDiff line change
@@ -246,8 +246,9 @@ steps:
246246
build.message !~ /\[only/ && !build.pull_request.draft &&
247247
build.message !~ /\[skip tests\]/ &&
248248
build.message !~ /\[skip downstream\]/
249-
timeout_in_minutes: 30
250-
soft_fail: true
249+
timeout_in_minutes: 60
250+
soft_fail:
251+
- exit_status: 3
251252

252253
- group: ":eyes: Special"
253254
depends_on: "cuda"

ext/EnzymeCoreExt.jl

+53
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,59 @@ function EnzymeCore.EnzymeRules.noalias(::Type{CT}, ::UndefInitializer, args...)
555555
return nothing
556556
end
557557

558+
@inline function EnzymeCore.make_zero(
559+
x::DenseCuArray{FT},
560+
) where {FT<:AbstractFloat}
561+
return Base.zero(x)
562+
end
563+
@inline function EnzymeCore.make_zero(
564+
x::DenseCuArray{Complex{FT}},
565+
) where {FT<:AbstractFloat}
566+
return Base.zero(x)
567+
end
568+
569+
@inline function EnzymeCore.make_zero(
570+
::Type{CT},
571+
seen::IdDict,
572+
prev::CT,
573+
::Val{copy_if_inactive} = Val(false),
574+
)::CT where {copy_if_inactive, FT<:AbstractFloat, CT <: Union{DenseCuArray{FT},DenseCuArray{Complex{FT}}}}
575+
if haskey(seen, prev)
576+
return seen[prev]
577+
end
578+
newa = Base.zero(prev)
579+
seen[prev] = newa
580+
return newa
581+
end
582+
583+
@inline function EnzymeCore.make_zero!(
584+
prev::DenseCuArray{FT},
585+
seen::ST,
586+
)::Nothing where {FT<:AbstractFloat,ST}
587+
if !isnothing(seen)
588+
if prev in seen
589+
return nothing
590+
end
591+
push!(seen, prev)
592+
end
593+
fill!(prev, zero(FT))
594+
return nothing
595+
end
596+
597+
@inline function EnzymeCore.make_zero!(
598+
prev::DenseCuArray{Complex{FT}},
599+
seen::ST,
600+
)::Nothing where {FT<:AbstractFloat,ST}
601+
if !isnothing(seen)
602+
if prev in seen
603+
return nothing
604+
end
605+
push!(seen, prev)
606+
end
607+
fill!(prev, zero(Complex{FT}))
608+
return nothing
609+
end
610+
558611
function EnzymeCore.EnzymeRules.forward(config, ofn::Const{typeof(GPUArrays.mapreducedim!)},
559612
::Type{RT},
560613
f::EnzymeCore.Const{typeof(Base.identity)},

test/extensions/enzyme.jl

+9
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,15 @@ using CUDA
77
@test EnzymeCore.compiler_job_from_backend(CUDABackend(), typeof(()->nothing), Tuple{}) isa GPUCompiler.CompilerJob
88
end
99

10+
@testset "Make_zero" begin
11+
A = CUDA.ones(64)
12+
dA = Enzyme.make_zero(A)
13+
@test all(dA .≈ 0)
14+
dA = CUDA.ones(64)
15+
Enzyme.make_zero!(dA)
16+
@test all(dA .≈ 0)
17+
end
18+
1019
function square_kernel!(x)
1120
i = threadIdx().x
1221
x[i] *= x[i]

0 commit comments

Comments
 (0)