Skip to content

Commit

Permalink
Emulated fma (#42783)
Browse files Browse the repository at this point in the history
Emulated `fma` for cases when hardware fma is not available. Generally pre-Haswell, some arm, etc.

Co-authored-by: oscarddssmith <[email protected]>
  • Loading branch information
oscardssmith and oscardssmith committed Nov 4, 2021
1 parent e7df4a6 commit ee36c13
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 8 deletions.
74 changes: 66 additions & 8 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,30 +342,88 @@ significantly more expensive than `x*y+z`. `fma` is used to improve accuracy in
algorithms. See [`muladd`](@ref).
"""
function fma end
function fma_emulated(a::Float32, b::Float32, c::Float32)::Float32
ab = Float64(a) * b
res = ab+c
reinterpret(UInt64, res)&0x1fff_ffff!=0x1000_0000 && return res
# yes error compensation is necessary. It sucks
reslo = abs(c)>abs(ab) ? ab-(res - c) : c-(res - ab)
res = iszero(reslo) ? res : (signbit(reslo) ? prevfloat(res) : nextfloat(res))
return res
end

""" Splits a Float64 into a hi bit and a low bit where the high bit has 27 trailing 0s and the low bit has 26 trailing 0s"""
@inline function splitbits(x::Float64)
hi = reinterpret(Float64, reinterpret(UInt64, x) & 0xffff_ffff_f800_0000)
return hi, x-hi
end

@inline function twomul(a::Float64, b::Float64)
ahi, alo = splitbits(a)
bhi, blo = splitbits(b)
abhi = a*b
blohi, blolo = splitbits(blo)
ablo = alo*blohi - (((abhi - ahi*bhi) - alo*bhi) - ahi*blo) + blolo*alo
return abhi, ablo
end

fma_libm(x::Float32, y::Float32, z::Float32) =
ccall(("fmaf", libm_name), Float32, (Float32,Float32,Float32), x, y, z)
fma_libm(x::Float64, y::Float64, z::Float64) =
ccall(("fma", libm_name), Float64, (Float64,Float64,Float64), x, y, z)
function fma_emulated(a::Float64, b::Float64,c::Float64)
abhi, ablo = twomul(a,b)
if !isfinite(abhi+c) || isless(abs(abhi), nextfloat(0x1p-969)) || issubnormal(a) || issubnormal(b)
(isfinite(a) && isfinite(b) && isfinite(c)) || return abhi+c
(iszero(a) || iszero(b)) && return abhi+c
bias = exponent(a) + exponent(b)
c_denorm = ldexp(c, -bias)
if isfinite(c_denorm)
# rescale a and b to [1,2), equivalent to ldexp(a, -exponent(a))
issubnormal(a) && (a *= 0x1p52)
issubnormal(b) && (b *= 0x1p52)
a = reinterpret(Float64, (reinterpret(UInt64, a) & 0x800fffffffffffff) | 0x3ff0000000000000)
b = reinterpret(Float64, (reinterpret(UInt64, b) & 0x800fffffffffffff) | 0x3ff0000000000000)
c = c_denorm
abhi, ablo = twomul(a,b)
r = abhi+c
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
sumhi = r+s
# If result is subnormal, ldexp will cause double rounding because subnormals have fewer mantisa bits.
# As such, we need to check whether round to even would lead to double rounding and manually round sumhi to avoid it.
if issubnormal(ldexp(sumhi, bias))
sumlo = r-sumhi+s
bits_lost = -bias-exponent(sumhi)-1022
sumhiInt = reinterpret(UInt64, sumhi)
if (bits_lost != 1) (sumhiInt&1 == 1)
sumhi = nextfloat(sumhi, cmp(sumlo,0))
end
end
return ldexp(sumhi, bias)
end
isinf(abhi) && signbit(c) == signbit(a*b) && return abhi
# fall through
end
r = abhi+c
s = (abs(abhi) > abs(c)) ? (abhi-r+c+ablo) : (c-r+abhi+ablo)
return r+s
end
fma_llvm(x::Float32, y::Float32, z::Float32) = fma_float(x, y, z)
fma_llvm(x::Float64, y::Float64, z::Float64) = fma_float(x, y, z)
# Disable LLVM's fma if it is incorrect, e.g. because LLVM falls back
# onto a broken system libm; if so, use openlibm's fma instead
# onto a broken system libm; if so, use a software emulated fma
# 1.0000305f0 = 1 + 1/2^15
# 1.0000000009313226 = 1 + 1/2^30
# If fma_llvm() clobbers the rounding mode, the result of 0.1 + 0.2 will be 0.3
# instead of the properly-rounded 0.30000000000000004; check after calling fma
# TODO actually detect fma in hardware and switch on that.
if (Sys.ARCH !== :i686 && fma_llvm(1.0000305f0, 1.0000305f0, -1.0f0) == 6.103609f-5 &&
(fma_llvm(1.0000000009313226, 1.0000000009313226, -1.0) ==
1.8626451500983188e-9) && 0.1 + 0.2 == 0.30000000000000004)
fma(x::Float32, y::Float32, z::Float32) = fma_llvm(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_llvm(x,y,z)
else
fma(x::Float32, y::Float32, z::Float32) = fma_libm(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_libm(x,y,z)
fma(x::Float32, y::Float32, z::Float32) = fma_emulated(x,y,z)
fma(x::Float64, y::Float64, z::Float64) = fma_emulated(x,y,z)
end
function fma(a::Float16, b::Float16, c::Float16)
Float16(fma(Float32(a), Float32(b), Float32(c)))
Float16(muladd(Float32(a), Float32(b), Float32(c))) #don't use fma if the hardware doesn't have it.
end

# This is necessary at least on 32-bit Intel Linux, since fma_llvm may
Expand Down
24 changes: 24 additions & 0 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1286,3 +1286,27 @@ end
@test_throws MethodError f(x)
end
end

@testset "fma" begin
for func in (fma, Base.fma_emulated)
@test func(nextfloat(1.),nextfloat(1.),-1.0) === 4.440892098500626e-16
@test func(nextfloat(1f0),nextfloat(1f0),-1f0) === 2.3841858f-7
@testset "$T" for T in (Float32, Float64)
@test func(floatmax(T), T(2), -floatmax(T)) === floatmax(T)
@test func(floatmax(T), T(1), eps(floatmax((T)))) === T(Inf)
@test func(T(Inf), T(Inf), T(Inf)) === T(Inf)
@test isnan_type(T, func(T(Inf), T(1), -T(Inf)))
@test isnan_type(T, func(T(Inf), T(0), -T(0)))
@test func(-zero(T), zero(T), -zero(T)) === -zero(T)
for _ in 1:2^18
a, b, c = reinterpret.(T, rand(Base.uinttype(T), 3))
@test isequal(func(a, b, c), fma(a, b, c)) || (a,b,c)
end
end
@test func(floatmax(Float64), nextfloat(1.0), -floatmax(Float64)) === 3.991680619069439e292
@test func(floatmax(Float32), nextfloat(1f0), -floatmax(Float32)) === 4.0564817f31
@test func(1.6341681540852291e308, -2., floatmax(Float64)) == -1.4706431733081426e308 # case where inv(a)*c*a == Inf
@test func(-2., 1.6341681540852291e308, floatmax(Float64)) == -1.4706431733081426e308 # case where inv(b)*c*b == Inf
@test func(-1.9369631f13, 2.1513551f-7, -1.7354427f-24) == -4.1670958f6
end
end

0 comments on commit ee36c13

Please sign in to comment.