Skip to content

Commit

Permalink
Port openlibm log1pf as log1p (#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotlampr authored Aug 15, 2023
1 parent a562261 commit e9efcff
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 0 deletions.
105 changes: 105 additions & 0 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Math function mappings to Metal intrinsics

using Base: FastMath
using Base.Math: throw_complex_domainerror

# TODO:
# - wrap all intrinsics from include/metal/metal_math
Expand Down Expand Up @@ -101,6 +102,110 @@ using Base: FastMath
@device_override Base.log10(x::Float32) = ccall("extern air.log10.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.log10(x::Float16) = ccall("extern air.log10.f16", llvmcall, Float16, (Float16,), x)

# Implementation of `log1p(::Float32)` from openlibm's `log1pf`
# https://github.com/JuliaMath/openlibm
const ln2_hi = 0.6931381f0
const ln2_lo = 9.058001f-6
const Lp1 = 0.6666667f0
const Lp2 = 0.4f0
const Lp3 = 0.2857143f0
const Lp4 = 0.22222199f0
const Lp5 = 0.18183573f0
const Lp6 = 0.15313838f0
const Lp7 = 0.14798199f0

@device_override function Base.Math.log1p(x::Float32)
hx = reinterpret(Int32, x)
ax = hx & 0x7fffffff # |x|

k = 1
if hx < 0x3ed413d0 # x < sqrt(2) - 1
if ax >= 0x3f800000 # |x| ≥ 1
if x == -1
return -Inf32
elseif isnan(x)
return NaN32
else # x < -1
# TODO: switch to throw_complex_domainerror_neg1 for next Julia release
throw_complex_domainerror(:log1p, x)
end
end

if ax < 0x38000000 # |x| < 2^-15
if ax < 0x33800000 # |x| < 2^-24
return x # Inexact
else
return x - x*x*0.5f0
end
end

if hx>0||hx<=reinterpret(Int32, 0xbe95f619) # (sqrt(2)/2)-1 <= x
k = 0
f = x
hu = 1f0
end
end # hx < 0x3ed413d0

if hx >= 0x7f800000
return x+x
end

if k 0
if hx < 0x5a000000
u = 1f0 + x
hu = reinterpret(Int32, u)
k = (hu>>23) - 127
c = k>0 ? 1f0-(u-x) : x-(u-1f0)
c /= u
else
u = x
hu = reinterpret(Int32, u)
k = (hu>>23) - 127
c = 0f0
end

hu &= 0x007fffff

if hu < 0x3504f4 # u < sqrt(2)
u = reinterpret(Float32, hu|0x3f800000)
else
k += 1
u = reinterpret(Float32, hu|0x3f000000)
hu = (0x00800000-hu)>>2
end
f = u-1f0
end

hfsq = 0.5f0*f*f

if hu == 0 # |f| < 2^-20
if f == 0
if k == 0
return 0f0
else
c += k*ln2_lo
return k*ln2_hi+c
end
end
R = hfsq*(1f0-Lp1*f)
if k == 0
return f-R
else
return k*ln2_hi - ((R-(k*ln2_lo+c))-f)
end
end

s = f/(2f0+f)
z = s*s
R = z*(Lp1+z*(Lp2+z*(Lp3+z*(Lp4+z*(Lp5+z*(Lp6+z*Lp7))))))
if k == 0
return f-(hfsq-s*(hfsq+R))
else
return k*ln2_hi-((hfsq-(s*(hfsq+R)+(k*ln2_lo+c)))-f)
end
end


@device_override FastMath.pow_fast(x::Float32, y::Float32) = ccall("extern air.fast_pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y)
Expand Down
5 changes: 5 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ end
synchronize()
@test vecA sin.(a)
@test vecB cos.(a)

b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
bufferC = MtlArray(b)
vecC = Array(log1p.(bufferC))
@test vecC log1p.(b)
end

############################################################################################
Expand Down

0 comments on commit e9efcff

Please sign in to comment.