Skip to content

Commit

Permalink
Make read!(::IO,::StridedArray) call unsafe_read if possible (Jul…
Browse files Browse the repository at this point in the history
…iaLang#42593)

This PR makes `read!(s::IO,a::StridedArray)` call `unsafe_read`, if `a`
has continous memory layout.
Since most of the checks are type-based, I think this does speed up,
esspecially when `lock = true`.

BTW, the above dispatch for `read!(s::IO, a::Array{UInt8})` in L762
seems unnecessary, should we drop it?

---------

Co-authored-by: Jameson Nash <[email protected]>
  • Loading branch information
N5N3 and vtjnash committed Oct 25, 2023
1 parent a1ccf53 commit 1705fe8
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 32 deletions.
103 changes: 71 additions & 32 deletions base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -794,38 +794,43 @@ function write(s::IO, A::AbstractArray)
error("`write` is not supported on non-isbits arrays")
end
nb = 0
r = Ref{eltype(A)}()
for a in A
nb += write(s, a)
r[] = a
nb += @noinline unsafe_write(s, r, sizeof(r)) # r must be heap-allocated
end
return nb
end

function write(s::IO, a::Array)
if isbitstype(eltype(a))
return GC.@preserve a unsafe_write(s, pointer(a), sizeof(a))
else
function write(s::IO, A::StridedArray)
if !isbitstype(eltype(A))
error("`write` is not supported on non-isbits arrays")
end
end

function write(s::IO, a::SubArray{T,N,<:Array}) where {T,N}
if !isbitstype(T) || !isa(a, StridedArray)
return invoke(write, Tuple{IO, AbstractArray}, s, a)
_checkcontiguous(Bool, A) &&
return GC.@preserve A unsafe_write(s, pointer(A), elsize(A) * length(A))
sz::Dims = size(A)
st::Dims = strides(A)
msz, mst, n = merge_adjacent_dim(sz, st)
mst == 1 || return invoke(write, Tuple{IO, AbstractArray}, s, A)
n == ndims(A) &&
return GC.@preserve A unsafe_write(s, pointer(A), elsize(A) * length(A))
sz′, st′ = tail(sz), tail(st)
while n > 1
sz′ = (tail(sz′)..., 1)
st′ = (tail(st′)..., 0)
n -= 1
end
elsz = elsize(a)
colsz = size(a,1) * elsz
GC.@preserve a if stride(a,1) != 1
for idxs in CartesianIndices(size(a))
unsafe_write(s, pointer(a, idxs), elsz)
end
return elsz * length(a)
elseif N <= 1
return unsafe_write(s, pointer(a, 1), colsz)
else
for colstart in CartesianIndices((1, size(a)[2:end]...))
unsafe_write(s, pointer(a, colstart), colsz)
GC.@preserve A begin
nb = 0
iter = CartesianIndices(sz′)
for I in iter
p = pointer(A)
for i in 1:length(sz′)
p += elsize(A) * st′[i] * (I[i] - 1)
end
nb += unsafe_write(s, p, elsize(A) * msz)
end
return colsz * trailingsize(a,2)
return nb
end
end

Expand Down Expand Up @@ -866,20 +871,54 @@ end
read(s::IO, ::Type{Bool}) = (read(s, UInt8) != 0)
read(s::IO, ::Type{Ptr{T}}) where {T} = convert(Ptr{T}, read(s, UInt))

function read!(s::IO, a::Array{UInt8})
GC.@preserve a unsafe_read(s, pointer(a), sizeof(a))
return a
function read!(s::IO, A::AbstractArray{T}) where {T}
if isbitstype(T) && _checkcontiguous(Bool, A)
GC.@preserve A unsafe_read(s, pointer(A), elsize(A) * length(A))
else
if isbitstype(T)
r = Ref{T}()
for i in eachindex(A)
@noinline unsafe_read(s, r, sizeof(r)) # r must be heap-allocated
A[i] = r[]
end
else
for i in eachindex(A)
A[i] = read(s, T)
end
end
end
return A
end

function read!(s::IO, a::AbstractArray{T}) where T
if isbitstype(T) && (a isa Array || a isa FastContiguousSubArray{T,<:Any,<:Array{T}})
GC.@preserve a unsafe_read(s, pointer(a), sizeof(a))
function read!(s::IO, A::StridedArray{T}) where {T}
if !isbitstype(T) || _checkcontiguous(Bool, A)
return invoke(read!, Tuple{IO, AbstractArray}, s, A)
end
sz::Dims = size(A)
st::Dims = strides(A)
msz, mst, n = merge_adjacent_dim(sz, st)
mst == 1 || return invoke(read!, Tuple{IO, AbstractArray}, s, A)
if n == ndims(A)
GC.@preserve A unsafe_read(s, pointer(A), elsize(A) * length(A))
else
for i in eachindex(a)
a[i] = read(s, T)
sz′, st′ = tail(sz), tail(st)
while n > 1
sz′ = (tail(sz′)..., 1)
st′ = (tail(st′)..., 0)
n -= 1
end
GC.@preserve A begin
iter = CartesianIndices(sz′)
for I in iter
p = pointer(A)
for i in 1:length(sz′)
p += elsize(A) * st′[i] * (I[i] - 1)
end
unsafe_read(s, p, elsize(A) * msz)
end
end
end
return a
return A
end

function read(io::IO, ::Type{Char})
Expand Down
18 changes: 18 additions & 0 deletions test/iostream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ end
end
end

@testset "read!/write(::IO, A::StridedArray)" begin
s1 = reshape(view(rand(UInt8, 16), 1:16), 2, 2, 2, 2)
s2 = view(s1, 1:2, 1:2, 1:2, 1:2)
s3 = view(s1, 1:2, 1:2, 1, 1:2)
mktemp() do path, io
b = Vector{UInt8}(undef, 17)
for s::StridedArray in (s3, s1, s2)
@test write(io, s) == length(s)
seek(io, 0)
@test readbytes!(io, b) == length(s)
seek(io, 0)
@test view(b, 1:length(s)) == vec(s)
@test read!(io, fill!(deepcopy(s), 0)) == s
seek(io, 0)
end
end
end

@test Base.open_flags(read=false, write=true, append=false) == (read=false, write=true, create=true, truncate=true, append=false)

@testset "issue #30978" begin
Expand Down

0 comments on commit 1705fe8

Please sign in to comment.