Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implementation of Tang (1990) log procedures in pure float32 #236

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Math function mappings to Metal intrinsics

using Base: FastMath
using Base.Math: @horner

# TODO:
# - wrap all intrinsics from include/metal/metal_math
Expand Down Expand Up @@ -101,6 +102,77 @@ using Base: FastMath
@device_override Base.log10(x::Float32) = ccall("extern air.log10.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.log10(x::Float16) = ccall("extern air.log10.f16", llvmcall, Float16, (Float16,), x)

# Re-implementation of Base.Math.log_proc1 and Base.Math.log_proc2 without upcasting to Float64
const t_log_Float32 = [0.0f0, 0.0077821403f0, 0.015504187f0, 0.023167059f0,
0.030771658f0, 0.038318865f0, 0.045809537f0, 0.053244516f0, 0.06062462f0, 0.06795066f0,
0.07522342f0, 0.08244367f0, 0.089612156f0, 0.09672963f0, 0.103796795f0, 0.11081436f0,
0.11778303f0, 0.12470348f0, 0.13157636f0, 0.13840233f0, 0.14518201f0, 0.15191604f0,
0.15860502f0, 0.16524957f0, 0.17185026f0, 0.17840765f0, 0.18492234f0, 0.19139485f0,
0.19782574f0, 0.20421554f0, 0.21056476f0, 0.21687394f0, 0.22314355f0, 0.2293741f0, 0.23556606f0,
0.24171993f0, 0.24783616f0, 0.25391522f0, 0.25995752f0, 0.26596355f0, 0.2719337f0, 0.27786845f0,
0.28376818f0, 0.2896333f0, 0.29546422f0, 0.30126134f0, 0.30702505f0, 0.3127557f0, 0.31845373f0,
0.32411948f0, 0.32975328f0, 0.33535555f0, 0.3409266f0, 0.34646678f0, 0.35197642f0, 0.35745588f0,
0.3629055f0, 0.36832556f0, 0.3737164f0, 0.37907836f0, 0.3844117f0, 0.38971674f0, 0.3949938f0,
0.40024316f0, 0.4054651f0, 0.41065994f0, 0.4158279f0, 0.4209693f0, 0.4260844f0, 0.43117347f0,
0.43623677f0, 0.44127455f0, 0.4462871f0, 0.45127463f0, 0.45623744f0, 0.4611757f0, 0.46608973f0,
0.47097972f0, 0.4758459f0, 0.48068854f0, 0.48550782f0, 0.490304f0, 0.49507725f0, 0.49982786f0,
0.504556f0, 0.5092619f0, 0.51394576f0, 0.51860774f0, 0.52324814f0, 0.5278671f0, 0.5324648f0,
0.5370415f0, 0.5415973f0, 0.54613245f0, 0.55064714f0, 0.5551415f0, 0.5596158f0, 0.56407017f0,
0.56850475f0, 0.5729197f0, 0.5773154f0, 0.58169174f0, 0.586049f0, 0.59038746f0, 0.59470713f0,
0.5990082f0, 0.60329086f0, 0.60755527f0, 0.61180156f0, 0.61602986f0, 0.6202404f0, 0.6244333f0,
0.62860864f0, 0.63276666f0, 0.63690746f0, 0.6410312f0, 0.64513797f0, 0.6492279f0, 0.6533013f0,
0.65735805f0, 0.6613985f0, 0.6654226f0, 0.6694307f0, 0.6734227f0, 0.6773988f0, 0.68135923f0,
0.685304f0, 0.6892333f0, 0.6931472f0]

@inline logb(::Val{2}) = 1.442695f0
@inline logb(::Val{:ℯ}) = 1f0
@inline logb(::Val{10}) = 0.4342945f0

@device_override Base.@assume_effects :consistent @inline function Base.Math.log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,base=Val(:ℯ))
jp = unsafe_trunc(Int,128.0f0*F)-127

## Steps 1 and 2
@inbounds hi = t_log_Float32[jp]
l = mf*0.6931472f0 + hi

## Step 3
# @inbounds u = f*c_invF[jp]
# q = u*u*@horner(u,
# Float32(-0x1.00006p-1),
# Float32(0x1.55546cp-2))

## Step 3' (alternative)
u = (2f0f)/(y+F)
v = u*u
q = u*v*0.08333351f0

## Step 4
logb(base)*(l + (u + q))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, logb in base returns a double

end

@inline truncbits(x::Float32) = reinterpret(Float32, reinterpret(UInt32, x) & 0xfff0_0000)

@device_override @inline function Base.Math.log_proc2(f::Float32,base=Val(:ℯ))
## Step 1
g = 1f0/(2f0+f)
u = 2(f*g)
v = u*u

## Step 2
q = u*v*@horner(v,
0.08333332f0,
0.012512346f0)

## Step 3
u₁ = truncbits(u)
f₁ = truncbits(f)
f₂ = f-f₁
u₂ = ((2(f-u₁)-u₁*f₁) - u₁*f₂)*g

## Step 4
logb(base)*(u₁ + (u₂ + q))
end

@device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern air.fast_pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y)
Expand Down