-
Notifications
You must be signed in to change notification settings - Fork 238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrap and test some more Float16 intrinsics #2644
base: master
Are you sure you want to change the base?
Changes from all commits
b905e0d
7dfff4c
736d836
edbbe73
6feddcf
94f3790
74961f9
b609b21
c3c1530
a8786b7
9acefb7
ebe2558
ed8ee84
9158c15
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,6 @@ | |
|
||
using Base: FastMath | ||
|
||
|
||
## helpers | ||
|
||
within(lower, upper) = (val) -> lower <= val <= upper | ||
|
@@ -83,9 +82,13 @@ end | |
@device_override Base.tanh(x::Float64) = ccall("extern __nv_tanh", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.tanh(x::Float32) = ccall("extern __nv_tanhf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
# TODO: enable once PTX > 7.0 is supported | ||
# @device_override Base.tanh(x::Float16) = @asmcall("tanh.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) | ||
|
||
@device_override function Base.tanh(x::Float16) | ||
if compute_capability() >= sv"7.5" | ||
@asmcall("tanh.approx.f16 \$0, \$1;", "=r,r", Float16, Tuple{Float16}, x) | ||
else | ||
Float16(tanh(Float32(x))) | ||
end | ||
end | ||
|
||
## inverse hyperbolic | ||
|
||
|
@@ -103,17 +106,58 @@ end | |
|
||
@device_override Base.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override function Base.log(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = @fastmath log(f) | ||
r = Float16(f) | ||
|
||
# handle degenrate cases | ||
r = fma(Float16(h == reinterpret(Float16, 0x160D)), reinterpret(Float16, 0x9C00), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x3BFE)), reinterpret(Float16, 0x8010), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x3C0B)), reinterpret(Float16, 0x8080), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x6051)), reinterpret(Float16, 0x1C00), r) | ||
|
||
return r | ||
end | ||
|
||
@device_override FastMath.log_fast(x::Float32) = ccall("extern __nv_fast_logf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_override Base.log10(x::Float64) = ccall("extern __nv_log10", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.log10(x::Float32) = ccall("extern __nv_log10f", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override function Base.log10(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = @fastmath log10(f) | ||
r = Float16(f) | ||
|
||
# handle degenerate cases | ||
r = fma(Float16(h == reinterpret(Float16, 0x338F)), reinterpret(Float16, 0x1000), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x33F8)), reinterpret(Float16, 0x9000), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x57E1)), reinterpret(Float16, 0x9800), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x719D)), reinterpret(Float16, 0x9C00), r) | ||
|
||
return r | ||
end | ||
@device_override FastMath.log10_fast(x::Float32) = ccall("extern __nv_fast_log10f", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_override Base.log1p(x::Float64) = ccall("extern __nv_log1p", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.log1p(x::Float32) = ccall("extern __nv_log1pf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_override Base.log2(x::Float64) = ccall("extern __nv_log2", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.log2(x::Float32) = ccall("extern __nv_log2f", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override function Base.log2(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = @fastmath log2(f) | ||
r = Float16(f) | ||
|
||
# handle degenerate cases | ||
r = fma(Float16(r == reinterpret(Float16, 0xA2E2)), reinterpret(Float16, 0x8080), r) | ||
r = fma(Float16(r == reinterpret(Float16, 0xBF46)), reinterpret(Float16, 0x9400), r) | ||
|
||
return r | ||
end | ||
@device_override FastMath.log2_fast(x::Float32) = ccall("extern __nv_fast_log2f", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_function logb(x::Float64) = ccall("extern __nv_logb", llvmcall, Cdouble, (Cdouble,), x) | ||
|
@@ -127,16 +171,65 @@ end | |
|
||
@device_override Base.exp(x::Float64) = ccall("extern __nv_exp", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.exp(x::Float32) = ccall("extern __nv_expf", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override function Base.exp(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = fma(f, log2(Float32(ℯ)), -0.0f0) | ||
f = @fastmath exp2(f) | ||
r = Float16(f) | ||
|
||
# handle degenerate cases | ||
r = fma(Float16(h == reinterpret(Float16, 0x1F79)), reinterpret(Float16, 0x9400), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x25CF)), reinterpret(Float16, 0x9400), r) | ||
Comment on lines
+182
to
+183
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two cause us to disagree with Julia.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
r = fma(Float16(h == reinterpret(Float16, 0xC13B)), reinterpret(Float16, 0x0400), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0xC1EF)), reinterpret(Float16, 0x0200), r) | ||
vchuravy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return r | ||
end | ||
@device_override FastMath.exp_fast(x::Float32) = ccall("extern __nv_fast_expf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_override Base.exp2(x::Float64) = ccall("extern __nv_exp2", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.exp2(x::Float32) = ccall("extern __nv_exp2f", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override FastMath.exp2_fast(x::Union{Float32, Float64}) = exp2(x) | ||
# TODO: enable once PTX > 7.0 is supported | ||
# @device_override Base.exp2(x::Float16) = @asmcall("ex2.approx.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) | ||
@device_override function Base.exp2(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = @fastmath exp2(f) | ||
|
||
# one ULP adjustement | ||
f = muladd(f, reinterpret(Float32, 0x33800000), f) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably, I can try to figure out what that magic number is too There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that magic number is just a relative
Don't know why they choose |
||
r = Float16(f) | ||
|
||
return r | ||
end | ||
@device_override FastMath.exp2_fast(x::Float64) = exp2(x) | ||
@device_override FastMath.exp2_fast(x::Float32) = | ||
@asmcall("ex2.approx.f32 \$0, \$1;", "=r,r", Float32, Tuple{Float32}, x) | ||
vchuravy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@device_override function FastMath.exp2_fast(x::Float16) | ||
if compute_capability() >= sv"7.5" | ||
ccall("llvm.nvvm.ex2.approx.f16", llvmcall, Float16, (Float16,), x) | ||
else | ||
exp2(x) | ||
end | ||
end | ||
|
||
@device_override Base.exp10(x::Float64) = ccall("extern __nv_exp10", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.exp10(x::Float32) = ccall("extern __nv_exp10f", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override function Base.exp10(h::Float16) | ||
# perform computation in Float32 domain | ||
f = Float32(h) | ||
f = fma(f, log2(10.f0), -0.0f0) | ||
f = @fastmath exp2(f) | ||
r = Float16(f) | ||
|
||
# handle degenerate cases | ||
r = fma(Float16(h == reinterpret(Float16, 0x34DE)), reinterpret(Float16, 0x9800), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x9766)), reinterpret(Float16, 0x9000), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0x9972)), reinterpret(Float16, 0x1000), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0xA5C4)), reinterpret(Float16, 0x1000), r) | ||
r = fma(Float16(h == reinterpret(Float16, 0xBF0A)), reinterpret(Float16, 0x8100), r) | ||
|
||
return r | ||
end | ||
@device_override FastMath.exp10_fast(x::Float32) = ccall("extern __nv_fast_exp10f", llvmcall, Cfloat, (Cfloat,), x) | ||
|
||
@device_override Base.expm1(x::Float64) = ccall("extern __nv_expm1", llvmcall, Cdouble, (Cdouble,), x) | ||
|
@@ -204,6 +297,7 @@ end | |
|
||
@device_override Base.isnan(x::Float64) = (ccall("extern __nv_isnand", llvmcall, Int32, (Cdouble,), x)) != 0 | ||
@device_override Base.isnan(x::Float32) = (ccall("extern __nv_isnanf", llvmcall, Int32, (Cfloat,), x)) != 0 | ||
@device_override Base.isnan(x::Float16) = isnan(Float32(x)) | ||
|
||
@device_function nearbyint(x::Float64) = ccall("extern __nv_nearbyint", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_function nearbyint(x::Float32) = ccall("extern __nv_nearbyintf", llvmcall, Cfloat, (Cfloat,), x) | ||
|
@@ -223,14 +317,14 @@ end | |
@device_override Base.abs(x::Int32) = ccall("extern __nv_abs", llvmcall, Int32, (Int32,), x) | ||
@device_override Base.abs(f::Float64) = ccall("extern __nv_fabs", llvmcall, Cdouble, (Cdouble,), f) | ||
@device_override Base.abs(f::Float32) = ccall("extern __nv_fabsf", llvmcall, Cfloat, (Cfloat,), f) | ||
# TODO: enable once PTX > 7.0 is supported | ||
# @device_override Base.abs(x::Float16) = @asmcall("abs.f16 \$0, \$1", "=h,h", Float16, Tuple{Float16}, x) | ||
@device_override Base.abs(f::Float16) = Float16(abs(Float32(f))) | ||
@device_override Base.abs(x::Int64) = ccall("extern __nv_llabs", llvmcall, Int64, (Int64,), x) | ||
|
||
## roots and powers | ||
|
||
@device_override Base.sqrt(x::Float64) = ccall("extern __nv_sqrt", llvmcall, Cdouble, (Cdouble,), x) | ||
@device_override Base.sqrt(x::Float32) = ccall("extern __nv_sqrtf", llvmcall, Cfloat, (Cfloat,), x) | ||
@device_override Base.sqrt(x::Float16) = Float16(sqrt(Float32(x))) | ||
@device_override FastMath.sqrt_fast(x::Union{Float32, Float64}) = sqrt(x) | ||
|
||
@device_function rsqrt(x::Float64) = ccall("extern __nv_rsqrt", llvmcall, Cdouble, (Cdouble,), x) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a comment that this is just:
But FMA is the fastest way to do that xD, and keep your flops number up.