From e87e3fe6892a47efd9263aea36a88756dd5b09d6 Mon Sep 17 00:00:00 2001 From: Erik Schnetter Date: Mon, 19 Jan 2015 09:55:09 -0500 Subject: [PATCH] Implement `muladd` --- base/exports.jl | 1 + base/float.jl | 2 ++ base/float16.jl | 3 +++ base/promotion.jl | 2 ++ doc/stdlib/math.rst | 4 ++++ src/intrinsics.cpp | 18 ++++++++++++++++-- test/numbers.jl | 29 +++++++++++++++++++++++++++++ 7 files changed, 57 insertions(+), 2 deletions(-) diff --git a/base/exports.jl b/base/exports.jl index df72a4e05f70a..4b3454829d412 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -408,6 +408,7 @@ export mod1, modf, mod2pi, + muladd, nextfloat, nextpow, nextpow2, diff --git a/base/float.jl b/base/float.jl index 15097c53a7f80..89ca9c9e404c0 100644 --- a/base/float.jl +++ b/base/float.jl @@ -200,6 +200,8 @@ widen(::Type{Float32}) = Float64 fma(x::Float32, y::Float32, z::Float32) = box(Float32,fma_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z))) fma(x::Float64, y::Float64, z::Float64) = box(Float64,fma_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z))) +muladd(x::Float32, y::Float32, z::Float32) = box(Float32,muladd_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z))) +muladd(x::Float64, y::Float64, z::Float64) = box(Float64,muladd_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z))) # TODO: faster floating point div? # TODO: faster floating point fld? diff --git a/base/float16.jl b/base/float16.jl index 62dee7f60da50..9427812a3c995 100644 --- a/base/float16.jl +++ b/base/float16.jl @@ -136,6 +136,9 @@ end function fma(a::Float16, b::Float16, c::Float16) float16(fma(float32(a), float32(b), float32(c))) end +function muladd(a::Float16, b::Float16, c::Float16) + float16(muladd(float32(a), float32(b), float32(c))) +end for op in (:<,:<=,:isless) @eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b)) end diff --git a/base/promotion.jl b/base/promotion.jl index e48df7b58a089..657eb39fa9aaf 100644 --- a/base/promotion.jl +++ b/base/promotion.jl @@ -161,6 +161,7 @@ promote_to_super{T<:Number,S<:Number}(::Type{T}, ::Type{S}, ::Type) = ^(x::Number, y::Number) = ^(promote(x,y)...) fma(x::Number, y::Number, z::Number) = fma(promote(x,y,z)...) +muladd(x::Number, y::Number, z::Number) = muladd(promote(x,y,z)...) (&)(x::Integer, y::Integer) = (&)(promote(x,y)...) (|)(x::Integer, y::Integer) = (|)(promote(x,y)...) @@ -195,6 +196,7 @@ no_op_err(name, T) = error(name," not defined for ",T) fma{T<:Number}(x::T, y::T, z::T) = no_op_err("fma", T) fma(x::Integer, y::Integer, z::Integer) = x*y+z +muladd{T<:Number}(x::T, y::T, z::T) = no_op_err("muladd", T) (&){T<:Integer}(x::T, y::T) = no_op_err("&", T) (|){T<:Integer}(x::T, y::T) = no_op_err("|", T) diff --git a/doc/stdlib/math.rst b/doc/stdlib/math.rst index aea02def55f79..2c788bf5f4d23 100644 --- a/doc/stdlib/math.rst +++ b/doc/stdlib/math.rst @@ -83,6 +83,10 @@ Mathematical Operators ``x*y``. On some systems this is significantly more expensive than ``x*y+z``. +.. function:: muladd(x, y, z) + + Combined multiply-add, computes ``x*y+z`` in an efficient manner. + .. function:: div(x, y) รท(x, y) diff --git a/src/intrinsics.cpp b/src/intrinsics.cpp index a7e0b62ddad3f..6aec85e4ae388 100644 --- a/src/intrinsics.cpp +++ b/src/intrinsics.cpp @@ -6,7 +6,7 @@ namespace JL_I { neg_int, add_int, sub_int, mul_int, sdiv_int, udiv_int, srem_int, urem_int, smod_int, neg_float, add_float, sub_float, mul_float, div_float, rem_float, - fma_float, + fma_float, muladd_float, // fast arithmetic neg_float_fast, add_float_fast, sub_float_fast, mul_float_fast, div_float_fast, rem_float_fast, @@ -986,6 +986,20 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs, ArrayRef(x->getType())), FP(x), FP(y), FP(z)); } + HANDLE(muladd_float,3) +#ifdef LLVM34 + { + assert(y->getType() == x->getType()); + assert(z->getType() == y->getType()); + return builder.CreateCall3 + (Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd, + ArrayRef(x->getType())), + FP(x), FP(y), FP(z)); + } +#else + return math_builder(ctx, true)(). + CreateFAdd(builder.CreateFMul(FP(x), FP(y)), FP(z)); +#endif HANDLE(checked_sadd,2) HANDLE(checked_uadd,2) @@ -1323,7 +1337,7 @@ extern "C" void jl_init_intrinsic_functions(void) ADD_I(sdiv_int); ADD_I(udiv_int); ADD_I(srem_int); ADD_I(urem_int); ADD_I(smod_int); ADD_I(neg_float); ADD_I(add_float); ADD_I(sub_float); ADD_I(mul_float); - ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float); + ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float); ADD_I(muladd_float); ADD_I(neg_float_fast); ADD_I(add_float_fast); ADD_I(sub_float_fast); ADD_I(mul_float_fast); ADD_I(div_float_fast); ADD_I(rem_float_fast); ADD_I(eq_int); ADD_I(ne_int); diff --git a/test/numbers.jl b/test/numbers.jl index 89ec4e44ad8b9..e6d3b26eb96a7 100644 --- a/test/numbers.jl +++ b/test/numbers.jl @@ -105,6 +105,35 @@ let eps = 1//BigInt(2)^200, one_eps = 1+eps, @test fma(one_eps256, one_eps256, -1) == BigFloat(one_eps * one_eps - 1) end +# muladd + +let eps = 1//BigInt(2)^30, one_eps = 1+eps, + eps64 = float64(eps), one_eps64 = float64(one_eps) + @test eps64 == float64(eps) + @test one_eps64 == float64(one_eps) + @test one_eps64 * one_eps64 - 1 != float64(one_eps * one_eps - 1) + @test isapprox(muladd(one_eps64, one_eps64, -1), + float64(one_eps * one_eps - 1)) +end + +let eps = 1//BigInt(2)^15, one_eps = 1+eps, + eps32 = float32(eps), one_eps32 = float32(one_eps) + @test eps32 == float32(eps) + @test one_eps32 == float32(one_eps) + @test one_eps32 * one_eps32 - 1 != float32(one_eps * one_eps - 1) + @test isapprox(muladd(one_eps32, one_eps32, -1), + float32(one_eps * one_eps - 1)) +end + +let eps = 1//BigInt(2)^7, one_eps = 1+eps, + eps16 = float16(float32(eps)), one_eps16 = float16(float32(one_eps)) + @test eps16 == float16(float32(eps)) + @test one_eps16 == float16(float32(one_eps)) + @test one_eps16 * one_eps16 - 1 != float16(float32(one_eps * one_eps - 1)) + @test isapprox(muladd(one_eps16, one_eps16, -1), + float16(float32(one_eps * one_eps - 1))) +end + # lexing typemin(Int64) @test (-9223372036854775808)^1 == -9223372036854775808 @test [1 -1 -9223372036854775808] == [1 -1 typemin(Int64)]