Skip to content

Commit

Permalink
Implement muladd
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed Jan 19, 2015
1 parent 6bc53e6 commit e87e3fe
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 2 deletions.
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ export
mod1,
modf,
mod2pi,
muladd,
nextfloat,
nextpow,
nextpow2,
Expand Down
2 changes: 2 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ widen(::Type{Float32}) = Float64

fma(x::Float32, y::Float32, z::Float32) = box(Float32,fma_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z)))
fma(x::Float64, y::Float64, z::Float64) = box(Float64,fma_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z)))
muladd(x::Float32, y::Float32, z::Float32) = box(Float32,muladd_float(unbox(Float32,x),unbox(Float32,y),unbox(Float32,z)))
muladd(x::Float64, y::Float64, z::Float64) = box(Float64,muladd_float(unbox(Float64,x),unbox(Float64,y),unbox(Float64,z)))

# TODO: faster floating point div?
# TODO: faster floating point fld?
Expand Down
3 changes: 3 additions & 0 deletions base/float16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ end
function fma(a::Float16, b::Float16, c::Float16)
float16(fma(float32(a), float32(b), float32(c)))
end
function muladd(a::Float16, b::Float16, c::Float16)
float16(muladd(float32(a), float32(b), float32(c)))
end
for op in (:<,:<=,:isless)
@eval ($op)(a::Float16, b::Float16) = ($op)(float32(a), float32(b))
end
Expand Down
2 changes: 2 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ promote_to_super{T<:Number,S<:Number}(::Type{T}, ::Type{S}, ::Type) =
^(x::Number, y::Number) = ^(promote(x,y)...)

fma(x::Number, y::Number, z::Number) = fma(promote(x,y,z)...)
muladd(x::Number, y::Number, z::Number) = muladd(promote(x,y,z)...)

(&)(x::Integer, y::Integer) = (&)(promote(x,y)...)
(|)(x::Integer, y::Integer) = (|)(promote(x,y)...)
Expand Down Expand Up @@ -195,6 +196,7 @@ no_op_err(name, T) = error(name," not defined for ",T)

fma{T<:Number}(x::T, y::T, z::T) = no_op_err("fma", T)
fma(x::Integer, y::Integer, z::Integer) = x*y+z
muladd{T<:Number}(x::T, y::T, z::T) = no_op_err("muladd", T)

(&){T<:Integer}(x::T, y::T) = no_op_err("&", T)
(|){T<:Integer}(x::T, y::T) = no_op_err("|", T)
Expand Down
4 changes: 4 additions & 0 deletions doc/stdlib/math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ Mathematical Operators
``x*y``. On some systems this is significantly more expensive than
``x*y+z``.

.. function:: muladd(x, y, z)

Combined multiply-add, computes ``x*y+z`` in an efficient manner.

.. function:: div(x, y)
÷(x, y)

Expand Down
18 changes: 16 additions & 2 deletions src/intrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace JL_I {
neg_int, add_int, sub_int, mul_int,
sdiv_int, udiv_int, srem_int, urem_int, smod_int,
neg_float, add_float, sub_float, mul_float, div_float, rem_float,
fma_float,
fma_float, muladd_float,
// fast arithmetic
neg_float_fast, add_float_fast, sub_float_fast,
mul_float_fast, div_float_fast, rem_float_fast,
Expand Down Expand Up @@ -986,6 +986,20 @@ static Value *emit_intrinsic(intrinsic f, jl_value_t **args, size_t nargs,
ArrayRef<Type*>(x->getType())),
FP(x), FP(y), FP(z));
}
HANDLE(muladd_float,3)
#ifdef LLVM34
{
assert(y->getType() == x->getType());
assert(z->getType() == y->getType());
return builder.CreateCall3
(Intrinsic::getDeclaration(jl_Module, Intrinsic::fmuladd,
ArrayRef<Type*>(x->getType())),
FP(x), FP(y), FP(z));
}
#else
return math_builder(ctx, true)().
CreateFAdd(builder.CreateFMul(FP(x), FP(y)), FP(z));
#endif

HANDLE(checked_sadd,2)
HANDLE(checked_uadd,2)
Expand Down Expand Up @@ -1323,7 +1337,7 @@ extern "C" void jl_init_intrinsic_functions(void)
ADD_I(sdiv_int); ADD_I(udiv_int); ADD_I(srem_int); ADD_I(urem_int);
ADD_I(smod_int);
ADD_I(neg_float); ADD_I(add_float); ADD_I(sub_float); ADD_I(mul_float);
ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float);
ADD_I(div_float); ADD_I(rem_float); ADD_I(fma_float); ADD_I(muladd_float);
ADD_I(neg_float_fast); ADD_I(add_float_fast); ADD_I(sub_float_fast);
ADD_I(mul_float_fast); ADD_I(div_float_fast); ADD_I(rem_float_fast);
ADD_I(eq_int); ADD_I(ne_int);
Expand Down
29 changes: 29 additions & 0 deletions test/numbers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ let eps = 1//BigInt(2)^200, one_eps = 1+eps,
@test fma(one_eps256, one_eps256, -1) == BigFloat(one_eps * one_eps - 1)
end

# muladd

let eps = 1//BigInt(2)^30, one_eps = 1+eps,
eps64 = float64(eps), one_eps64 = float64(one_eps)
@test eps64 == float64(eps)
@test one_eps64 == float64(one_eps)
@test one_eps64 * one_eps64 - 1 != float64(one_eps * one_eps - 1)
@test isapprox(muladd(one_eps64, one_eps64, -1),
float64(one_eps * one_eps - 1))
end

let eps = 1//BigInt(2)^15, one_eps = 1+eps,
eps32 = float32(eps), one_eps32 = float32(one_eps)
@test eps32 == float32(eps)
@test one_eps32 == float32(one_eps)
@test one_eps32 * one_eps32 - 1 != float32(one_eps * one_eps - 1)
@test isapprox(muladd(one_eps32, one_eps32, -1),
float32(one_eps * one_eps - 1))
end

let eps = 1//BigInt(2)^7, one_eps = 1+eps,
eps16 = float16(float32(eps)), one_eps16 = float16(float32(one_eps))
@test eps16 == float16(float32(eps))
@test one_eps16 == float16(float32(one_eps))
@test one_eps16 * one_eps16 - 1 != float16(float32(one_eps * one_eps - 1))
@test isapprox(muladd(one_eps16, one_eps16, -1),
float16(float32(one_eps * one_eps - 1)))
end

# lexing typemin(Int64)
@test (-9223372036854775808)^1 == -9223372036854775808
@test [1 -1 -9223372036854775808] == [1 -1 typemin(Int64)]
Expand Down

0 comments on commit e87e3fe

Please sign in to comment.