Skip to content

Commit

Permalink
Make StridedReinterpretArray's get/setindex pointer based. (Julia…
Browse files Browse the repository at this point in the history
…Lang#44186)

This PR makes `StridedReinterpretArray`'s `get/setindex` purely pointer based if its root storage is a `Array`/`Memory`.
The generated IR would be simpler and (hopefully) easier to optimize.

TODO: LLVM's LV dislikes GC preserved `MemoryRef`, reinterpreted `Array`s might block auto vectorization.

---------

Co-authored-by: Gabriel Baraldi <[email protected]>
Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
3 people committed Nov 8, 2023
1 parent 8f8b9ca commit 1972432
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 109 deletions.
62 changes: 44 additions & 18 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,23 +352,32 @@ has_offset_axes(a::ReinterpretArray) = has_offset_axes(a.parent)
elsize(::Type{<:ReinterpretArray{T}}) where {T} = sizeof(T)
cconvert(::Type{Ptr{T}}, a::ReinterpretArray{T,N,S} where N) where {T,S} = cconvert(Ptr{S}, a.parent)

@inline @propagate_inbounds function getindex(a::NonReshapedReinterpretArray{T,0,S}) where {T,S}
@propagate_inbounds function getindex(a::NonReshapedReinterpretArray{T,0,S}) where {T,S}
if isprimitivetype(T) && isprimitivetype(S)
reinterpret(T, a.parent[])
else
a[firstindex(a)]
end
end

@inline @propagate_inbounds getindex(a::ReinterpretArray) = a[firstindex(a)]
check_ptr_indexable(a::ReinterpretArray, sz = elsize(a)) = check_ptr_indexable(parent(a), sz)
check_ptr_indexable(a::ReshapedArray, sz) = check_ptr_indexable(parent(a), sz)
check_ptr_indexable(a::FastContiguousSubArray, sz) = check_ptr_indexable(parent(a), sz)
check_ptr_indexable(a::Array, sz) = sizeof(eltype(a)) !== sz
check_ptr_indexable(a::Memory, sz) = true
check_ptr_indexable(a::AbstractArray, sz) = false

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
@propagate_inbounds getindex(a::ReinterpretArray) = a[firstindex(a)]

@propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, inds::Vararg{Int, N}) where {T,N,S}
check_readable(a)
check_ptr_indexable(a) && return _getindex_ptr(a, inds...)
_getindex_ra(a, inds[1], tail(inds))
end

@inline @propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, i::Int) where {T,N,S}
@propagate_inbounds function getindex(a::ReinterpretArray{T,N,S}, i::Int) where {T,N,S}
check_readable(a)
check_ptr_indexable(a) && return _getindex_ptr(a, i)
if isa(IndexStyle(a), IndexLinear)
return _getindex_ra(a, i, ())
end
Expand All @@ -378,16 +387,22 @@ end
isempty(inds) ? _getindex_ra(a, 1, ()) : _getindex_ra(a, inds[1], tail(inds))
end

@inline @propagate_inbounds function getindex(a::ReshapedReinterpretArray{T,N,S}, ind::SCartesianIndex2) where {T,N,S}
@propagate_inbounds function getindex(a::ReshapedReinterpretArray{T,N,S}, ind::SCartesianIndex2) where {T,N,S}
check_readable(a)
s = Ref{S}(a.parent[ind.j])
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
return unsafe_load(tptr, ind.i)
end
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
GC.@preserve s return unsafe_load(tptr, ind.i)
end

@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
@inline function _getindex_ptr(a::ReinterpretArray{T}, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
ap = cconvert(Ptr{T}, a)
p = unsafe_convert(Ptr{T}, ap) + sizeof(T) * (li - 1)
GC.@preserve ap return unsafe_load(p)
end

@propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if issingletontype(T) # singleton types
Expand Down Expand Up @@ -443,7 +458,7 @@ end
end
end

@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
@propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if issingletontype(T) # singleton types
Expand Down Expand Up @@ -490,31 +505,33 @@ end
end
end

@inline @propagate_inbounds function setindex!(a::NonReshapedReinterpretArray{T,0,S}, v) where {T,S}
@propagate_inbounds function setindex!(a::NonReshapedReinterpretArray{T,0,S}, v) where {T,S}
if isprimitivetype(S) && isprimitivetype(T)
a.parent[] = reinterpret(S, v)
return a
end
setindex!(a, v, firstindex(a))
end

@inline @propagate_inbounds setindex!(a::ReinterpretArray, v) = setindex!(a, v, firstindex(a))
@propagate_inbounds setindex!(a::ReinterpretArray, v) = setindex!(a, v, firstindex(a))

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
@propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, inds::Vararg{Int, N}) where {T,N,S}
check_writable(a)
check_ptr_indexable(a) && return _setindex_ptr!(a, v, inds...)
_setindex_ra!(a, v, inds[1], tail(inds))
end

@inline @propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, i::Int) where {T,N,S}
@propagate_inbounds function setindex!(a::ReinterpretArray{T,N,S}, v, i::Int) where {T,N,S}
check_writable(a)
check_ptr_indexable(a) && return _setindex_ptr!(a, v, i)
if isa(IndexStyle(a), IndexLinear)
return _setindex_ra!(a, v, i, ())
end
inds = _to_subscript_indices(a, i)
_setindex_ra!(a, v, inds[1], tail(inds))
end

@inline @propagate_inbounds function setindex!(a::ReshapedReinterpretArray{T,N,S}, v, ind::SCartesianIndex2) where {T,N,S}
@propagate_inbounds function setindex!(a::ReshapedReinterpretArray{T,N,S}, v, ind::SCartesianIndex2) where {T,N,S}
check_writable(a)
v = convert(T, v)::T
s = Ref{S}(a.parent[ind.j])
Expand All @@ -526,7 +543,16 @@ end
return a
end

@inline @propagate_inbounds function _setindex_ra!(a::NonReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
@inline function _setindex_ptr!(a::ReinterpretArray{T}, v, inds...) where {T}
@boundscheck checkbounds(a, inds...)
li = _to_linear_index(a, inds...)
ap = cconvert(Ptr{T}, a)
p = unsafe_convert(Ptr{T}, ap) + sizeof(T) * (li - 1)
GC.@preserve ap unsafe_store!(p, v)
return a
end

@propagate_inbounds function _setindex_ra!(a::NonReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
Expand Down Expand Up @@ -599,7 +625,7 @@ end
return a
end

@inline @propagate_inbounds function _setindex_ra!(a::ReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
@propagate_inbounds function _setindex_ra!(a::ReshapedReinterpretArray{T,N,S}, v, i1::Int, tailinds::TT) where {T,N,S,TT}
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
Expand Down
Loading

0 comments on commit 1972432

Please sign in to comment.