Skip to content

Commit

Permalink
SubArray: avoid invalid elimination of singleton indices (JuliaLang#5…
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Feb 8, 2024
1 parent 2bd4cf8 commit 4d0a469
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
14 changes: 8 additions & 6 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,11 @@ _maybe_reshape_parent(A::AbstractArray, ::NTuple{1, Bool}) = reshape(A, Val(1))
_maybe_reshape_parent(A::AbstractArray{<:Any,1}, ::NTuple{1, Bool}) = reshape(A, Val(1))
_maybe_reshape_parent(A::AbstractArray{<:Any,N}, ::NTuple{N, Bool}) where {N} = A
_maybe_reshape_parent(A::AbstractArray, ::NTuple{N, Bool}) where {N} = reshape(A, Val(N))
# The trailing singleton indices could be eliminated after bounds checking.
rm_singleton_indices(ndims::Tuple, J1, Js...) = (J1, rm_singleton_indices(IteratorsMD._splitrest(ndims, index_ndims(J1)), Js...)...)
rm_singleton_indices(::Tuple{}, ::ScalarIndex, Js...) = rm_singleton_indices((), Js...)
rm_singleton_indices(::Tuple) = ()

"""
view(A, inds...)
Expand Down Expand Up @@ -200,15 +205,12 @@ julia> view(2:5, 2:3) # returns a range as type is immutable
3:4
```
"""
function view(A::AbstractArray{<:Any,N}, I::Vararg{Any,M}) where {N,M}
function view(A::AbstractArray, I::Vararg{Any,M}) where {M}
@inline
J = map(i->unalias(A,i), to_indices(A, I))
@boundscheck checkbounds(A, J...)
if length(J) > ndims(A) && J[N+1:end] isa Tuple{Vararg{Int}}
# view([1,2,3], :, 1) does not need to reshape
return unsafe_view(A, J[1:N]...)
end
unsafe_view(_maybe_reshape_parent(A, index_ndims(J...)), J...)
J′ = rm_singleton_indices(ntuple(Returns(true), Val(ndims(A))), J...)
unsafe_view(_maybe_reshape_parent(A, index_ndims(J′...)), J′...)
end

# Ranges implement getindex to return recomputed ranges; use that for views, too (when possible)
Expand Down
13 changes: 10 additions & 3 deletions test/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -923,9 +923,9 @@ end

@testset "issue #41221: view(::Vector, :, 1)" begin
v = randn(3)
@test view(v,:,1) == v
@test parent(view(v,:,1)) === v
@test parent(view(v,2:3,1,1)) === v
@test @inferred(view(v,:,1)) == v
@test parent(@inferred(view(v,:,1))) === v
@test parent(@inferred(view(v,2:3,1,1))) === v
@test_throws BoundsError view(v,:,2)
@test_throws BoundsError view(v,:,1,2)

Expand All @@ -934,6 +934,13 @@ end
@test parent(view(m, 1:2, 3, 1, 1)) === m
end

@testset "issue #53209: avoid invalid elimination of singleton indices" begin
A = randn(4,5)
@test A[CartesianIndices(()), :, 3] == @inferred(view(A, CartesianIndices(()), :, 3))
@test parent(@inferred(view(A, :, 3, 1, CartesianIndices(()), 1))) === A
@test_throws BoundsError view(A, :, 3, 2, CartesianIndices(()), 1)
end

@testset "replace_in_print_matrix" begin
struct MyIdentity <: AbstractMatrix{Bool}
n :: Int
Expand Down

0 comments on commit 4d0a469

Please sign in to comment.