Skip to content

Commit

Permalink
Fix eachindex for more than one ReshapedReinterpretArray (JuliaLang#3…
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Oct 13, 2020
1 parent c24a932 commit 819693a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
9 changes: 7 additions & 2 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,12 @@ struct SCartesianIndices2{K,R<:AbstractUnitRange{Int}} <: AbstractMatrix{SCartes
end
SCartesianIndices2{K}(indices2::AbstractUnitRange{Int}) where {K} = (@assert K::Int > 1; SCartesianIndices2{K,typeof(indices2)}(indices2))

eachindex(::IndexSCartesian2{K}, A::AbstractArray) where {K} = SCartesianIndices2{K}(eachindex(IndexLinear(), parent(A)))
eachindex(::IndexSCartesian2{K}, A::ReshapedReinterpretArray) where {K} = SCartesianIndices2{K}(eachindex(IndexLinear(), parent(A)))
@inline function eachindex(style::IndexSCartesian2{K}, A::AbstractArray, B::AbstractArray...) where {K}
iter = eachindex(style, A)
Base._all_match_first(C->eachindex(style, C), iter, B...) || Base.throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
return iter
end

size(iter::SCartesianIndices2{K}) where K = (K, length(iter.indices2))
axes(iter::SCartesianIndices2{K}) where K = (Base.OneTo(K), iter.indices2)
Expand Down Expand Up @@ -265,7 +270,7 @@ function _setindex!(::IndexSCartesian2, A::AbstractArray{T,N}, v, ind::SCartesia
J = _ind2sub(tail(axes(A)), ind.j)
setindex!(A, v, ind.i, J...)
end

eachindex(style::IndexSCartesian2, A::AbstractArray) = eachindex(style, parent(A))

## AbstractArray interface

Expand Down
7 changes: 7 additions & 0 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ Azr = reinterpret(reshape, UInt8, Az)
W = WrapperArray(Azr)
copyto!(W, fill(0x01, 3, 2, 2))
@test all(isequal((0x01, 0x01, 0x01)), Az)
@test eachindex(W, W) == eachindex(W)

# ensure that reinterpret arrays aren't erroneously classified as strided
let A = reshape(1:20, 5, 4)
Expand Down Expand Up @@ -238,6 +239,12 @@ let a = fill(1.0, 5, 3)
r[goodinds...] = -5
@test r[goodinds...] == -5
end

ar = [(1,2), (3,4)]
arr = reinterpret(reshape, Int, ar)
@test @inferred(IndexStyle(arr)) == Base.IndexSCartesian2{2}()
@test @inferred(eachindex(arr)) == Base.SCartesianIndices2{2}(Base.OneTo(2))
@test @inferred(eachindex(arr, arr)) == Base.SCartesianIndices2{2}(Base.OneTo(2))
end

# Error on reinterprets that would expose padding
Expand Down

0 comments on commit 819693a

Please sign in to comment.