Skip to content

Commit

Permalink
faster prod(::Array{BigInt}) (JuliaLang#41014)
Browse files Browse the repository at this point in the history
Use a first pass on the array to compute the size of the result.
  • Loading branch information
rfourquet authored Jun 10, 2021
1 parent ecab430 commit 8f57f88
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
25 changes: 18 additions & 7 deletions base/gmp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ export BigInt
import .Base: *, +, -, /, <, <<, >>, >>>, <=, ==, >, >=, ^, (~), (&), (|), xor, nand, nor,
binomial, cmp, convert, div, divrem, factorial, cld, fld, gcd, gcdx, lcm, mod,
ndigits, promote_rule, rem, show, isqrt, string, powermod,
sum, trailing_zeros, trailing_ones, count_ones, tryparse_internal,
sum, prod, trailing_zeros, trailing_ones, count_ones, tryparse_internal,
bin, oct, dec, hex, isequal, invmod, _prevpow2, _nextpow2, ndigits0zpb,
widen, signed, unsafe_trunc, trunc, iszero, isone, big, flipsign, signbit,
sign, hastypemax, isodd, iseven, digits!, hash, hash_integer
Expand Down Expand Up @@ -635,12 +635,23 @@ end
+(x::BigInt, y::BigInt, rest::BigInt...) = sum(tuple(x, y, rest...))
sum(arr::Union{AbstractArray{BigInt}, Tuple{BigInt, Vararg{BigInt}}}) =
foldl(MPZ.add!, arr; init=BigInt(0))
# Note: a similar implementation for `prod` won't be efficient:
# 1) the time complexity of the allocations is negligible compared to the multiplications
# 2) assuming arr contains similarly sized BigInts, the multiplications are much more
# performant when doing e.g. ((a1*a2)*(a3*a4))*(...) rather than a1*(a2*(a3*(...))),
# which is exactly what the default implementation of `prod` does, via `mapreduce`
# (which maybe could be slightly optimized for BigInt).

function prod(arr::AbstractArray{BigInt})
# compute first the needed number of bits for the result,
# to avoid re-allocations;
# GMP will always request n+m limbs for the result in MPZ.mul!,
# if the arguments have n and m limbs; so we add all the bits
# taken by the array elements, and add BITS_PER_LIMB to that,
# to account for the rounding to limbs in MPZ.mul!
# (BITS_PER_LIMB-1 would typically be enough, to which we add
# 1 for the initial multiplication by init=1 in foldl)
nbits = GC.@preserve arr sum(arr; init=BITS_PER_LIMB) do x
abs(x.size) * BITS_PER_LIMB - leading_zeros(unsafe_load(x.d))
end
init = BigInt(; nbits)
MPZ.set_si!(init, 1)
foldl(MPZ.mul!, arr; init)
end

factorial(x::BigInt) = isneg(x) ? BigInt(0) : MPZ.fac_ui(x)

Expand Down
3 changes: 3 additions & 0 deletions test/gmp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ let a, b
a = rand(1:100, 10000)
b = map(BigInt, a)
@test sum(a) == sum(b)
@test 0 == sum(BigInt[]) isa BigInt
@test prod(b) == foldl(*, b)
@test 1 == prod(BigInt[]) isa BigInt
end

@testset "Iterated arithmetic" begin
Expand Down

0 comments on commit 8f57f88

Please sign in to comment.