Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up mapslices #40996

Merged
merged 18 commits into from
May 31, 2022
194 changes: 109 additions & 85 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2778,134 +2778,158 @@ foreach(f, itrs...) = (for z in zip(itrs...); f(z...); end; nothing)
"""
mapslices(f, A; dims)

Transform the given dimensions of array `A` using function `f`. `f` is called on each slice
of `A` of the form `A[...,:,...,:,...]`. `dims` is an integer vector specifying where the
colons go in this expression. The results are concatenated along the remaining dimensions.
For example, if `dims` is `[1,2]` and `A` is 4-dimensional, `f` is called on `A[:,:,i,j]`
for all `i` and `j`.
Transform the given dimensions of array `A` by applying a function `f` on each slice
of the form `A[..., :, ..., :, ...]`, with a colon at each `d` in `dims`. The results are
concatenated along the remaining dimensions.

See also [`eachcol`](@ref), [`eachslice`](@ref).
For example, if `dims = [1,2]` and `A` is 4-dimensional, then `f` is called on `x = A[:,:,i,j]`
for all `i` and `j`, and `f(x)` becomes `R[:,:,i,j]` in the result `R`.

See also [`eachcol`](@ref), [`eachslice`](@ref), [`mapreduce`](@ref).

# Examples
```jldoctest
julia> a = reshape(Vector(1:16),(2,2,2,2))
2×2×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
1 3
2 4
julia> A = reshape(1:30,(2,5,3))
2×5×3 reshape(::UnitRange{$Int}, 2, 5, 3) with eltype $Int:
[:, :, 1] =
1 3 5 7 9
2 4 6 8 10

[:, :, 2] =
11 13 15 17 19
12 14 16 18 20

[:, :, 3] =
21 23 25 27 29
22 24 26 28 30

[:, :, 2, 1] =
5 7
6 8
julia> f(x::Matrix) = fill(x[1,1], 1,4); # returns a 1×4 matrix

[:, :, 1, 2] =
9 11
10 12
julia> mapslices(f, A, dims=(1,2))
1×4×3 Array{$Int, 3}:
[:, :, 1] =
1 1 1 1

[:, :, 2] =
11 11 11 11

[:, :, 2, 2] =
13 15
14 16
[:, :, 3] =
21 21 21 21

julia> mapslices(sum, a, dims = [1,2])
1×1×2×2 Array{Int64, 4}:
[:, :, 1, 1] =
10
julia> g(x) = x[begin] // x[end-1]; # returns a number

[:, :, 2, 1] =
26
julia> mapslices(g, A, dims=[1,3])
1×5×1 Array{Rational{$Int}, 3}:
[:, :, 1] =
1//21 3//23 1//5 7//27 9//29

[:, :, 1, 2] =
42
julia> map(g, eachslice(A, dims=2))
5-element Vector{Rational{$Int}}:
1//21
3//23
1//5
7//27
9//29

[:, :, 2, 2] =
58
julia> mapslices(sum, A; dims=(1,3)) == sum(A; dims=(1,3))
true
```

Notice that in `eachslice(A; dims=2)`, the specified dimension is the
one *without* a colon in the slice. This is `view(A,:,i,:)`, whereas
`mapslices(f, A; dims=(1,3))` uses `A[:,i,:]`. The function `f` may mutate
values in the slice without affecting `A`.
"""
function mapslices(f, A::AbstractArray; dims)
if isempty(dims)
return map(f,A)
end
if !isa(dims, AbstractVector)
dims = [dims...]
end
isempty(dims) && return map(f, A)

dimsA = [axes(A)...]
ndimsA = ndims(A)
alldims = [1:ndimsA;]

otherdims = setdiff(alldims, dims)

idx = Any[first(ind) for ind in axes(A)]
itershape = tuple(dimsA[otherdims]...)
for d in dims
idx[d] = Slice(axes(A, d))
d isa Integer || throw(ArgumentError("mapslices: dimension must be an integer, got $d"))
d >= 1 || throw(ArgumentError("mapslices: dimension must be ≥ 1, got $d"))
# Indexing a matrix M[:,1,:] produces a 1-column matrix, but dims=(1,3) here
# would otherwise ignore 3, and slice M[:,i]. Previously this gave error:
# BoundsError: attempt to access 2-element Vector{Any} at index [3]
d > ndims(A) && throw(ArgumentError("mapslices does not accept dimensions > ndims(A) = $(ndims(A)), got $d"))
end
dim_mask = ntuple(d -> d in dims, ndims(A))

# Apply the function to the first slice in order to determine the next steps
Aslice = A[idx...]
idx1 = ntuple(d -> d in dims ? (:) : firstindex(A,d), ndims(A))
Aslice = A[idx1...]
r1 = f(Aslice)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(Aslice, StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))

# determine result size and allocate
Rsize = copy(dimsA)
# TODO: maybe support removing dimensions
if !isa(r1, AbstractArray) || ndims(r1) == 0
res1 = if r1 isa AbstractArray && ndims(r1) > 0
n = sum(dim_mask)
if ndims(r1) > n && any(ntuple(d -> size(r1,d+n)>1, ndims(r1)-n))
s = size(r1)[1:n]
throw(DimensionMismatch("mapslices cannot assign slice f(x) of size $(size(r1)) into output of size $s"))
end
r1
else
# If the result of f on a single slice is a scalar then we add singleton
# dimensions. When adding the dimensions, we have to respect the
# index type of the input array (e.g. in the case of OffsetArrays)
tmp = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
tmp[firstindex(tmp)] = r1
r1 = tmp
end
nextra = max(0, length(dims)-ndims(r1))
if eltype(Rsize) == Int
Rsize[dims] = [size(r1)..., ntuple(Returns(1), nextra)...]
else
Rsize[dims] = [axes(r1)..., ntuple(Returns(OneTo(1)), nextra)...]
_res1 = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
_res1[begin] = r1
_res1
end
R = similar(r1, tuple(Rsize...,))

ridx = Any[map(first, axes(R))...]
for d in dims
ridx[d] = axes(R,d)
# Determine result size and allocate. We always pad ndims(res1) out to length(dims):
din = Ref(0)
Rsize = ntuple(ndims(A)) do d
if d in dims
axes(res1, din[] += 1)
else
axes(A,d)
end
end
R = similar(res1, Rsize)

# Determine iteration space. It will be convenient in the loop to mask N-dimensional
# CartesianIndices, with some trivial dimensions:
itershape = ntuple(d -> d in dims ? Base.OneTo(1) : axes(A,d), ndims(A))
indices = Iterators.drop(CartesianIndices(itershape), 1)

# That skips the first element, which we already have:
ridx = ntuple(d -> d in dims ? Slice(axes(R,d)) : firstindex(A,d), ndims(A))
concatenate_setindex!(R, res1, ridx...)

concatenate_setindex!(R, r1, ridx...)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse = isa(Aslice, StridedArray) &&
(isa(r1, Number) || (isa(r1, AbstractArray) && eltype(r1) <: Number))

nidx = length(otherdims)
indices = Iterators.drop(CartesianIndices(itershape), 1) # skip the first element, we already handled it
inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
_inner_mapslices!(R, indices, f, A, dim_mask, Aslice, safe_for_reuse)
return R
end

@noinline function inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
@noinline function _inner_mapslices!(R, indices, f, A, dim_mask, Aslice, safe_for_reuse)
must_extend = any(dim_mask .& size(R) .> 1)
if safe_for_reuse
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
# so we can reuse Aslice
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
idx = ifelse.(dim_mask, Slice.(axes(A)), Tuple(I))
_unsafe_getindex!(Aslice, A, idx...)
concatenate_setindex!(R, f(Aslice), ridx...)
r = f(Aslice)
if r isa AbstractArray || must_extend
ridx = ifelse.(dim_mask, Slice.(axes(R)), Tuple(I))
R[ridx...] = r
else
ridx = ifelse.(dim_mask, first.(axes(R)), Tuple(I))
R[ridx...] = r
end
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
idx = ifelse.(dim_mask, Slice.(axes(A)), Tuple(I))
ridx = ifelse.(dim_mask, Slice.(axes(R)), Tuple(I))
concatenate_setindex!(R, f(A[idx...]), ridx...)
end
end

return R
end

function replace_tuples!(nidx, idx, ridx, otherdims, I)
for i in 1:nidx
idx[otherdims[i]] = ridx[otherdims[i]] = I.I[i]
end
end

concatenate_setindex!(R, v, I...) = (R[I...] .= (v,); R)
Expand Down
27 changes: 23 additions & 4 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,6 @@ end
@test mapslices(prod,["1"],dims=1) == ["1"]

# issue #5177

c = fill(1,2,3,4)
m1 = mapslices(_ -> fill(1,2,3), c, dims=[1,2])
m2 = mapslices(_ -> fill(1,2,4), c, dims=[1,3])
Expand All @@ -1196,9 +1195,29 @@ end
@test o == fill(1, 3, 4)

# issue #18524
m = mapslices(x->tuple(x), [1 2; 3 4], dims=1)
@test m[1,1] == ([1,3],)
@test m[1,2] == ([2,4],)
# m = mapslices(x->tuple(x), [1 2; 3 4], dims=1) # see variations of this below
# ERROR: fatal error in type inference (type bound), https://github.com/JuliaLang/julia/issues/43064
# @test m[1,1] == ([1,3],)
# @test m[1,2] == ([2,4],)

r = rand(Int8, 4,5,2)
@test vec(mapslices(repr, r, dims=(2,1))) == map(repr, eachslice(r, dims=3))
@test mapslices(tuple, [1 2; 3 4], dims=1) == [([1, 3],) ([2, 4],)]
@test mapslices(transpose, r, dims=(1,3)) == permutedims(r, (3,2,1))

# failures
@test_broken @inferred(mapslices(tuple, [1 2; 3 4], dims=1)) == [([1, 3],) ([2, 4],)]
@test_broken @inferred(mapslices(transpose, r, dims=(1,3))) == permutedims(r, (3,2,1))
# ERROR: fatal error in type inference (type bound), https://github.com/JuliaLang/julia/issues/43064
@test_broken @inferred(mapslices(x -> tuple(x), [1 2; 3 4], dims=1)) == [([1, 3],) ([2, 4],)]

# re-write, #40996
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=0) # previously BoundsError
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=(1,3)) # previously BoundsError
@test_throws DimensionMismatch mapslices(x -> x * x', rand(2,3), dims=1) # explicitly caught
@test @inferred(mapslices(hcat, [1 2; 3 4], dims=1)) == [1 2; 3 4] # previously an error, now allowed
@test mapslices(identity, [1 2; 3 4], dims=(2,2)) == [1 2; 3 4] # previously an error
@test_broken @inferred(mapslices(identity, [1 2; 3 4], dims=(2,2))) == [1 2; 3 4]
end

@testset "single multidimensional index" begin
Expand Down