Skip to content

Commit

Permalink
A few more #26670 fixes
Browse files Browse the repository at this point in the history
Fixes BigFloat digit rounding (JuliaIO/Formatting.jl#56). I've also tweaked the definitions to make it easier to extend to new number formats.
  • Loading branch information
simonbyrne committed Apr 10, 2018
1 parent 89d2397 commit e396c03
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 39 deletions.
34 changes: 17 additions & 17 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,26 +348,26 @@ trunc(::Type{Integer}, x::Float64) = trunc(Int,x)
trunc(::Type{T}, x::Float16) where {T<:Integer} = trunc(T, Float32(x))

# fallbacks
floor(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,_round(x, RoundDown))
floor(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,round(x, RoundDown))
floor(::Type{T}, x::Float16) where {T<:Integer} = floor(T, Float32(x))
ceil(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,_round(x, RoundUp))
ceil(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,round(x, RoundUp))
ceil(::Type{T}, x::Float16) where {T<:Integer} = ceil(T, Float32(x))
round(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,_round(x, RoundNearest))
round(::Type{T}, x::AbstractFloat) where {T<:Integer} = trunc(T,round(x, RoundNearest))
round(::Type{T}, x::Float16) where {T<:Integer} = round(T, Float32(x))

_round(x::Float64, r::RoundingMode{:ToZero}) = trunc_llvm(x)
_round(x::Float32, r::RoundingMode{:ToZero}) = trunc_llvm(x)
_round(x::Float64, r::RoundingMode{:Down}) = floor_llvm(x)
_round(x::Float32, r::RoundingMode{:Down}) = floor_llvm(x)
_round(x::Float64, r::RoundingMode{:Up}) = ceil_llvm(x)
_round(x::Float32, r::RoundingMode{:Up}) = ceil_llvm(x)
_round(x::Float64, r::RoundingMode{:Nearest}) = rint_llvm(x)
_round(x::Float32, r::RoundingMode{:Nearest}) = rint_llvm(x)
round(x::Float64, r::RoundingMode{:ToZero}) = trunc_llvm(x)
round(x::Float32, r::RoundingMode{:ToZero}) = trunc_llvm(x)
round(x::Float64, r::RoundingMode{:Down}) = floor_llvm(x)
round(x::Float32, r::RoundingMode{:Down}) = floor_llvm(x)
round(x::Float64, r::RoundingMode{:Up}) = ceil_llvm(x)
round(x::Float32, r::RoundingMode{:Up}) = ceil_llvm(x)
round(x::Float64, r::RoundingMode{:Nearest}) = rint_llvm(x)
round(x::Float32, r::RoundingMode{:Nearest}) = rint_llvm(x)

_round(x::Float16, r::RoundingMode{:ToZero}) = Float16(_round(Float32(x), r))
_round(x::Float16, r::RoundingMode{:Down}) = Float16(_round(Float32(x), r))
_round(x::Float16, r::RoundingMode{:Up}) = Float16(_round(Float32(x), r))
_round(x::Float16, r::RoundingMode{:Nearest}) = Float16(_round(Float32(x), r))
round(x::Float16, r::RoundingMode{:ToZero}) = Float16(round(Float32(x), r))
round(x::Float16, r::RoundingMode{:Down}) = Float16(round(Float32(x), r))
round(x::Float16, r::RoundingMode{:Up}) = Float16(round(Float32(x), r))
round(x::Float16, r::RoundingMode{:Nearest}) = Float16(round(Float32(x), r))

## floating point promotions ##
promote_rule(::Type{Float32}, ::Type{Float16}) = Float32
Expand Down Expand Up @@ -660,7 +660,7 @@ for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UIn
end
end
function (::Type{$Ti})(x::$Tf)
if ($(Tf(typemin(Ti))) <= x <= $(Tf(typemax(Ti)))) && (_round(x, RoundToZero) == x)
if ($(Tf(typemin(Ti))) <= x <= $(Tf(typemax(Ti)))) && (round(x, RoundToZero) == x)
return unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
Expand All @@ -681,7 +681,7 @@ for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UIn
end
end
function (::Type{$Ti})(x::$Tf)
if ($(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti)))) && (_round(x, RoundToZero) == x)
if ($(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti)))) && (round(x, RoundToZero) == x)
return unsafe_trunc($Ti,x)
else
throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x))
Expand Down
20 changes: 12 additions & 8 deletions base/floatfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@ julia> round(357.913; sigdigits=4, base=2)
# Extensions
To extend `round` to new numeric types, it is typically sufficient to define `Base._round(x::NewType, ::RoundingMode)`.
To extend `round` to new numeric types, it is typically sufficient to define `Base.round(x::NewType, r::RoundingMode)`.
"""
round(T::Type, x)

round(::Type{T}, x::AbstractFloat, r::RoundingMode{:ToZero}) where {T<:Integer} = trunc(T, x)
round(::Type{T}, x::AbstractFloat, r::RoundingMode) where {T<:Integer} = trunc(T, _round(x,r))
round(::Type{T}, x::AbstractFloat, r::RoundingMode) where {T<:Integer} = trunc(T, round(x,r))

# NOTE: this relies on the current keyword dispatch behaviour (#9498).
function round(x::Real, r::RoundingMode=RoundNearest;
digits::Union{Nothing,Integer}=nothing, sigdigits::Union{Nothing,Integer}=nothing, base=10)
isfinite(x) || return x
Expand All @@ -130,12 +131,15 @@ trunc(x::Real; kwargs...) = round(x, RoundToZero; kwargs...)
floor(x::Real; kwargs...) = round(x, RoundDown; kwargs...)
ceil(x::Real; kwargs...) = round(x, RoundUp; kwargs...)

_round(x, r::RoundingMode, digits::Nothing, sigdigits::Nothing, base) = _round(x, r)
_round(x::Integer, r::RoundingMode) = x
# avoid recursive calls
round(x::Real, r::RoundingMode) = throw(MethodError(round, (x,r)))
round(x::Integer, r::RoundingMode) = x

_round(x, r::RoundingMode, digits::Nothing, sigdigits::Nothing, base) = round(x, r)

# round x to multiples of 1/invstep
function _round_invstep(x, invstep, r::RoundingMode)
y = _round(x * invstep, r) / invstep
y = round(x * invstep, r) / invstep
if !isfinite(y)
return x
end
Expand All @@ -145,7 +149,7 @@ end
# round x to multiples of step
function _round_step(x, step, r::RoundingMode)
# TODO: use div with rounding mode
y = _round(x / step, r) * step
y = round(x / step, r) * step
if !isfinite(y)
if x > 0
return (r == RoundUp ? oftype(x, Inf) : zero(x))
Expand Down Expand Up @@ -191,12 +195,12 @@ _round(x, r::RoundingMode, digits::Integer, sigdigits::Integer, base) =
throw(ArgumentError("`round` cannot use both `digits` and `sigdigits` arguments."))

# C-style round
function _round(x::AbstractFloat, ::RoundingMode{:NearestTiesAway})
function round(x::AbstractFloat, ::RoundingMode{:NearestTiesAway})
y = trunc(x)
ifelse(x==y,y,trunc(2*x-y))
end
# Java-style round
function _round(x::AbstractFloat, ::RoundingMode{:NearestTiesUp})
function round(x::AbstractFloat, ::RoundingMode{:NearestTiesUp})
y = floor(x)
ifelse(x==y,y,copysign(floor(2*x-y),x))
end
Expand Down
2 changes: 1 addition & 1 deletion base/irrationals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ for op in Symbol[:+, :-, :*, :/, :^]
end
*(x::Bool, y::AbstractIrrational) = ifelse(x, Float64(y), 0.0)

_round(x::Irrational, r::RoundingMode) = _round(float(x), r)
round(x::Irrational, r::RoundingMode) = round(float(x), r)

macro irrational(sym, val, def)
esym = esc(sym)
Expand Down
19 changes: 6 additions & 13 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -827,27 +827,20 @@ function isinteger(x::BigFloat)
return ccall((:mpfr_integer_p, :libmpfr), Int32, (Ref{BigFloat},), x) != 0
end

for f in (:ceil, :floor, :trunc)
for (f,R) in ((:roundeven, :Nearest),
(:ceil, :Up),
(:floor, :Down),
(:trunc, :ToZero),
(:round, :NearestTiesAway))
@eval begin
function ($f)(x::BigFloat)
function round(x::BigFloat, ::RoundingMode{$(QuoteNode(R))})
z = BigFloat()
ccall(($(string(:mpfr_,f)), :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}), z, x)
return z
end
end
end

function round(x::BigFloat)
z = BigFloat()
ccall((:mpfr_rint, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Cint), z, x, ROUNDING_MODE[])
return z
end
function round(x::BigFloat,::RoundingMode{:NearestTiesAway})
z = BigFloat()
ccall((:mpfr_round, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}), z, x)
return z
end

function isinf(x::BigFloat)
return ccall((:mpfr_inf_p, :libmpfr), Int32, (Ref{BigFloat},), x) != 0
end
Expand Down
7 changes: 7 additions & 0 deletions test/rounding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ end
@test round(pi, sigdigits=1) 3.
@test round(pi, sigdigits=3) 3.14
@test round(pi, sigdigits=4, base=2) 3.25
@test round(big(pi)) big"3."
@test round(big(pi), digits=0) big"3."
@test round(big(pi), digits=1) big"3.1"
@test round(big(pi), digits=3, base=2) big"3.125"
@test round(big(pi), sigdigits=1) big"3."
@test round(big(pi), sigdigits=3) big"3.14"
@test round(big(pi), sigdigits=4, base=2) big"3.25"
@test round(10*pi, digits=-1) 30.
@test round(.1, digits=0) == 0.
@test round(-.1, digits=0) == -0.
Expand Down

0 comments on commit e396c03

Please sign in to comment.