Skip to content

Commit

Permalink
Wrap and test some more Float16 intrinsics
Browse files Browse the repository at this point in the history
  • Loading branch information
kshyatt committed Feb 6, 2025
1 parent 4d85f27 commit 5c1db19
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 10 deletions.
81 changes: 77 additions & 4 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,38 @@ end

@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hlog", llvmcall, Float16, (Float16,), x)
else
return Float16(log(Float32(x)))
end
end
@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log10(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hlog10", llvmcall, Float16, (Float16,), x)
else
return Float16(log10(Float32(x)))
end
end
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log2(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hlog2", llvmcall, Float16, (Float16,), x)
else
return Float16(log(Float32(x)))
end
end
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x)

@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x)
Expand All @@ -127,16 +148,35 @@ end

@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.exp(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hexp", llvmcall, Float16, (Float16,), x)
else
return Float16(exp(Float32(x)))
end
end
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.exp2(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hexp2", llvmcall, Float16, (Float16,), x)
else
return Float16(exp2(Float32(x)))
end
end
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
# TODO: enable once PTX > 7.0 is supported
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)

@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.exp10(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hexp10", llvmcall, Float16, (Float16,), x)
else
return Float16(exp10(Float32(x)))
end
end
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x)
Expand Down Expand Up @@ -204,6 +244,13 @@ end

@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
@device_override function Base.isnan(x::Float16)
if compute_capability() >= sv"8.0"
return (ccall("extern __nv_hisnan", llvmcall, Int32, (Float16,), x)) != 0
else
return isnan(Float32(x))
end
end

@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x)
@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x)
Expand All @@ -223,14 +270,26 @@ end
@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x)
@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f)
@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f)
# TODO: enable once PTX > 7.0 is supported
# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
@device_override function Base.abs(f::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_habs", llvmcall, Float16, (Float16,), f)
else
return Float16(abs(Float32(f)))
end
end
@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x)

## roots and powers

@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.sqrt(x::Float16)
if compute_capability() >= sv"8.0"
ccall("extern __nv_hsqrt", llvmcall, Float16, (Float16,), x)
else
return Float16(sqrt(Float32(x)))
end
end
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)

@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)
Expand Down Expand Up @@ -295,6 +354,13 @@ end
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
#@device_override Base.min(x::Float64, y::Float64) = ccall("extern __nv_fmin", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
#@device_override Base.min(x::Float32, y::Float32) = ccall("extern __nv_fminf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override @inline function Base.min(x::Float16, y::Float16)
if compute_capability() >= sv"8.0"
return ccall("extern __nv_hmin", llvmcall, Float16, (Float16, Float16), x, y)
else
return Float16(min(Float32(x), Float32(y)))
end
end
@device_override @inline function Base.min(x::Float32, y::Float32)
if @static LLVM.version() < v"14" ? false : (compute_capability() >= sv"8.0")
# LLVM 14+ can do the right thing, but only on sm_80+
Expand All @@ -321,6 +387,13 @@ end
# JuliaGPU/CUDA.jl#2111: fmin semantics wrt. NaN don't match Julia's
#@device_override Base.max(x::Float64, y::Float64) = ccall("extern __nv_fmax", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
#@device_override Base.max(x::Float32, y::Float32) = ccall("extern __nv_fmaxf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override @inline function Base.max(x::Float16, y::Float16)
if compute_capability() >= sv"8.0"
return ccall("extern __nv_hmax", llvmcall, Float16, (Float16, Float16), x, y)
else
return Float16(max(Float32(x), Float32(y)))
end
end
@device_override @inline function Base.max(x::Float32, y::Float32)
if @static LLVM.version() < v"14" ? false : (compute_capability() >= sv"8.0")
# LLVM 14+ can do the right thing, but only on sm_80+
Expand Down
27 changes: 21 additions & 6 deletions test/core/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ using SpecialFunctions

@testset "math" begin
@testset "log10" begin
@test testf(a->log10.(a), Float32[100])
for T in (Float16, Float32, Float64)
@test testf(a->log10.(a), T[100])
end
end

@testset "pow" begin
Expand All @@ -12,15 +14,22 @@ using SpecialFunctions
@test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
end
end

@testset "min/max" begin
for T in (Float16, Float32, Float64)
@test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
@test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
end
end

@testset "isinf" begin
for x in (Inf32, Inf, NaN32, NaN)
for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isinf.(x), [x])
end
end

@testset "isnan" begin
for x in (Inf32, Inf, NaN32, NaN)
for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isnan.(x), [x])
end
end
Expand All @@ -33,7 +42,6 @@ using SpecialFunctions
end
end
end

for op in (expm1,)
@testset "$op" begin
# FIXME: add expm1(::Float16) to Base
Expand All @@ -50,7 +58,6 @@ using SpecialFunctions
@test testf(x->op.(x), rand(T, 1))
@test testf(x->op.(x), -rand(T, 1))
end

end
end
@testset "mod and rem" begin
Expand Down Expand Up @@ -97,6 +104,14 @@ using SpecialFunctions
# JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
@test testf(x->exp.(x), [1e7im])
end

for op in (exp, abs, abs2, log, exp10, log10)
@testset "Real - $op" begin
for T in (Float16, Float32, Float64)
@test testf(x->op.(x), rand(T, 1))
end
end
end

@testset "fastmath" begin
# libdevice provides some fast math functions
Expand Down Expand Up @@ -150,7 +165,7 @@ using SpecialFunctions
end

@testset "JuliaGPU/CUDA.jl#2111: min/max should return NaN" begin
for T in [Float32, Float64]
for T in [Float16, Float32, Float64]
AT = CuArray{T}
@test isequal(Array(min.(AT([NaN]), AT([Inf]))), [NaN])
@test isequal(minimum(AT([NaN])), NaN)
Expand Down

0 comments on commit 5c1db19

Please sign in to comment.