Skip to content

Commit

Permalink
Speed up mapslices (#40996)
Browse files Browse the repository at this point in the history
* renovate mapslices
  • Loading branch information
mcabbott committed May 31, 2022
1 parent f5ebcc9 commit 3eaed8b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 89 deletions.
194 changes: 109 additions & 85 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2801,134 +2801,158 @@ foreach(f, itrs...) = (for z in zip(itrs...); f(z...); end; nothing)
"""
mapslices(f, A; dims)
Transform the given dimensions of array `A` using function `f`. `f` is called on each slice
of `A` of the form `A[...,:,...,:,...]`. `dims` is an integer vector specifying where the
colons go in this expression. The results are concatenated along the remaining dimensions.
For example, if `dims` is `[1,2]` and `A` is 4-dimensional, `f` is called on `A[:,:,i,j]`
for all `i` and `j`.
Transform the given dimensions of array `A` by applying a function `f` on each slice
of the form `A[..., :, ..., :, ...]`, with a colon at each `d` in `dims`. The results are
concatenated along the remaining dimensions.
See also [`eachcol`](@ref), [`eachslice`](@ref).
For example, if `dims = [1,2]` and `A` is 4-dimensional, then `f` is called on `x = A[:,:,i,j]`
for all `i` and `j`, and `f(x)` becomes `R[:,:,i,j]` in the result `R`.
See also [`eachcol`](@ref), [`eachslice`](@ref), [`mapreduce`](@ref).
# Examples
```jldoctest
julia> a = reshape(Vector(1:16),(2,2,2,2))
2×2×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
1 3
2 4
julia> A = reshape(1:30,(2,5,3))
2×5×3 reshape(::UnitRange{$Int}, 2, 5, 3) with eltype $Int:
[:, :, 1] =
1 3 5 7 9
2 4 6 8 10
[:, :, 2] =
11 13 15 17 19
12 14 16 18 20
[:, :, 3] =
21 23 25 27 29
22 24 26 28 30
[:, :, 2, 1] =
5 7
6 8
julia> f(x::Matrix) = fill(x[1,1], 1,4); # returns a 1×4 matrix
[:, :, 1, 2] =
9 11
10 12
julia> mapslices(f, A, dims=(1,2))
1×4×3 Array{$Int, 3}:
[:, :, 1] =
1 1 1 1
[:, :, 2] =
11 11 11 11
[:, :, 2, 2] =
13 15
14 16
[:, :, 3] =
21 21 21 21
julia> mapslices(sum, a, dims = [1,2])
1×1×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
10
julia> g(x) = x[begin] // x[end-1]; # returns a number
[:, :, 2, 1] =
26
julia> mapslices(g, A, dims=[1,3])
1×5×1 Array{Rational{$Int}, 3}:
[:, :, 1] =
1//21 3//23 1//5 7//27 9//29
[:, :, 1, 2] =
42
julia> map(g, eachslice(A, dims=2))
5-element Vector{Rational{$Int}}:
1//21
3//23
1//5
7//27
9//29
[:, :, 2, 2] =
58
julia> mapslices(sum, A; dims=(1,3)) == sum(A; dims=(1,3))
true
```
Notice that in `eachslice(A; dims=2)`, the specified dimension is the
one *without* a colon in the slice. This is `view(A,:,i,:)`, whereas
`mapslices(f, A; dims=(1,3))` uses `A[:,i,:]`. The function `f` may mutate
values in the slice without affecting `A`.
"""
function mapslices(f, A::AbstractArray; dims)
if isempty(dims)
return map(f,A)
end
if !isa(dims, AbstractVector)
dims = [dims...]
end
isempty(dims) && return map(f, A)

dimsA = [axes(A)...]
ndimsA = ndims(A)
alldims = [1:ndimsA;]

otherdims = setdiff(alldims, dims)

idx = Any[first(ind) for ind in axes(A)]
itershape = tuple(dimsA[otherdims]...)
for d in dims
idx[d] = Slice(axes(A, d))
d isa Integer || throw(ArgumentError("mapslices: dimension must be an integer, got $d"))
d >= 1 || throw(ArgumentError("mapslices: dimension must be ≥ 1, got $d"))
# Indexing a matrix M[:,1,:] produces a 1-column matrix, but dims=(1,3) here
# would otherwise ignore 3, and slice M[:,i]. Previously this gave error:
# BoundsError: attempt to access 2-element Vector{Any} at index [3]
d > ndims(A) && throw(ArgumentError("mapslices does not accept dimensions > ndims(A) = $(ndims(A)), got $d"))
end
dim_mask = ntuple(d -> d in dims, ndims(A))

# Apply the function to the first slice in order to determine the next steps
Aslice = A[idx...]
idx1 = ntuple(d -> d in dims ? (:) : firstindex(A,d), ndims(A))
Aslice = A[idx1...]
r1 = f(Aslice)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(Aslice, StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))

# determine result size and allocate
Rsize = copy(dimsA)
# TODO: maybe support removing dimensions
if !isa(r1, AbstractArray) || ndims(r1) == 0
res1 = if r1 isa AbstractArray && ndims(r1) > 0
n = sum(dim_mask)
if ndims(r1) > n && any(ntuple(d -> size(r1,d+n)>1, ndims(r1)-n))
s = size(r1)[1:n]
throw(DimensionMismatch("mapslices cannot assign slice f(x) of size $(size(r1)) into output of size $s"))
end
r1
else
# If the result of f on a single slice is a scalar then we add singleton
# dimensions. When adding the dimensions, we have to respect the
# index type of the input array (e.g. in the case of OffsetArrays)
tmp = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
tmp[firstindex(tmp)] = r1
r1 = tmp
end
nextra = max(0, length(dims)-ndims(r1))
if eltype(Rsize) == Int
Rsize[dims] = [size(r1)..., ntuple(Returns(1), nextra)...]
else
Rsize[dims] = [axes(r1)..., ntuple(Returns(OneTo(1)), nextra)...]
_res1 = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
_res1[begin] = r1
_res1
end
R = similar(r1, tuple(Rsize...,))

ridx = Any[map(first, axes(R))...]
for d in dims
ridx[d] = axes(R,d)
# Determine result size and allocate. We always pad ndims(res1) out to length(dims):
din = Ref(0)
Rsize = ntuple(ndims(A)) do d
if d in dims
axes(res1, din[] += 1)
else
axes(A,d)
end
end
R = similar(res1, Rsize)

# Determine iteration space. It will be convenient in the loop to mask N-dimensional
# CartesianIndices, with some trivial dimensions:
itershape = ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))
indices = Iterators.drop(CartesianIndices(itershape), 1)

# That skips the first element, which we already have:
ridx = ntuple(d -> d in dims ? Slice(axes(R,d)) : firstindex(A,d), ndims(A))
concatenate_setindex!(R, res1, ridx...)

concatenate_setindex!(R, r1, ridx...)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(Aslice, StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))

nidx = length(otherdims)
indices = Iterators.drop(CartesianIndices(itershape), 1) # skip the first element, we already handled it
inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
_inner_mapslices!(R, indices, f, A, dim_mask, Aslice, safe_for_reuse)
return R
end

@noinline function inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
@noinline function _inner_mapslices!(R, indices, f, A, dim_mask, Aslice, safe_for_reuse)
must_extend = any(dim_mask .& size(R) .> 1)
if safe_for_reuse
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
# so we can reuse Aslice
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
idx = ifelse.(dim_mask, Slice.(axes(A)), Tuple(I))
_unsafe_getindex!(Aslice, A, idx...)
concatenate_setindex!(R, f(Aslice), ridx...)
r = f(Aslice)
if r isa AbstractArray || must_extend
ridx = ifelse.(dim_mask, Slice.(axes(R)), Tuple(I))
R[ridx...] = r
else
ridx = ifelse.(dim_mask, first.(axes(R)), Tuple(I))
R[ridx...] = r
end
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
idx = ifelse.(dim_mask, Slice.(axes(A)), Tuple(I))
ridx = ifelse.(dim_mask, Slice.(axes(R)), Tuple(I))
concatenate_setindex!(R, f(A[idx...]), ridx...)
end
end

return R
end

function replace_tuples!(nidx, idx, ridx, otherdims, I)
for i in 1:nidx
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
end
end

concatenate_setindex!(R, v, I...) = (R[I...] .= (v,); R)
Expand Down
27 changes: 23 additions & 4 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,6 @@ end
@test mapslices(prod,["1"],dims=1) == ["1"]

# issue #5177

c = fill(1,2,3,4)
m1 = mapslices(_ -> fill(1,2,3), c, dims=[1,2])
m2 = mapslices(_ -> fill(1,2,4), c, dims=[1,3])
Expand All @@ -1196,9 +1195,29 @@ end
@test o == fill(1, 3, 4)

# issue #18524
m = mapslices(x->tuple(x), [1 2; 3 4], dims=1)
@test m[1,1] == ([1,3],)
@test m[1,2] == ([2,4],)
# m = mapslices(x->tuple(x), [1 2; 3 4], dims=1) # see variations of this below
# ERROR: fatal error in type inference (type bound), https://github.com/JuliaLang/julia/issues/43064
# @test m[1,1] == ([1,3],)
# @test m[1,2] == ([2,4],)

r = rand(Int8, 4,5,2)
@test vec(mapslices(repr, r, dims=(2,1))) == map(repr, eachslice(r, dims=3))
@test mapslices(tuple, [1 2; 3 4], dims=1) == [([1, 3],) ([2, 4],)]
@test mapslices(transpose, r, dims=(1,3)) == permutedims(r, (3,2,1))

# failures
@test_broken @inferred(mapslices(tuple, [1 2; 3 4], dims=1)) == [([1, 3],) ([2, 4],)]
@test_broken @inferred(mapslices(transpose, r, dims=(1,3))) == permutedims(r, (3,2,1))
# ERROR: fatal error in type inference (type bound), https://github.com/JuliaLang/julia/issues/43064
@test_broken @inferred(mapslices(x -> tuple(x), [1 2; 3 4], dims=1)) == [([1, 3],) ([2, 4],)]

# re-write, #40996
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=0) # previously BoundsError
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=(1,3)) # previously BoundsError
@test_throws DimensionMismatch mapslices(x -> x * x', rand(2,3), dims=1) # explicitly caught
@test @inferred(mapslices(hcat, [1 2; 3 4], dims=1)) == [1 2; 3 4] # previously an error, now allowed
@test mapslices(identity, [1 2; 3 4], dims=(2,2)) == [1 2; 3 4] # previously an error
@test_broken @inferred(mapslices(identity, [1 2; 3 4], dims=(2,2))) == [1 2; 3 4]
end

@testset "single multidimensional index" begin
Expand Down

0 comments on commit 3eaed8b

Please sign in to comment.