Skip to content

Commit

Permalink
Fix argument types in sincos (#232)
Browse files Browse the repository at this point in the history
From the Apple developer manual, the sincos functions return the cosine
part via a reference argument `&T`. Changed the overrides to now return
both the sin and cosine components as a tuple.
  • Loading branch information
fjebaker authored Aug 2, 2023
1 parent d6f958c commit df0b348
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
18 changes: 15 additions & 3 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,21 @@ using Base: FastMath
@device_override Base.sin(x::Float32) = ccall("extern air.sin.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sin(x::Float16) = ccall("extern air.sin.f16", llvmcall, Float16, (Float16,), x)

@device_override FastMath.sincos_fast(x::Float32) = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sincos(x::Float32) = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sincos(x::Float16) = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16,), x)
@device_override function FastMath.sincos_fast(x::Float32)
c = Ref{Cfloat}()
s = ccall("extern air.fast_sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
(s, c[])
end
@device_override function Base.sincos(x::Float32)
c = Ref{Cfloat}()
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
(s, c[])
end
@device_override function Base.sincos(x::Float16)
c = Ref{Float16}()
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
(s, c[])
end

@device_override FastMath.sinh_fast(x::Float32) = ccall("extern air.fast_sinh.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sinh(x::Float32) = ccall("extern air.sinh.f32", llvmcall, Cfloat, (Cfloat,), x)
Expand Down
17 changes: 17 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,23 @@ end
return nothing
end
@metal intr_test2(bufferA)
synchronize()

bufferB = MtlArray{eltype(a),length(size(a)),Shared}(a)
vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1)

function intr_test3(arr_sin, arr_cos)
idx = thread_position_in_grid_1d()
s, c = sincos(arr_cos[idx])
arr_sin[idx] = s
arr_cos[idx] = c
return nothing
end

@metal intr_test3(bufferA, bufferB)
synchronize()
@test vecA sin.(a)
@test vecB cos.(a)
end

############################################################################################
Expand Down

0 comments on commit df0b348

Please sign in to comment.