Skip to content

Commit

Permalink
faster repeat(::AbstractArray ...) (JuliaLang#35944)
Browse files Browse the repository at this point in the history
* faster repeat(::AbstractArray ...)

* Update base/abstractarraymath.jl

Co-authored-by: Matt Bauman <[email protected]>

* Update base/abstractarraymath.jl

Co-authored-by: Matt Bauman <[email protected]>

* Update base/abstractarraymath.jl

Co-authored-by: Matt Bauman <[email protected]>

* fix

* better repeat errors

Co-authored-by: Matt Bauman <[email protected]>
  • Loading branch information
jw3126 and mbauman committed Jun 15, 2020
1 parent 13b07fc commit 25dc696
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 94 deletions.
211 changes: 117 additions & 94 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -270,30 +270,8 @@ julia> repeat([1, 2, 3], 2, 3)
3 3 3
```
"""
repeat(a::AbstractArray, counts::Integer...) = repeat(a, outer = counts)

function repeat(a::AbstractVecOrMat, m::Integer, n::Integer=1)
o, p = size(a,1), size(a,2)
b = similar(a, o*m, p*n)
for j=1:n
d = (j-1)*p+1
R = d:d+p-1
for i=1:m
c = (i-1)*o+1
b[c:c+o-1, R] = a
end
end
return b
end

function repeat(a::AbstractVector, m::Integer)
o = length(a)
b = similar(a, o*m)
for i=1:m
c = (i-1)*o+1
b[c:c+o-1] = a
end
return b
function repeat(A::AbstractArray, counts...)
return _RepeatInnerOuter.repeat(A, outer=counts)
end

"""
Expand Down Expand Up @@ -330,89 +308,134 @@ julia> repeat([1 2; 3 4], inner=(2, 1), outer=(1, 3))
```
"""
function repeat(A::AbstractArray; inner = nothing, outer = nothing)
return _repeat_inner_outer(A, inner, outer)
return _RepeatInnerOuter.repeat(A, inner=inner, outer=outer)
end

module _RepeatInnerOuter

function repeat(arr; inner=nothing, outer=nothing)
check(arr, inner, outer)
arr, inner, outer = resolve(arr, inner, outer)
repeat_inner_outer(arr, inner, outer)
end

to_tuple(t::Tuple) = t
to_tuple(x::Integer) = (x,)
to_tuple(itr) = tuple(itr...)

function pad(a, b)
N = max(length(a), length(b))
Base.fill_to_length(a, 1, Val(N)), Base.fill_to_length(b, 1, Val(N))
end
function pad(a, b, c)
N = max(max(length(a), length(b)), length(c))
Base.fill_to_length(a, 1, Val(N)), Base.fill_to_length(b, 1, Val(N)), Base.fill_to_length(c, 1, Val(N))
end

function resolve(arr::AbstractArray{<:Any, N}, inner::NTuple{N, Any}, outer::NTuple{N,Any}) where {N}
arr, inner, outer
end
function resolve(arr, inner, outer)
dims, inner, outer = pad(size(arr), to_tuple(inner), to_tuple(outer))
reshape(arr, dims), inner, outer
end
function resolve(arr, inner::Nothing, outer::Nothing)
return arr, inner, outer
end
function resolve(arr, inner::Nothing, outer)
dims, outer = pad(size(arr), to_tuple(outer))
reshape(arr, dims), inner, outer
end
function resolve(arr, inner, outer::Nothing)
dims, inner = pad(size(arr), to_tuple(inner))
reshape(arr, dims), inner, outer
end

# we have optimized implementations of these cases above
_repeat_inner_outer(A::AbstractVecOrMat, ::Nothing, r::Union{Tuple{Integer},Tuple{Integer,Integer}}) = repeat(A, r...)
_repeat_inner_outer(A::AbstractVecOrMat, ::Nothing, r::Integer) = repeat(A, r)

_repeat_inner_outer(A, ::Nothing, ::Nothing) = A
_repeat_inner_outer(A, ::Nothing, outer) = _repeat(A, ntuple(n->1, Val(ndims(A))), rep_kw2tup(outer))
_repeat_inner_outer(A, inner, ::Nothing) = _repeat(A, rep_kw2tup(inner), ntuple(n->1, Val(ndims(A))))
_repeat_inner_outer(A, inner, outer) = _repeat(A, rep_kw2tup(inner), rep_kw2tup(outer))

rep_kw2tup(n::Integer) = (n,)
rep_kw2tup(v::AbstractArray{<:Integer}) = (v...,)
rep_kw2tup(t::Tuple) = t

rep_shapes(A, i, o) = _rshps((), (), size(A), i, o)

_rshps(shp, shp_i, ::Tuple{}, ::Tuple{}, ::Tuple{}) = (shp, shp_i)
@inline _rshps(shp, shp_i, ::Tuple{}, ::Tuple{}, o) =
_rshps((shp..., o[1]), (shp_i..., 1), (), (), tail(o))
@inline _rshps(shp, shp_i, ::Tuple{}, i, ::Tuple{}) = (n = i[1];
_rshps((shp..., n), (shp_i..., n), (), tail(i), ()))
@inline _rshps(shp, shp_i, ::Tuple{}, i, o) = (n = i[1];
_rshps((shp..., n * o[1]), (shp_i..., n), (), tail(i), tail(o)))
@inline _rshps(shp, shp_i, sz, i, o) = (n = sz[1] * i[1];
_rshps((shp..., n * o[1]), (shp_i..., n), tail(sz), tail(i), tail(o)))
_rshps(shp, shp_i, sz, ::Tuple{}, ::Tuple{}) =
(n = length(shp); N = n + length(sz); _reperr("inner", n, N))
_rshps(shp, shp_i, sz, ::Tuple{}, o) =
(n = length(shp); N = n + length(sz); _reperr("inner", n, N))
_rshps(shp, shp_i, sz, i, ::Tuple{}) =
(n = length(shp); N = n + length(sz); _reperr("outer", n, N))
_reperr(s, n, N) = throw(ArgumentError("number of " * s * " repetitions " *
"($n) cannot be less than number of dimensions of input ($N)"))

_negreperr(n) = throw(ArgumentError("number of $n repetitions" *
"cannot be negative"))

@noinline function _repeat(A::AbstractArray, inner, outer)
any(<(0), inner) && _negreperr("inner")
any(<(0), outer) && _negreperr("outer")

shape, inner_shape = rep_shapes(A, inner, outer)

R = similar(A, shape)
if any(iszero, shape)
return R
function check(arr, inner, outer)
if inner !== nothing
# TODO: Currently one based indexing is demanded for inner !== nothing,
# but not for outer !== nothing. Decide for something consistent.
Base.require_one_based_indexing(arr)
if any(<(0), inner)
throw(ArgumentError("no inner repetition count may be negative; got $inner"))
end
if length(inner) < ndims(arr)
throw(ArgumentError("number of inner repetitions ($(length(inner))) cannot be less than number of dimensions of input array ($(ndims(arr)))"))
end
end
if outer !== nothing
if any(<(0), outer)
throw(ArgumentError("no outer repetition count may be negative; got $outer"))
end
if (length(outer) < ndims(arr)) && (inner !== nothing)
throw(ArgumentError("number of outer repetitions ($(length(outer))) cannot be less than number of dimensions of input array ($(ndims(arr)))"))
end
end
end

# fill the first inner block
if all(isequal(1), inner)
idxs = (axes(A)..., ntuple(n->OneTo(1), ndims(R)-ndims(A))...) # keep dimension consistent
R[idxs...] = A
else
inner_indices = [1:n for n in inner]
for c in CartesianIndices(axes(A))
for i in 1:ndims(A)
n = inner[i]
inner_indices[i] = (1:n) .+ ((c[i] - 1) * n)
end
fill!(view(R, inner_indices...), A[c])
repeat_inner_outer(arr, inner::Nothing, outer::Nothing) = arr
repeat_inner_outer(arr, ::Nothing, outer) = repeat_outer(arr, outer)
repeat_inner_outer(arr, inner, ::Nothing) = repeat_inner(arr, inner)
repeat_inner_outer(arr, inner, outer) = repeat_outer(repeat_inner(arr, inner), outer)

function repeat_outer(a::AbstractMatrix, (m,n)::NTuple{2, Any})
o, p = size(a,1), size(a,2)
b = similar(a, o*m, p*n)
for j=1:n
d = (j-1)*p+1
R = d:d+p-1
for i=1:m
c = (i-1)*o+1
@inbounds b[c:c+o-1, R] = a
end
end
return b
end

# fill the outer blocks along each dimension
if all(isequal(1), outer)
return R
function repeat_outer(a::AbstractVector, (m,)::Tuple{Any})
o = length(a)
b = similar(a, o*m)
for i=1:m
c = (i-1)*o+1
@inbounds b[c:c+o-1] = a
end
src_indices = [1:n for n in inner_shape]
dest_indices = copy(src_indices)
for i in eachindex(outer)
B = view(R, src_indices...)
for j in 2:outer[i]
dest_indices[i] = dest_indices[i] .+ inner_shape[i]
R[dest_indices...] = B
return b
end

function repeat_outer(arr::AbstractArray{<:Any,N}, dims::NTuple{N,Any}) where {N}
insize = size(arr)
outsize = map(*, insize, dims)
out = similar(arr, outsize)
for I in CartesianIndices(arr)
for J in CartesianIndices(dims)
TIJ = map(Tuple(I), Tuple(J), insize) do i, j, d
i + d * (j-1)
end
IJ = CartesianIndex(TIJ)
@inbounds out[IJ] = arr[I]
end
src_indices[i] = dest_indices[i] = 1:shape[i]
end
return out
end

return R
function repeat_inner(arr, inner)
basedims = size(arr)
outsize = map(*, size(arr), inner)
out = similar(arr, outsize)
for I in CartesianIndices(arr)
for J in CartesianIndices(inner)
TIJ = map(Tuple(I), Tuple(J), inner) do i, j, d
(i-1) * d + j
end
IJ = CartesianIndex(TIJ)
@inbounds out[IJ] = arr[I]
end
end
return out
end

end#module

"""
eachrow(A::AbstractVecOrMat)
Expand Down
2 changes: 2 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ end
3 4], inner=(2, 2), outer=(2,))
@test_throws ArgumentError repeat([1, 2], inner=(1, -1), outer=(1, -1))

@test_throws ArgumentError repeat(OffsetArray(rand(2), 1), inner=(2,))

A = reshape(1:8, 2, 2, 2)
R = repeat(A, inner = (1, 1, 2), outer = (1, 1, 1))
T = reshape([1:4; 1:4; 5:8; 5:8], 2, 2, 4)
Expand Down

0 comments on commit 25dc696

Please sign in to comment.