Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/device/atomics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Atomic operation device overrides and fallbacks

# Fallback wrappers for Float32 atomic_inc!/atomic_dec!
# Intel Level Zero doesn't support these directly for floating-point types,
# so we implement them using atomic_add!/atomic_sub!

@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float32,AS}) where {AS}
SPIRVIntrinsics.atomic_add!(p, Float32(1))
end

@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float32,AS}) where {AS}
SPIRVIntrinsics.atomic_sub!(p, Float32(1))
end

# Float64 fallbacks (if Float64 is supported on device)
@device_override @inline function SPIRVIntrinsics.atomic_inc!(p::LLVMPtr{Float64,AS}) where {AS}
SPIRVIntrinsics.atomic_add!(p, Float64(1))
end

@device_override @inline function SPIRVIntrinsics.atomic_dec!(p::LLVMPtr{Float64,AS}) where {AS}
SPIRVIntrinsics.atomic_sub!(p, Float64(1))
end
1 change: 1 addition & 0 deletions src/oneAPI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Base.Experimental.@MethodTable(method_table)
include("device/runtime.jl")
include("device/array.jl")
include("device/quirks.jl")
include("device/atomics.jl")

# essential stuff
include("context.jl")
Expand Down
16 changes: 8 additions & 8 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ end

@testset "atomics (low level)" begin

@testset "atomic_add($T)" for T in [Int32, UInt32]
@testset "atomic_add($T)" for T in [Int32, UInt32, Float32]
a = oneArray([zero(T)])

function kernel(a, b)
Expand All @@ -288,7 +288,7 @@ end
@test Array(a)[1] == T(256)
end

@testset "atomic_sub($T)" for T in [Int32, UInt32]
@testset "atomic_sub($T)" for T in [Int32, UInt32, Float32]
a = oneArray([T(256)])

function kernel(a, b)
Expand All @@ -300,7 +300,7 @@ end
@test Array(a)[1] == T(0)
end

@testset "atomic_inc($T)" for T in [Int32, UInt32]
@testset "atomic_inc($T)" for T in [Int32, UInt32, Float32]
a = oneArray([zero(T)])

function kernel(a)
Expand All @@ -312,7 +312,7 @@ end
@test Array(a)[1] == T(256)
end

@testset "atomic_dec($T)" for T in [Int32, UInt32]
@testset "atomic_dec($T)" for T in [Int32, UInt32, Float32]
a = oneArray([T(256)])

function kernel(a)
Expand All @@ -324,25 +324,25 @@ end
@test Array(a)[1] == T(0)
end

@testset "atomic_min($T)" for T in [Int32, UInt32]
@testset "atomic_min($T)" for T in [Int32, UInt32, Float32]
a = oneArray([T(256)])

function kernel(a, T)
i = get_global_id()
oneAPI.atomic_min!(pointer(a), i%T)
oneAPI.atomic_min!(pointer(a), T(i))
return
end

@oneapi items=256 kernel(a, T)
@test Array(a)[1] == one(T)
end

@testset "atomic_max($T)" for T in [Int32, UInt32]
@testset "atomic_max($T)" for T in [Int32, UInt32, Float32]
a = oneArray([zero(T)])

function kernel(a, T)
i = get_global_id()
oneAPI.atomic_max!(pointer(a), i%T)
oneAPI.atomic_max!(pointer(a), T(i))
return
end

Expand Down
Loading