Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap and test some more Float16 intrinsics #2644

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
112 changes: 103 additions & 9 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

using Base: FastMath


## helpers

within(lower, upper) = (val) -> lower <= val <= upper
Expand Down Expand Up @@ -83,9 +82,13 @@ end
@device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x)

# TODO: enable once PTX > 7.0 is supported
# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)

@device_override function Base.tanh(x::Float16)
if compute_capability() >= sv"7.5"
@asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x)
else
Float16(tanh(Float32(x)))
end
end

## inverse hyperbolic

Expand All @@ -103,17 +106,58 @@ end

@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log(f)
r = Float16(f)

# handle degenrate cases
r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r)
r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r)
r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r)
r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r)

return r
end

@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log10(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log10(f)
r = Float16(f)

# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a comment that this is just:

r = h == reinterpret(Float16, 0x338F) ? reinterpret(Float16, 0x1000) : r

But FMA is the fastest way to do that xD, and keep your flops number up.

r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r)
r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r)
r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r)

return r
end
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.log2(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath log2(f)
r = Float16(f)

# handle degenerate cases
r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r)
r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r)

return r
end
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x)

@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x)
Expand All @@ -127,16 +171,65 @@ end

@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.exp(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, log2(Float32(ℯ)), -0.0f0)
f = @fastmath exp2(f)
r = Float16(f)

# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r)
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r)
Comment on lines +182 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two cause us to disagree with Julia.

julia> findall(==(0), exp.(all_float_16) .== Array(exp.(CuArray(all_float_16))))
2-element Vector{Int64}:
 8058
 9680

julia> all_float_16[8058]
Float16(0.007298)

julia> all_float_16[9680]
Float16(0.02269)

julia> reinterpret(UInt16, all_float_16[8058]
       )
0x1f79

julia> reinterpret(UInt16, all_float_16[9680])
0x25cf

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> Float16(exp(Float32(all_float_16[8058])))
Float16(1.008)

julia> exp_cu[8058]
Float16(1.007)

julia> exp_cu[8058] - Float16(exp(Float32(all_float_16[8058])))
Float16(-0.000977)

r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r)
r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r)

return r
end
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x)
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x)
# TODO: enable once PTX > 7.0 is supported
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
@device_override function Base.exp2(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = @fastmath exp2(f)

# one ULP adjustement
f = muladd(f, reinterpret(Float32, 0x33800000), f)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an fma?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, I can try to figure out what that magic number is too

Copy link
Member

@vchuravy vchuravy Mar 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that magic number is just a relative eps

julia> reinterpret(UInt32, eps(0.5f0))
0x33800000

Don't know why they choose 0.5f0 though.

r = Float16(f)

return r
end
@device_override FastMath.exp2_fast(x::Float64) = exp2(x)
@device_override FastMath.exp2_fast(x::Float32) =
@asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x)
@device_override function FastMath.exp2_fast(x::Float16)
if compute_capability() >= sv"7.5"
ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x)
else
exp2(x)
end
end

@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x)
@device_override function Base.exp10(h::Float16)
# perform computation in Float32 domain
f = Float32(h)
f = fma(f, log2(10.f0), -0.0f0)
f = @fastmath exp2(f)
r = Float16(f)

# handle degenerate cases
r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r)
r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r)
r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r)
r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r)
r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r)

return r
end
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x)

@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x)
Expand Down Expand Up @@ -204,6 +297,7 @@ end

@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0
@device_override Base.isnan(x::Float16) = isnan(Float32(x))

@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x)
@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x)
Expand All @@ -223,14 +317,14 @@ end
@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x)
@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f)
@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f)
# TODO: enable once PTX > 7.0 is supported
# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x)
@device_override Base.abs(f::Float16) = Float16(abs(Float32(f)))
@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x)

## roots and powers

@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x)
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sqrt(x::Float16) = Float16(sqrt(Float32(x)))
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x)

@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x)
Expand Down
34 changes: 28 additions & 6 deletions test/core/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ using SpecialFunctions

@testset "math" begin
@testset "log10" begin
@test testf(a->log10.(a), Float32[100])
for T in (Float32, Float64)
@test testf(a->log10.(a), T[100])
end
end

@testset "pow" begin
Expand All @@ -12,28 +14,34 @@ using SpecialFunctions
@test testf((x,y)->x.^y, rand(Float32, 1), -rand(range, 1))
end
end

@testset "min/max" begin
for T in (Float32, Float64)
@test testf((x,y)->max.(x, y), rand(Float32, 1), rand(T, 1))
@test testf((x,y)->min.(x, y), rand(Float32, 1), rand(T, 1))
end
end

@testset "isinf" begin
for x in (Inf32, Inf, NaN32, NaN)
for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isinf.(x), [x])
end
end

@testset "isnan" begin
for x in (Inf32, Inf, NaN32, NaN)
for x in (Inf32, Inf, NaN16, NaN32, NaN)
@test testf(x->isnan.(x), [x])
end
end

for op in (exp, angle, exp2, exp10,)
@testset "$op" begin
for T in (Float16, Float32, Float64)
for T in (Float32, Float64)
@test testf(x->op.(x), rand(T, 1))
@test testf(x->op.(x), -rand(T, 1))
end
end
end

for op in (expm1,)
@testset "$op" begin
# FIXME: add expm1(::Float16) to Base
Expand All @@ -50,7 +58,6 @@ using SpecialFunctions
@test testf(x->op.(x), rand(T, 1))
@test testf(x->op.(x), -rand(T, 1))
end

end
end
@testset "mod and rem" begin
Expand Down Expand Up @@ -97,6 +104,21 @@ using SpecialFunctions
# JuliaGPU/CUDA.jl#1085: exp uses Base.sincos performing a global CPU load
@test testf(x->exp.(x), [1e7im])
end

@testset "Real - $op" for op in (abs, abs2, exp, exp10, log, log10)
@testset "$T" for T in (Float16, Float32, Float64)
@test testf(x->op.(x), rand(T, 1))
end
end

@testset "Float16 - $op" for op in (exp,exp2,exp10,log,log2,log10)
all_float_16 = collect(reinterpret(Float16, pattern) for pattern in UInt16(0):UInt16(1):typemax(UInt16))
all_float_16 = filter(!isnan, all_float_16)
if op in (log, log2, log10)
all_float_16 = filter(>=(0), all_float_16)
end
@test testf(x->map(op, x), all_float_16)
end

@testset "fastmath" begin
# libdevice provides some fast math functions
Expand Down