Skip to content

Commit

Permalink
Enable stride/strides to work on a ReinterpretArray with complex Stri…
Browse files Browse the repository at this point in the history
…ded parents. (JuliaLang#37414)
  • Loading branch information
BioTurboNick committed Sep 9, 2020
1 parent 18198b1 commit 01b29ec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
13 changes: 12 additions & 1 deletion base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}
# a tuple containing 1 and a cumulative product of the first N-1 sizes
# this definition is also used for StridedReshapedArray and StridedReinterpretedArray
# which have the same memory storage as Array
function stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, i::Int)
stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, i::Int) = _stride(a, i)

function stride(a::ReinterpretArray, i::Int)
a.parent isa StridedArray || ArgumentError("Parent must be strided.") |> throw
return _stride(a, i)
end

function _stride(a, i)
if i > ndims(a)
return length(a)
end
Expand All @@ -73,6 +80,10 @@ function stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray
return s
end

function strides(a::ReinterpretArray)
a.parent isa StridedArray || ArgumentError("Parent must be strided.") |> throw
size_to_strides(1, size(a)...)
end
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)

function check_readable(a::ReinterpretArray{T, N, S} where N) where {T,S}
Expand Down
8 changes: 8 additions & 0 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ let A = collect(reshape(1:20, 5, 4))
@test reshape(R, :) isa StridedArray
end

# and ensure a reinterpret array containing a strided array can have strides computed
let A = view(reinterpret(Int16, collect(reshape(UnitRange{Int64}(1, 20), 5, 4))), :, 1:2)
R = reinterpret(Int32, A)
@test strides(R) == (1, 10)
@test stride(R, 1) == 1
@test stride(R, 2) == 10
end

@testset "strides" begin
a = rand(10)
b = view(a,2:2:10)
Expand Down

0 comments on commit 01b29ec

Please sign in to comment.