Skip to content

Commit

Permalink
Fix pointer calculation for SubArray with none-dense parent. (Julia…
Browse files Browse the repository at this point in the history
…Lang#51900)

And code clean for `first_index` and `compute_linindex`:
1. call `compute_linindex` directly in `first_index(::SlowSubArray)`.
(There's no need to calculate stride/offset.)
2. remove the uneeded `compute_linindex` dispatch
(`first(x::ScalarIndex) == x`)
  • Loading branch information
N5N3 committed Oct 28, 2023
1 parent f106bd9 commit 0a0bd00
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 42 deletions.
7 changes: 6 additions & 1 deletion base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
function unsafe_convert(::Type{Ptr{S}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {S,T,N,P}
parent = V.parent
p = cconvert(Ptr{T}, parent) # XXX: this should occur in cconvert, the result is not GC-rooted
return Ptr{S}(unsafe_convert(Ptr{T}, p) + (first_index(V)-1)*sizeof(T))
Δmem = if _checkcontiguous(Bool, parent)
(first_index(V) - firstindex(parent)) * elsize(parent)
else
_memory_offset(parent, map(first, V.indices)...)
end
return Ptr{S}(unsafe_convert(Ptr{T}, p) + Δmem)
end

_checkcontiguous(::Type{Bool}, A::AbstractArray) = false
Expand Down
20 changes: 2 additions & 18 deletions base/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,12 +416,8 @@ iscontiguous(A::SubArray) = iscontiguous(typeof(A))
iscontiguous(::Type{<:SubArray}) = false
iscontiguous(::Type{<:FastContiguousSubArray}) = true

first_index(V::FastSubArray) = V.offset1 + V.stride1 # cached for fast linear SubArrays
function first_index(V::SubArray)
P, I = parent(V), V.indices
s1 = compute_stride1(P, I)
s1 + compute_offset1(P, s1, I)
end
first_index(V::FastSubArray) = V.offset1 + V.stride1 * firstindex(V) # cached for fast linear SubArrays
first_index(V::SubArray) = compute_linindex(parent(V), V.indices)

# Computing the first index simply steps through the indices, accumulating the
# sum of index each multiplied by the parent's stride.
Expand All @@ -447,11 +443,6 @@ function compute_linindex(parent, I::NTuple{N,Any}) where N
IP = fill_to_length(axes(parent), OneTo(1), Val(N))
compute_linindex(first(LinearIndices(parent)), 1, IP, I)
end
function compute_linindex(f, s, IP::Tuple, I::Tuple{ScalarIndex, Vararg{Any}})
@inline
Δi = I[1]-first(IP[1])
compute_linindex(f + Δi*s, s*length(IP[1]), tail(IP), tail(I))
end
function compute_linindex(f, s, IP::Tuple, I::Tuple{Any, Vararg{Any}})
@inline
Δi = first(I[1])-first(IP[1])
Expand All @@ -466,13 +457,6 @@ find_extended_inds(::ScalarIndex, I...) = (@inline; find_extended_inds(I...))
find_extended_inds(i1, I...) = (@inline; (i1, find_extended_inds(I...)...))
find_extended_inds() = ()

# cconvert(::Type{<:Ptr}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {T,N,P} = V
function unsafe_convert(::Type{Ptr{S}}, V::SubArray{T,N,P,<:Tuple{Vararg{RangeIndex}}}) where {S,T,N,P}
parent = V.parent
p = cconvert(Ptr{T}, parent) # XXX: this should occur in cconvert, the result is not GC-rooted
return Ptr{S}(unsafe_convert(Ptr{T}, p) + _memory_offset(parent, map(first, V.indices)...))
end

pointer(V::FastSubArray, i::Int) = pointer(V.parent, V.offset1 + V.stride1*i)
pointer(V::FastContiguousSubArray, i::Int) = pointer(V.parent, V.offset1 + i)

Expand Down
56 changes: 33 additions & 23 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1759,47 +1759,50 @@ module IRUtils
include("compiler/irutils.jl")
end

@testset "strides for ReshapedArray" begin
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
function check_pointer_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
# Test pointer via value check.
first(A) === Base.unsafe_load(pointer(A)) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end

@testset "strides for ReshapedArray" begin
# Type-based contiguous Check
a = vec(reinterpret(reshape, Int16, reshape(view(reinterpret(Int32, randn(10)), 2:11), 5, :)))
f(a) = only(strides(a));
@test IRUtils.fully_eliminated(f, Base.typesof(a)) && f(a) == 1
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test check_strides(vec(a))
@test check_pointer_strides(vec(a))
b = view(parent(a), 1:9, 1:10)
@test_throws "Input is not strided." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test check_strides(reshape(a, 3, 4, 5))
@test check_strides(reshape(a, 5, 6, 2))
@test check_pointer_strides(reshape(a, 3, 4, 5))
@test check_pointer_strides(reshape(a, 5, 6, 2))
b = view(parent(a), 60n:-n:1)
@test check_strides(reshape(b, 3, 4, 5))
@test check_strides(reshape(b, 5, 6, 2))
@test check_pointer_strides(reshape(b, 3, 4, 5))
@test check_pointer_strides(reshape(b, 5, 6, 2))
end
# StridedVector like parent
a = randn(10, 10, 10)
b = view(a, 1:10, 1:1, 5:5)
@test check_strides(reshape(b, 2, 5))
@test check_pointer_strides(reshape(b, 2, 5))
# Other StridedArray parent
a = view(randn(10,10), 1:9, 1:10)
@test check_strides(reshape(a,3,3,2,5))
@test check_strides(reshape(a,3,3,5,2))
@test check_strides(reshape(a,9,5,2))
@test check_strides(reshape(a,3,3,10))
@test check_strides(reshape(a,1,3,1,3,1,5,1,2))
@test check_strides(reshape(a,3,3,5,1,1,2,1,1))
@test check_pointer_strides(reshape(a,3,3,2,5))
@test check_pointer_strides(reshape(a,3,3,5,2))
@test check_pointer_strides(reshape(a,9,5,2))
@test check_pointer_strides(reshape(a,3,3,10))
@test check_pointer_strides(reshape(a,1,3,1,3,1,5,1,2))
@test check_pointer_strides(reshape(a,3,3,5,1,1,2,1,1))
@test_throws "Input is not strided." strides(reshape(a,3,6,5))
@test_throws "Input is not strided." strides(reshape(a,3,2,3,5))
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
Expand All @@ -1812,7 +1815,14 @@ end
@test @inferred(strides(a)) == (1, 1, 1)
# Dense parent (but not StridedArray)
A = reinterpret(Int8, reinterpret(reshape, Int16, rand(Int8, 2, 3, 3)))
@test check_strides(reshape(A, 3, 2, 3))
@test check_pointer_strides(reshape(A, 3, 2, 3))
end

@testset "pointer for SubArray with none-dense parent." begin
a = view(Matrix(reshape(0x01:0xc8, 20, :)), 1:2:20, :)
b = reshape(a, 20, :)
@test check_pointer_strides(view(b, 2:11, 1:5))
@test check_pointer_strides(view(b, reshape(2:11, 2, :), 1:5))
end

@testset "stride for 0 dims array #44087" begin
Expand Down
6 changes: 6 additions & 0 deletions test/subarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -800,3 +800,9 @@ end
V = view(OneElVec(6, 2), 1:5)
@test sprint(show, "text/plain", V) == "$(summary(V)):\n\n 1\n\n\n"
end

@testset "Base.first_index for offset indices" begin
a = Vector(1:10)
b = view(a, Base.IdentityUnitRange(4:7))
@test first(b) == a[Base.first_index(b)]
end

0 comments on commit 0a0bd00

Please sign in to comment.