Skip to content

Commit

Permalink
simplify reinterpret array code (#43955)
Browse files Browse the repository at this point in the history
Avoid one of the memcpy calls, when possible.
  • Loading branch information
vtjnash committed Feb 15, 2022
1 parent 3897667 commit 07f0fdb
Showing 1 changed file with 89 additions and 83 deletions.
172 changes: 89 additions & 83 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,11 @@ end

@inline @propagate_inbounds function getindex(a::ReshapedReinterpretArray{T,N,S}, ind::SCartesianIndex2) where {T,N,S}
check_readable(a)
n = sizeof(S) ÷ sizeof(T)
t = Ref{NTuple{n,T}}()
s = Ref{S}(a.parent[ind.j])
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
_memcpy!(tptr, sptr, sizeof(S))
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
return unsafe_load(tptr, ind.i)
end
return t[][ind.i]
end

@inline _memcpy!(dst, src, n) = ccall(:memcpy, Cvoid, (Ptr{UInt8}, Ptr{UInt8}, Csize_t), dst, src, n)
Expand All @@ -386,29 +382,37 @@ end
else
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}()
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
# Optimizations that avoid branches
if sizeof(T) % sizeof(S) == 0
# T is bigger than S and contains an integer number of them
n = sizeof(T) ÷ sizeof(S)
# Optimizations that avoid branches
if sizeof(T) % sizeof(S) == 0
# T is bigger than S and contains an integer number of them
n = sizeof(T) ÷ sizeof(S)
t = Ref{T}()
GC.@preserve t begin
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
for i = 1:n
s[] = a.parent[ind_start + i, tailinds...]
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
s = a.parent[ind_start + i, tailinds...]
unsafe_store!(sptr, s, i)
end
elseif sizeof(S) % sizeof(T) == 0
# S is bigger than T and contains an integer number of them
s[] = a.parent[ind_start + 1, tailinds...]
_memcpy!(tptr, sptr + sidx, sizeof(T))
else
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
end
return t[]
elseif sizeof(S) % sizeof(T) == 0
# S is bigger than T and contains an integer number of them
s = Ref{S}(a.parent[ind_start + 1, tailinds...])
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
return unsafe_load(tptr + sidx)
end
else
i = 1
nbytes_copied = 0
# This is a bit complicated to deal with partial elements
# at both the start and the end. LLVM will fold as appropriate,
# once it knows the data layout
s = Ref{S}()
t = Ref{T}()
GC.@preserve s t begin
sptr = Ptr{S}(unsafe_convert(Ref{S}, s))
tptr = Ptr{T}(unsafe_convert(Ref{T}, t))
while nbytes_copied < sizeof(T)
s[] = a.parent[ind_start + i, tailinds...]
nb = min(sizeof(S) - sidx, sizeof(T)-nbytes_copied)
Expand All @@ -418,8 +422,8 @@ end
i += 1
end
end
return t[]
end
return t[]
end
end

Expand All @@ -435,44 +439,39 @@ end
@boundscheck checkbounds(a, i1, tailinds...)
if sizeof(T) >= sizeof(S)
t = Ref{T}()
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
GC.@preserve t begin
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
if sizeof(T) > sizeof(S)
# Extra dimension in the parent array
n = sizeof(T) ÷ sizeof(S)
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
offset = n * (i1 - firstindex(a))
for i = 1:n
s[] = a.parent[i + offset]
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
s = a.parent[i + offset]
unsafe_store!(sptr, s, i)
end
else
for i = 1:n
s[] = a.parent[i, i1, tailinds...]
_memcpy!(tptr + (i-1)*sizeof(S), sptr, sizeof(S))
s = a.parent[i, i1, tailinds...]
unsafe_store!(sptr, s, i)
end
end
else
# No extra dimension
s[] = a.parent[i1, tailinds...]
_memcpy!(tptr, sptr, sizeof(S))
s = a.parent[i1, tailinds...]
unsafe_store!(sptr, s)
end
end
return t[]
end
# S is bigger than T and contains an integer number of them
n = sizeof(S) ÷ sizeof(T)
t = Ref{NTuple{n,T}}()
# n = sizeof(S) ÷ sizeof(T)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
s[] = a.parent[tailinds...]
_memcpy!(tptr, sptr, sizeof(S))
return unsafe_load(tptr, i1)
end
return t[][i1]
end

@inline @propagate_inbounds function setindex!(a::NonReshapedReinterpretArray{T,0,S}, v) where {T,S}
Expand Down Expand Up @@ -502,12 +501,10 @@ end
@inline @propagate_inbounds function setindex!(a::ReshapedReinterpretArray{T,N,S}, v, ind::SCartesianIndex2) where {T,N,S}
check_writable(a)
v = convert(T, v)::T
t = Ref{T}(v)
s = Ref{S}(a.parent[ind.j])
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
_memcpy!(sptr + (ind.i-1)*sizeof(T), tptr, sizeof(T))
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
unsafe_store!(tptr, v, ind.i)
end
a.parent[ind.j] = s[]
return a
Expand All @@ -526,25 +523,32 @@ end
else
@boundscheck checkbounds(a, i1, tailinds...)
ind_start, sidx = divrem((i1-1)*sizeof(T), sizeof(S))
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
# Optimizations that avoid branches
if sizeof(T) % sizeof(S) == 0
# T is bigger than S and contains an integer number of them
# Optimizations that avoid branches
if sizeof(T) % sizeof(S) == 0
# T is bigger than S and contains an integer number of them
t = Ref{T}(v)
GC.@preserve t begin
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
n = sizeof(T) ÷ sizeof(S)
for i = 0:n-1
_memcpy!(sptr, tptr + i*sizeof(S), sizeof(S))
a.parent[ind_start + i + 1, tailinds...] = s[]
for i = 1:n
s = unsafe_load(sptr, i)
a.parent[ind_start + i, tailinds...] = s
end
elseif sizeof(S) % sizeof(T) == 0
# S is bigger than T and contains an integer number of them
s[] = a.parent[ind_start + 1, tailinds...]
_memcpy!(sptr + sidx, tptr, sizeof(T))
end
elseif sizeof(S) % sizeof(T) == 0
# S is bigger than T and contains an integer number of them
s = Ref{S}(a.parent[ind_start + 1, tailinds...])
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
unsafe_store!(tptr + sidx, v)
a.parent[ind_start + 1, tailinds...] = s[]
else
end
else
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
nbytes_copied = 0
i = 1
# Deal with any partial elements at the start. We'll have to copy in the
Expand Down Expand Up @@ -591,36 +595,38 @@ end
end
end
@boundscheck checkbounds(a, i1, tailinds...)
t = Ref{T}(v)
s = Ref{S}()
GC.@preserve t s begin
tptr = Ptr{UInt8}(unsafe_convert(Ref{T}, t))
sptr = Ptr{UInt8}(unsafe_convert(Ref{S}, s))
if sizeof(T) >= sizeof(S)
if sizeof(T) >= sizeof(S)
t = Ref{T}(v)
GC.@preserve t begin
sptr = Ptr{S}(unsafe_convert(Ref{T}, t))
if sizeof(T) > sizeof(S)
# Extra dimension in the parent array
n = sizeof(T) ÷ sizeof(S)
if isempty(tailinds) && IndexStyle(a.parent) === IndexLinear()
offset = n * (i1 - firstindex(a))
for i = 1:n
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
a.parent[i + offset] = s[]
s = unsafe_load(sptr, i)
a.parent[i + offset] = s
end
else
for i = 1:n
_memcpy!(sptr, tptr + (i-1)*sizeof(S), sizeof(S))
a.parent[i, i1, tailinds...] = s[]
s = unsafe_load(sptr, i)
a.parent[i, i1, tailinds...] = s
end
end
else
else # sizeof(T) == sizeof(S)
# No extra dimension
_memcpy!(sptr, tptr, sizeof(S))
a.parent[i1, tailinds...] = s[]
s = unsafe_load(sptr)
a.parent[i1, tailinds...] = s
end
else
# S is bigger than T and contains an integer number of them
end
else
# S is bigger than T and contains an integer number of them
s = Ref{S}()
GC.@preserve s begin
tptr = Ptr{T}(unsafe_convert(Ref{S}, s))
s[] = a.parent[tailinds...]
_memcpy!(sptr + (i1-1)*sizeof(T), tptr, sizeof(T))
unsafe_store!(tptr, v, i1)
a.parent[tailinds...] = s[]
end
end
Expand Down

0 comments on commit 07f0fdb

Please sign in to comment.