Skip to content

Commit

Permalink
Ensure that BitArrays are inferrable for arbitrary Integer (#37082)
Browse files Browse the repository at this point in the history
BitArray operations are frequently performed with partial inference
information (particularly from Pkg but others as well).
Currently the internal operations are a source of invalidation for
common arithmetic and bitwise operators. There seems to be no need
to live with this, as all the inference problems are fixed by
converting `Integer`s to `Int`s. This is done elsewhere already in
BitArray code, so it only seems to be continuing a process that is
already underway.

Because of the multiplicity of methods `BitArray{N}(a::Any)` is not
fully inferrable but there are at least limited victories,
and the real benefits are in the elimination of instabilities
in the internal operations.
  • Loading branch information
timholy committed Aug 26, 2020
1 parent d84cd4d commit b0ab29e
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 42 deletions.
2 changes: 1 addition & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ promote_eltype() = Bottom
promote_eltype(v1, vs...) = promote_type(eltype(v1), promote_eltype(vs...))

#TODO: ERROR CHECK
_cat(catdim::Integer) = Vector{Any}()
_cat(catdim::Int) = Vector{Any}()

typed_vcat(::Type{T}) where {T} = Vector{T}()
typed_hcat(::Type{T}) where {T} = Vector{T}()
Expand Down
107 changes: 66 additions & 41 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ julia> BitArray(undef, (3, 1))
```
"""
BitArray(::UndefInitializer, dims::Integer...) = BitArray(undef, map(Int,dims))
BitArray{N}(::UndefInitializer, dims::Integer...) where {N} = BitArray{N}(undef, map(Int,dims))
BitArray{N}(::UndefInitializer, dims::Integer...) where {N} = BitArray{N}(undef, map(Int,dims))::BitArray{N}
BitArray(::UndefInitializer, dims::NTuple{N,Integer}) where {N} = BitArray{N}(undef, map(Int, dims)...)
BitArray{N}(::UndefInitializer, dims::NTuple{N,Integer}) where {N} = BitArray{N}(undef, map(Int, dims)...)

Expand Down Expand Up @@ -119,11 +119,11 @@ const _msk64 = ~UInt64(0)
@inline _div64(l) = l >> 6
@inline _mod64(l) = l & 63
@inline _blsr(x)= x & (x-1) #zeros the last set bit. Has native instruction on many archs. needed in multidimensional.jl
@inline _msk_end(l::Integer) = _msk64 >>> _mod64(-l)
@inline _msk_end(l::Int) = _msk64 >>> _mod64(-l)
@inline _msk_end(B::BitArray) = _msk_end(length(B))
num_bit_chunks(n::Int) = _div64(n+63)

@inline get_chunks_id(i::Integer) = _div64(Int(i)-1)+1, _mod64(Int(i)-1)
@inline get_chunks_id(i::Int) = _div64(i-1)+1, _mod64(i-1)

function glue_src_bitchunks(src::Vector{UInt64}, k::Int, ks1::Int, msk_s0::UInt64, ls0::Int)
@inbounds begin
Expand All @@ -136,7 +136,7 @@ function glue_src_bitchunks(src::Vector{UInt64}, k::Int, ks1::Int, msk_s0::UInt6
return chunk
end

function copy_chunks!(dest::Vector{UInt64}, pos_d::Integer, src::Vector{UInt64}, pos_s::Integer, numbits::Integer)
function copy_chunks!(dest::Vector{UInt64}, pos_d::Int, src::Vector{UInt64}, pos_s::Int, numbits::Int)
numbits == 0 && return
if dest === src && pos_d > pos_s
return copy_chunks_rtol!(dest, pos_d, pos_s, numbits)
Expand Down Expand Up @@ -192,7 +192,7 @@ function copy_chunks!(dest::Vector{UInt64}, pos_d::Integer, src::Vector{UInt64},
return
end

function copy_chunks_rtol!(chunks::Vector{UInt64}, pos_d::Integer, pos_s::Integer, numbits::Integer)
function copy_chunks_rtol!(chunks::Vector{UInt64}, pos_d::Int, pos_s::Int, numbits::Int)
pos_d == pos_s && return
pos_d < pos_s && return copy_chunks!(chunks, pos_d, chunks, pos_s, numbits)

Expand Down Expand Up @@ -240,7 +240,7 @@ function copy_chunks_rtol!(chunks::Vector{UInt64}, pos_d::Integer, pos_s::Intege
end
end

function fill_chunks!(Bc::Array{UInt64}, x::Bool, pos::Integer, numbits::Integer)
function fill_chunks!(Bc::Array{UInt64}, x::Bool, pos::Int, numbits::Int)
numbits <= 0 && return
k0, l0 = get_chunks_id(pos)
k1, l1 = get_chunks_id(pos+numbits-1)
Expand Down Expand Up @@ -454,11 +454,13 @@ function copyto!(dest::BitArray, src::BitArray)
end

function unsafe_copyto!(dest::BitArray, doffs::Integer, src::Union{BitArray,Array}, soffs::Integer, n::Integer)
copy_to_bitarray_chunks!(dest.chunks, doffs, src, soffs, n)
copy_to_bitarray_chunks!(dest.chunks, Int(doffs), src, Int(soffs), Int(n))
return dest
end

function copyto!(dest::BitArray, doffs::Integer, src::Array, soffs::Integer, n::Integer)
copyto!(dest::BitArray, doffs::Integer, src::Array, soffs::Integer, n::Integer) =
_copyto_int!(dest, Int(doffs), src, Int(soffs), Int(n))
function _copyto_int!(dest::BitArray, doffs::Int, src::Array, soffs::Int, n::Int)
n == 0 && return dest
soffs < 1 && throw(BoundsError(src, soffs))
doffs < 1 && throw(BoundsError(dest, doffs))
Expand Down Expand Up @@ -500,7 +502,7 @@ end

BitArray(A::AbstractArray{<:Any,N}) where {N} = BitArray{N}(A)
function BitArray{N}(A::AbstractArray{T,N}) where N where T
B = BitArray(undef, size(A))
B = BitArray(undef, convert(Dims{N}, size(A)::Dims{N}))
Bc = B.chunks
l = length(B)
l == 0 && return B
Expand All @@ -509,14 +511,14 @@ function BitArray{N}(A::AbstractArray{T,N}) where N where T
for i = 1:length(Bc)-1
c = UInt64(0)
for j = 0:63
c |= (UInt64(convert(Bool, A[ind])) << j)
c |= (UInt64(convert(Bool, A[ind])::Bool) << j)
ind += 1
end
Bc[i] = c
end
c = UInt64(0)
for j = 0:_mod64(l-1)
c |= (UInt64(convert(Bool, A[ind])) << j)
c |= (UInt64(convert(Bool, A[ind])::Bool) << j)
ind += 1
end
Bc[end] = c
Expand All @@ -530,14 +532,14 @@ function BitArray{N}(A::Array{Bool,N}) where N
l = length(B)
l == 0 && return B
copy_to_bitarray_chunks!(Bc, 1, A, 1, l)
return B
return B::BitArray{N}
end

reinterpret(::Type{Bool}, B::BitArray, dims::NTuple{N,Int}) where {N} = reinterpret(B, dims)
reinterpret(B::BitArray, dims::NTuple{N,Int}) where {N} = reshape(B, dims)

if nameof(@__MODULE__) === :Base # avoid method overwrite
(::Type{T})(x::T) where {T<:BitArray} = copy(x)
(::Type{T})(x::T) where {T<:BitArray} = copy(x)::T
BitArray(x::BitArray) = copy(x)
end

Expand Down Expand Up @@ -570,6 +572,7 @@ julia> BitArray(x+y == 3 for x = 1:2 for y = 1:3)
```
"""
BitArray(itr) = gen_bitarray(IteratorSize(itr), itr)
BitArray{N}(itr) where N = gen_bitarrayN(BitArray{N}, IteratorSize(itr), itr)

convert(T::Type{<:BitArray}, a::AbstractArray) = a isa T ? a : T(a)

Expand Down Expand Up @@ -598,6 +601,14 @@ end

gen_bitarray(::IsInfinite, itr) = throw(ArgumentError("infinite-size iterable used in BitArray constructor"))

gen_bitarrayN(::Type{BitVector}, itsz, itr) = gen_bitarray(itsz, itr)
gen_bitarrayN(::Type{BitVector}, itsz::HasShape{1}, itr) = gen_bitarray(itsz, itr)
gen_bitarrayN(::Type{BitArray{N}}, itsz::HasShape{N}, itr) where N = gen_bitarray(itsz, itr)
# The first of these is just for ambiguity resolution
gen_bitarrayN(::Type{BitVector}, itsz::HasShape{N}, itr) where N = throw(DimensionMismatch("cannot create a $T from a $N-dimensional iterator"))
gen_bitarrayN(@nospecialize(T::Type), itsz::HasShape{N}, itr) where N = throw(DimensionMismatch("cannot create a $T from a $N-dimensional iterator"))
gen_bitarrayN(@nospecialize(T::Type), itsz, itr) = throw(DimensionMismatch("cannot create a $T from a generic iterator"))

# The aux functions gen_bitarray_from_itr and fill_bitarray_from_itr! both
# use a Vector{Bool} cache for performance reasons

Expand Down Expand Up @@ -768,7 +779,7 @@ function append!(B::BitVector, items::BitVector)
return B
end

append!(B::BitVector, items) = append!(B, BitArray(items))
append!(B::BitVector, items) = append!(B, BitVector(items))
append!(A::Vector{Bool}, items::BitVector) = append!(A, Array(items))

function prepend!(B::BitVector, items::BitVector)
Expand Down Expand Up @@ -796,7 +807,8 @@ function sizehint!(B::BitVector, sz::Integer)
return B
end

function resize!(B::BitVector, n::Integer)
resize!(B::BitVector, n::Integer) = _resize_int!(B, Int(n))
function _resize_int!(B::BitVector, n::Int)
n0 = length(B)
n == n0 && return B
n >= 0 || throw(BoundsError(B, n))
Expand All @@ -806,7 +818,7 @@ function resize!(B::BitVector, n::Integer)
end
Bc = B.chunks
k0 = length(Bc)
k1 = num_bit_chunks(Int(n))
k1 = num_bit_chunks(n)
if k1 > k0
_growend!(Bc, k1 - k0)
Bc[end] = UInt64(0)
Expand Down Expand Up @@ -872,7 +884,9 @@ function popfirst!(B::BitVector)
return item
end

function insert!(B::BitVector, i::Integer, item)
insert!(B::BitVector, i::Integer, item) = _insert_int!(B, Int(i), item)
function _insert_int!(B::BitVector, i::Int, item)
i = Int(i)
n = length(B)
1 <= i <= n+1 || throw(BoundsError(B, i))
item = convert(Bool, item)
Expand All @@ -899,7 +913,7 @@ function insert!(B::BitVector, i::Integer, item)
B
end

function _deleteat!(B::BitVector, i::Integer)
function _deleteat!(B::BitVector, i::Int)
k, j = get_chunks_id(i)

msk_bef = _msk64 >>> (63 - j)
Expand Down Expand Up @@ -933,6 +947,7 @@ function _deleteat!(B::BitVector, i::Integer)
end

function deleteat!(B::BitVector, i::Integer)
i = Int(i)
n = length(B)
1 <= i <= n || throw(BoundsError(B, i))

Expand Down Expand Up @@ -983,14 +998,14 @@ function deleteat!(B::BitVector, inds)
end
new_l -= 1
if i > q
copy_chunks!(Bc, p, Bc, q, i-q)
copy_chunks!(Bc, p, Bc, Int(q), Int(i-q))
p += i-q
end
q = i+1
y = iterate(inds, s)
end

q <= n && copy_chunks!(Bc, p, Bc, q, n-q+1)
q <= n && copy_chunks!(Bc, p, Bc, Int(q), Int(n-q+1))

delta_k = num_bit_chunks(new_l) - length(Bc)
delta_k < 0 && _deleteend!(Bc, -delta_k)
Expand All @@ -1005,6 +1020,7 @@ function deleteat!(B::BitVector, inds)
end

function splice!(B::BitVector, i::Integer)
i = Int(i)
n = length(B)
1 <= i <= n || throw(BoundsError(B, i))

Expand All @@ -1016,10 +1032,11 @@ end
const _default_bit_splice = BitVector()

function splice!(B::BitVector, r::Union{UnitRange{Int}, Integer}, ins::AbstractArray = _default_bit_splice)
_splice_int!(B, isa(r, UnitRange{Int}) ? r : Int(r), ins)
end
function _splice_int!(B::BitVector, r, ins)
n = length(B)
i_f = first(r)
i_l = last(r)

i_f, i_l = first(r), last(r)
1 <= i_f <= n+1 || throw(BoundsError(B, i_f))
i_l <= n || throw(BoundsError(B, n+1))

Expand Down Expand Up @@ -1146,8 +1163,9 @@ end

# TODO some of this could be optimized

function reverse(A::BitArray; dims::Integer)
nd = ndims(A); d = dims
reverse(A::BitArray; dims::Integer) = _reverse_int(A, Int(dims))
function _reverse_int(A::BitArray, d::Int)
nd = ndims(A)
1 d nd || throw(ArgumentError("dimension $d is not 1 ≤ $d$nd"))
sd = size(A, d)
sd == 1 && return copy(A)
Expand All @@ -1156,7 +1174,7 @@ function reverse(A::BitArray; dims::Integer)

nnd = 0
for i = 1:nd
nnd += Int(size(A,i)==1 || i==d)
nnd += size(A,i)==1 || i==d
end
if nnd == nd
# reverse along the only non-singleton dimension
Expand Down Expand Up @@ -1253,15 +1271,15 @@ function (<<)(B::BitVector, i::UInt)
n = length(B)
i == 0 && return copy(B)
A = falses(n)
i < n && copy_chunks!(A.chunks, 1, B.chunks, i+1, n-i)
i < n && copy_chunks!(A.chunks, 1, B.chunks, Int(i+1), Int(n-i))
return A
end

function (>>>)(B::BitVector, i::UInt)
n = length(B)
i == 0 && return copy(B)
A = falses(n)
i < n && copy_chunks!(A.chunks, i+1, B.chunks, 1, n-i)
i < n && copy_chunks!(A.chunks, Int(i+1), B.chunks, 1, Int(n-i))
return A
end

Expand Down Expand Up @@ -1348,7 +1366,9 @@ details and examples.
"""
(>>>)(B::BitVector, i::Int) = (i >=0 ? B >> unsigned(i) : B << unsigned(-i))

function circshift!(dest::BitVector, src::BitVector, i::Integer)
circshift!(dest::BitVector, src::BitVector, i::Integer) = _circshift_int!(dest, src, Int(i))
function _circshift_int!(dest::BitVector, src::BitVector, i::Int)
i = Int(i)
length(dest) == length(src) || throw(ArgumentError("destination and source should be of same size"))
n = length(dest)
i %= n
Expand Down Expand Up @@ -1400,16 +1420,16 @@ end

# returns the index of the next true element, or nothing if all false
function findnext(B::BitArray, start::Integer)
start = Int(start)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing
unsafe_bitfindnext(B.chunks, Int(start))
unsafe_bitfindnext(B.chunks, start)
end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl

# aux function: same as findnext(~B, start), but performed without temporaries
function findnextnot(B::BitArray, start::Integer)
start = Int(start)
function findnextnot(B::BitArray, start::Int)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing

Expand Down Expand Up @@ -1444,22 +1464,23 @@ findfirstnot(B::BitArray) = findnextnot(B,1)
function findnext(pred::Fix2{<:Union{typeof(isequal),typeof(==)},Bool},
B::BitArray, start::Integer)
v = pred.x
v == false && return findnextnot(B, start)
v == false && return findnextnot(B, Int(start))
v == true && return findnext(B, start)
return nothing
end
#findfirst(B::BitArray, v) = findnext(B, 1, v) ## defined in array.jl

# returns the index of the first element for which the function returns true
function findnext(testf::Function, B::BitArray, start::Integer)
findnext(testf::Function, B::BitArray, start::Integer) = _findnext_int(testf, B, Int(start))
function _findnext_int(testf::Function, B::BitArray, start::Int)
f0::Bool = testf(false)
f1::Bool = testf(true)
!f0 && f1 && return findnext(B, start)
f0 && !f1 && return findnextnot(B, start)

start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing
f0 && f1 && return Int(start)
f0 && f1 && return start
return nothing # last case: !f0 && !f1
end
#findfirst(testf::Function, B::BitArray) = findnext(testf, B, 1) ## defined in array.jl
Expand All @@ -1484,12 +1505,13 @@ end

# returns the index of the previous true element, or nothing if all false
function findprev(B::BitArray, start::Integer)
start = Int(start)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, Int(start))
unsafe_bitfindprev(B.chunks, start)
end

function findprevnot(B::BitArray, start::Integer)
function findprevnot(B::BitArray, start::Int)
start = Int(start)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
Expand Down Expand Up @@ -1518,22 +1540,23 @@ findlastnot(B::BitArray) = findprevnot(B, length(B))
function findprev(pred::Fix2{<:Union{typeof(isequal),typeof(==)},Bool},
B::BitArray, start::Integer)
v = pred.x
v == false && return findprevnot(B, start)
v == false && return findprevnot(B, Int(start))
v == true && return findprev(B, start)
return nothing
end
#findlast(B::BitArray, v) = findprev(B, 1, v) ## defined in array.jl

# returns the index of the previous element for which the function returns true
function findprev(testf::Function, B::BitArray, start::Integer)
findprev(testf::Function, B::BitArray, start::Integer) = _findprev_int(testf, B, Int(start))
function _findprev_int(testf::Function, B::BitArray, start::Int)
f0::Bool = testf(false)
f1::Bool = testf(true)
!f0 && f1 && return findprev(B, start)
f0 && !f1 && return findprevnot(B, start)

start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
f0 && f1 && return Int(start)
f0 && f1 && return start
return nothing # last case: !f0 && !f1
end
#findlast(testf::Function, B::BitArray) = findprev(testf, B, 1) ## defined in array.jl
Expand Down Expand Up @@ -1808,7 +1831,9 @@ function vcat(A::BitMatrix...)
end

# general case, specialized for BitArrays and Integers
function _cat(dims::Integer, X::Union{BitArray, Bool}...)
_cat(dims::Integer, X::Union{BitArray, Bool}...) = _cat(Int(dims)::Int, X...)
function _cat(dims::Int, X::Union{BitArray, Bool}...)
dims = Int(dims)
catdims = dims2cat(dims)
shape = cat_shape(catdims, map(cat_size, X))
A = falses(shape)
Expand Down
4 changes: 4 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ timesofar("utils")
# issue #24062
@test_throws InexactError BitArray([0, 1, 2, 3])
@test_throws MethodError BitArray([0, ""])

# construction with poor inference
f(c) = BitVector(c[1])
@test @inferred(f(AbstractVector[[0,1]])) == [false, true]
end

timesofar("constructors")
Expand Down

0 comments on commit b0ab29e

Please sign in to comment.