Skip to content

Commit

Permalink
exact BigFloat to IEEE FP conversion in pure Julia (JuliaLang#50691)
Browse files Browse the repository at this point in the history
There's lots of code, but most of it seems like it will be useful in
general. For example, I think I'll use the changes in float.jl and
rounding.jl to improve the JuliaLang#49749 PR. The changes in float.jl could also
be used to refactor float.jl to remove many magic constants.

Benchmarking script:
```julia
using BenchmarkTools
f(::Type{T} = BigFloat, n::Int = 2000) where {T} = rand(T, n)
g!(u, v) = map!(eltype(u), u, v)
@Btime g!(u, v) setup=(u = f(Float16); v = f();)
@Btime g!(u, v) setup=(u = f(Float32); v = f();)
@Btime g!(u, v) setup=(u = f(Float64); v = f();)
```

On master (dc06468):
```
  46.116 μs (0 allocations: 0 bytes)
  38.842 μs (0 allocations: 0 bytes)
  37.039 μs (0 allocations: 0 bytes)
```

With both this commit and JuliaLang#50674 applied:
```
  42.310 μs (0 allocations: 0 bytes)
  42.661 μs (0 allocations: 0 bytes)
  41.608 μs (0 allocations: 0 bytes)
```

So, with this benchmark at least, on an AMD Zen 2 laptop, conversion to
`Float16` is faster, but there's a slowdown for `Float32` and `Float64`.

Fixes JuliaLang#50642 (exact conversion to `Float16`)

Co-authored-by: Oscar Smith <[email protected]>
  • Loading branch information
nsajko and oscardssmith committed Aug 21, 2023
1 parent 61ebaf6 commit ac607dc
Show file tree
Hide file tree
Showing 8 changed files with 445 additions and 29 deletions.
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ include("hashing.jl")
include("rounding.jl")
using .Rounding
include("div.jl")
include("rawbigints.jl")
include("float.jl")
include("twiceprecision.jl")
include("complex.jl")
Expand Down
62 changes: 62 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,68 @@ i.e. the maximum integer value representable by [`exponent_bits(T)`](@ref) bits.
"""
function exponent_raw_max end

"""
IEEE 754 definition of the minimum exponent.
"""
ieee754_exponent_min(::Type{T}) where {T<:IEEEFloat} = Int(1 - exponent_max(T))::Int

exponent_min(::Type{Float16}) = ieee754_exponent_min(Float16)
exponent_min(::Type{Float32}) = ieee754_exponent_min(Float32)
exponent_min(::Type{Float64}) = ieee754_exponent_min(Float64)

function ieee754_representation(
::Type{F}, sign_bit::Bool, exponent_field::Integer, significand_field::Integer
) where {F<:IEEEFloat}
T = uinttype(F)
ret::T = sign_bit
ret <<= exponent_bits(F)
ret |= exponent_field
ret <<= significand_bits(F)
ret |= significand_field
end

# ±floatmax(T)
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:omega}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, exponent_raw_max(F) - 1, significand_mask(F))
end

# NaN or an infinity
function ieee754_representation(
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:nan}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, exponent_raw_max(F), significand_field)
end

# NaN with default payload
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:nan}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, one(uinttype(F)) << (significand_bits(F) - 1), Val(:nan))
end

# Infinity
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:inf}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, Val(:nan))
end

# Subnormal or zero
function ieee754_representation(
::Type{F}, sign_bit::Bool, significand_field::Integer, ::Val{:subnormal}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, significand_field)
end

# Zero
function ieee754_representation(
::Type{F}, sign_bit::Bool, ::Val{:zero}
) where {F<:IEEEFloat}
ieee754_representation(F, sign_bit, false, Val(:subnormal))
end

"""
uabs(x::Integer)
Expand Down
112 changes: 84 additions & 28 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@ import
cbrt, typemax, typemin, unsafe_trunc, floatmin, floatmax, rounding,
setrounding, maxintfloat, widen, significand, frexp, tryparse, iszero,
isone, big, _string_n, decompose, minmax,
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand
sinpi, cospi, sincospi, tanpi, sind, cosd, tand, asind, acosd, atand,
uinttype, exponent_max, exponent_min, ieee754_representation, significand_mask,
RawBigIntRoundingIncrementHelper, truncated, RawBigInt


using .Base.Libc
import ..Rounding: rounding_raw, setrounding_raw
import ..Rounding:
rounding_raw, setrounding_raw, rounds_to_nearest, rounds_away_from_zero,
tie_breaker_is_to_even, correct_rounding_requires_increment

import ..GMP: ClongMax, CulongMax, CdoubleMax, Limb, libgmp

Expand Down Expand Up @@ -89,6 +93,21 @@ function convert(::Type{RoundingMode}, r::MPFRRoundingMode)
end
end

rounds_to_nearest(m::MPFRRoundingMode) = m == MPFRRoundNearest
function rounds_away_from_zero(m::MPFRRoundingMode, sign_bit::Bool)
if m == MPFRRoundToZero
false
elseif m == MPFRRoundUp
!sign_bit
elseif m == MPFRRoundDown
sign_bit
else
# Assuming `m == MPFRRoundFromZero`
true
end
end
tie_breaker_is_to_even(::MPFRRoundingMode) = true

const ROUNDING_MODE = Ref{MPFRRoundingMode}(MPFRRoundNearest)
const DEFAULT_PRECISION = Ref{Clong}(256)

Expand Down Expand Up @@ -136,6 +155,9 @@ mutable struct BigFloat <: AbstractFloat
end
end

# The rounding mode here shouldn't matter.
significand_limb_count(x::BigFloat) = div(sizeof(x._d), sizeof(Limb), RoundToZero)

rounding_raw(::Type{BigFloat}) = ROUNDING_MODE[]
setrounding_raw(::Type{BigFloat}, r::MPFRRoundingMode) = ROUNDING_MODE[]=r

Expand Down Expand Up @@ -386,35 +408,69 @@ function (::Type{T})(x::BigFloat) where T<:Integer
trunc(T,x)
end

## BigFloat -> AbstractFloat
_cpynansgn(x::AbstractFloat, y::BigFloat) = isnan(x) && signbit(x) != signbit(y) ? -x : x

Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
_cpynansgn(ccall((:mpfr_get_d,libmpfr), Float64, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
Float64(x::BigFloat, r::RoundingMode) = Float64(x, convert(MPFRRoundingMode, r))

Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) =
_cpynansgn(ccall((:mpfr_get_flt,libmpfr), Float32, (Ref{BigFloat}, MPFRRoundingMode), x, r), x)
Float32(x::BigFloat, r::RoundingMode) = Float32(x, convert(MPFRRoundingMode, r))

function Float16(x::BigFloat) :: Float16
res = Float32(x)
resi = reinterpret(UInt32, res)
if (resi&0x7fffffff) < 0x38800000 # if Float16(res) is subnormal
#shift so that the mantissa lines up where it would for normal Float16
shift = 113-((resi & 0x7f800000)>>23)
if shift<23
resi |= 0x0080_0000 # set implicit bit
resi >>= shift
function to_ieee754(::Type{T}, x::BigFloat, rm) where {T<:AbstractFloat}
sb = signbit(x)
is_zero = iszero(x)
is_inf = isinf(x)
is_nan = isnan(x)
is_regular = !is_zero & !is_inf & !is_nan
ieee_exp = Int(x.exp) - 1
ieee_precision = precision(T)
ieee_exp_max = exponent_max(T)
ieee_exp_min = exponent_min(T)
exp_diff = ieee_exp - ieee_exp_min
is_normal = 0 exp_diff
(rm_is_to_zero, rm_is_from_zero) = if rounds_to_nearest(rm)
(false, false)
else
let from = rounds_away_from_zero(rm, sb)
(!from, from)
end
end
if (resi & 0x1fff == 0x1000) # if we are halfway between 2 Float16 values
# adjust the value by 1 ULP in the direction that will make Float16(res) give the right answer
res = nextfloat(res, cmp(x, res))
end
return res
end::NTuple{2,Bool}
exp_is_huge_p = ieee_exp_max < ieee_exp
exp_is_huge_n = signbit(exp_diff + ieee_precision)
rounds_to_inf = is_regular & exp_is_huge_p & !rm_is_to_zero
rounds_to_zero = is_regular & exp_is_huge_n & !rm_is_from_zero
U = uinttype(T)

ret_u = if is_regular & !rounds_to_inf & !rounds_to_zero
if !exp_is_huge_p
# significand
v = RawBigInt(x.d, significand_limb_count(x))
len = max(ieee_precision + min(exp_diff, 0), 0)::Int
signif = truncated(U, v, len) & significand_mask(T)

# round up if necessary
rh = RawBigIntRoundingIncrementHelper(v, len)
incr = correct_rounding_requires_increment(rh, rm, sb)

# exponent
exp_field = max(exp_diff, 0) + is_normal

ieee754_representation(T, sb, exp_field, signif) + incr
else
ieee754_representation(T, sb, Val(:omega))
end
else
if is_zero | rounds_to_zero
ieee754_representation(T, sb, Val(:zero))
elseif is_inf | rounds_to_inf
ieee754_representation(T, sb, Val(:inf))
else
ieee754_representation(T, sb, Val(:nan))
end
end::U

reinterpret(T, ret_u)
end

Float16(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float16, x, r)
Float32(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float32, x, r)
Float64(x::BigFloat, r::MPFRRoundingMode=ROUNDING_MODE[]) = to_ieee754(Float64, x, r)
Float16(x::BigFloat, r::RoundingMode) = to_ieee754(Float16, x, r)
Float32(x::BigFloat, r::RoundingMode) = to_ieee754(Float32, x, r)
Float64(x::BigFloat, r::RoundingMode) = to_ieee754(Float64, x, r)

promote_rule(::Type{BigFloat}, ::Type{<:Real}) = BigFloat
promote_rule(::Type{BigInt}, ::Type{<:AbstractFloat}) = BigFloat
promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat
Expand Down
149 changes: 149 additions & 0 deletions base/rawbigints.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Segment of raw words of bits interpreted as a big integer. Less
significant words come first. Each word is in machine-native bit-order.
"""
struct RawBigInt{T<:Unsigned}
d::Ptr{T}
word_count::Int

function RawBigInt{T}(d::Ptr{T}, word_count::Int) where {T<:Unsigned}
new{T}(d, word_count)
end
end

RawBigInt(d::Ptr{T}, word_count::Int) where {T<:Unsigned} = RawBigInt{T}(d, word_count)
elem_count(x::RawBigInt, ::Val{:words}) = x.word_count
elem_count(x::Unsigned, ::Val{:bits}) = sizeof(x) * 8
word_length(::RawBigInt{T}) where {T} = elem_count(zero(T), Val(:bits))
elem_count(x::RawBigInt{T}, ::Val{:bits}) where {T} = word_length(x) * elem_count(x, Val(:words))
reversed_index(n::Int, i::Int) = n - i - 1
reversed_index(x, i::Int, v::Val) = reversed_index(elem_count(x, v), i)::Int
split_bit_index(x::RawBigInt, i::Int) = divrem(i, word_length(x), RoundToZero)

"""
`i` is the zero-based index of the wanted word in `x`, starting from
the less significant words.
"""
function get_elem(x::RawBigInt, i::Int, ::Val{:words}, ::Val{:ascending})
unsafe_load(x.d, i + 1)
end

function get_elem(x, i::Int, v::Val, ::Val{:descending})
j = reversed_index(x, i, v)
get_elem(x, j, v, Val(:ascending))
end

word_is_nonzero(x::RawBigInt, i::Int, v::Val) = !iszero(get_elem(x, i, Val(:words), v))

word_is_nonzero(x::RawBigInt, v::Val) = let x = x
i -> word_is_nonzero(x, i, v)
end

"""
Returns a `Bool` indicating whether the `len` least significant words
of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:words})
any(word_is_nonzero(x, Val(:ascending)), 0:(len - 1))
end

"""
Returns a `Bool` indicating whether the `len` least significant bits of
the `i`-th (zero-based index) word of `x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, i::Int, ::Val{:word})
!iszero(len) &&
!iszero(get_elem(x, i, Val(:words), Val(:ascending)) << (word_length(x) - len))
end

"""
Returns a `Bool` indicating whether the `len` least significant bits of
`x` are nonzero.
"""
function tail_is_nonzero(x::RawBigInt, len::Int, ::Val{:bits})
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
tail_is_nonzero(x, bit_count_in_word, word_count, Val(:word)) ||
tail_is_nonzero(x, word_count, Val(:words))
else
false
end::Bool
end

"""
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
"""
function get_elem(x::Unsigned, i::Int, ::Val{:bits}, ::Val{:ascending})
(x >>> i) % Bool
end

"""
Returns a `Bool` that is the `i`-th (zero-based index) bit of `x`.
"""
function get_elem(x::RawBigInt, i::Int, ::Val{:bits}, v::Val{:ascending})
vb = Val(:bits)
if 0 i < elem_count(x, vb)
word_index, bit_index_in_word = split_bit_index(x, i)
word = get_elem(x, word_index, Val(:words), v)
get_elem(word, bit_index_in_word, vb, v)
else
false
end::Bool
end

"""
Returns an integer of type `R`, consisting of the `len` most
significant bits of `x`.
"""
function truncated(::Type{R}, x::RawBigInt, len::Int) where {R<:Integer}
ret = zero(R)
if 0 < len
word_count, bit_count_in_word = split_bit_index(x, len)
k = word_length(x)
vals = (Val(:words), Val(:descending))

for w 0:(word_count - 1)
ret <<= k
word = get_elem(x, w, vals...)
ret |= R(word)
end

if !iszero(bit_count_in_word)
ret <<= bit_count_in_word
wrd = get_elem(x, word_count, vals...)
ret |= R(wrd >>> (k - bit_count_in_word))
end
end
ret::R
end

struct RawBigIntRoundingIncrementHelper{T<:Unsigned}
n::RawBigInt{T}
trunc_len::Int

final_bit::Bool
round_bit::Bool

function RawBigIntRoundingIncrementHelper{T}(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
vals = (Val(:bits), Val(:descending))
f = get_elem(n, len - 1, vals...)
r = get_elem(n, len , vals...)
new{T}(n, len, f, r)
end
end

function RawBigIntRoundingIncrementHelper(n::RawBigInt{T}, len::Int) where {T<:Unsigned}
RawBigIntRoundingIncrementHelper{T}(n, len)
end

(h::RawBigIntRoundingIncrementHelper)(::Rounding.FinalBit) = h.final_bit

(h::RawBigIntRoundingIncrementHelper)(::Rounding.RoundBit) = h.round_bit

function (h::RawBigIntRoundingIncrementHelper)(::Rounding.StickyBit)
v = Val(:bits)
n = h.n
tail_is_nonzero(n, elem_count(n, v) - h.trunc_len - 1, v)
end
Loading

0 comments on commit ac607dc

Please sign in to comment.