Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define mapfoldl/foldl for static arrays #750

Merged
merged 2 commits into from
Mar 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module StaticArrays
import Base: @_inline_meta, @_propagate_inbounds_meta, @_pure_meta, @propagate_inbounds, @pure

import Base: getindex, setindex!, size, similar, vec, show, length, convert, promote_op,
promote_rule, map, map!, reduce, mapreduce, broadcast,
promote_rule, map, map!, reduce, mapreduce, foldl, mapfoldl, broadcast,
broadcast!, conj, hcat, vcat, ones, zeros, one, reshape, fill, fill!, inv,
iszero, sum, prod, count, any, all, minimum, maximum, extrema,
copy, read, read!, write, reverse
Expand Down
149 changes: 78 additions & 71 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
"""
_InitialValue

A singleton type for representing "universal" initial value (identity element).

The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
version of it by

op′(::_InitialValue, x) = x
op′(acc, x) = op(acc, x)

This is just a conceptually useful model to have in mind and we don't actually
define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
action.

(It is related to that you can always turn a semigroup without an identity into
a monoid by "adjoining" an element that acts as the identity.)
"""
struct _InitialValue end

@inline _first(a1, as...) = a1

################
Expand Down Expand Up @@ -86,28 +106,21 @@ end
## mapreduce ##
###############

@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:,kw...)
_mapreduce(f, op, dims, kw.data, same_size(a, b...), a, b...)
@inline function mapreduce(f, op, a::StaticArray, b::StaticArray...; dims=:, init = _InitialValue())
_mapreduce(f, op, dims, init, same_size(a, b...), a, b...)
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{()},
::Size{S}, a::StaticArray...) where {S}
@inline _mapreduce(args::Vararg{Any,N}) where N = _mapfoldl(args...)

@generated function _mapfoldl(f, op, dims::Colon, init, ::Size{S}, a::StaticArray...) where {S}
tmp = [:(a[$j][1]) for j ∈ 1:length(a)]
expr = :(f($(tmp...)))
for i ∈ 2:prod(S)
tmp = [:(a[$j][$i]) for j ∈ 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
return quote
@_inline_meta
@inbounds return $expr
if init === _InitialValue
expr = :(Base.reduce_first(op, $expr))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base.reduce_first wasn't called before so sum(SA[true]) returned true; now it returns 1 :: Int. This is consistent with sum([true]).

else
expr = :(op(init, $expr))
end
end

@generated function _mapreduce(f, op, dims::Colon, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray...) where {S}
expr = :(nt.init)
for i ∈ 1:prod(S)
for i ∈ 2:prod(S)
tmp = [:(a[$j][$i]) for j ∈ 1:length(a)]
expr = :(op($expr, f($(tmp...))))
end
Expand All @@ -117,24 +130,24 @@ end
end
end

@inline function _mapreduce(f, op, D::Int, nt::NamedTuple, sz::Size{S}, a::StaticArray) where {S}
@inline function _mapreduce(f, op, D::Int, init, sz::Size{S}, a::StaticArray) where {S}
# Body of this function is split because constant propagation (at least
# as of Julia 1.2) can't always correctly propagate here and
# as a result the function is not type stable and very slow.
# This makes it at least fast for three dimensions but people should use
# for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
if D == 1
return _mapreduce(f, op, Val(1), nt, sz, a)
return _mapreduce(f, op, Val(1), init, sz, a)
elseif D == 2
return _mapreduce(f, op, Val(2), nt, sz, a)
return _mapreduce(f, op, Val(2), init, sz, a)
elseif D == 3
return _mapreduce(f, op, Val(3), nt, sz, a)
return _mapreduce(f, op, Val(3), init, sz, a)
else
return _mapreduce(f, op, Val(D), nt, sz, a)
return _mapreduce(f, op, Val(D), init, sz, a)
end
end

@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{()},
@generated function _mapfoldl(f, op, dims::Val{D}, init,
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)
Expand All @@ -143,32 +156,12 @@ end
itr = [1:n for n ∈ Snew]
for i ∈ Base.product(itr...)
expr = :(f(a[$(i...)]))
for k = 2:S[D]
ik = collect(i)
ik[D] = k
expr = :(op($expr, f(a[$(ik...)])))
if init === _InitialValue
expr = :(Base.reduce_first(op, $expr))
else
expr = :(op(init, $expr))
end

exprs[i...] = expr
end

return quote
@_inline_meta
@inbounds elements = tuple($(exprs...))
@inbounds return similar_type(a, eltype(elements), Size($Snew))(elements)
end
end

@generated function _mapreduce(f, op, dims::Val{D}, nt::NamedTuple{(:init,)},
::Size{S}, a::StaticArray) where {S,D}
N = length(S)
Snew = ([n==D ? 1 : S[n] for n = 1:N]...,)

exprs = Array{Expr}(undef, Snew)
itr = [1:n for n = Snew]
for i ∈ Base.product(itr...)
expr = :(nt.init)
for k = 1:S[D]
for k = 2:S[D]
ik = collect(i)
ik[D] = k
expr = :(op($expr, f(a[$(ik...)])))
Expand All @@ -188,20 +181,37 @@ end
## reduce ##
############

@inline reduce(op, a::StaticArray; dims=:, kw...) = _reduce(op, a, dims, kw.data)
@inline reduce(op, a::StaticArray; dims = :, init = _InitialValue()) =
_reduce(op, a, dims, init)

# disambiguation
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
Base._typed_vcat(mapreduce(eltype, promote_type, A), A)
reduce(::typeof(vcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
_reduce(vcat, A, :, NamedTuple())
_reduce(vcat, A, :, _InitialValue())

reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:AbstractVecOrMat}) =
Base._typed_hcat(mapreduce(eltype, promote_type, A), A)
reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
_reduce(hcat, A, :, NamedTuple())
_reduce(hcat, A, :, _InitialValue())

@inline _reduce(op, a::StaticArray, dims, init = _InitialValue()) =
_mapreduce(identity, op, dims, init, Size(a), a)

@inline _reduce(op, a::StaticArray, dims, kw::NamedTuple=NamedTuple()) = _mapreduce(identity, op, dims, kw, Size(a), a)
################
## (map)foldl ##
################

# Using `where {R}` to force specialization. See:
# https://docs.julialang.org/en/v1.5-dev/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing-1
# https://github.com/JuliaLang/julia/pull/33917

@inline mapfoldl(f::F, op::R, a::StaticArray; init = _InitialValue()) where {F,R} =
_mapfoldl(f, op, :, init, Size(a), a)
@inline foldl(op::R, a::StaticArray; init = _InitialValue()) where {R} =
_foldl(op, a, :, init)
@inline _foldl(op::R, a, dims, init = _InitialValue()) where {R} =
_mapfoldl(identity, op, dims, init, Size(a), a)
c42f marked this conversation as resolved.
Show resolved Hide resolved

#######################
## related functions ##
Expand All @@ -227,37 +237,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
@inline iszero(a::StaticArray{<:Tuple,T}) where {T} = reduce((x,y) -> x && iszero(y), a, init=true)

@inline sum(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(+, a, dims)
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) # avoid ambiguity
@inline sum(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a)
@inline sum(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) # avoid ambiguity

@inline prod(a::StaticArray{<:Tuple,T}; dims=:) where {T} = _reduce(*, a, dims)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, NamedTuple(), Size(a), a)
@inline prod(f, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)
@inline prod(f::Union{Function, Type}, a::StaticArray{<:Tuple,T}; dims=:) where {T} = _mapreduce(f, *, dims, _InitialValue(), Size(a), a)

@inline count(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(+, a, dims)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, NamedTuple(), Size(a), a)
@inline count(f, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, +, dims, _InitialValue(), Size(a), a)

@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, (init=true,)) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, (init=true,), Size(a), a)
@inline all(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(&, a, dims, true) # non-branching versions
@inline all(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, &, dims, true, Size(a), a)

@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, (init=false,)) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, (init=false,), Size(a), a) # (benchmarking needed)
@inline any(a::StaticArray{<:Tuple,Bool}; dims=:) = _reduce(|, a, dims, false) # (benchmarking needed)
@inline any(f::Function, a::StaticArray; dims=:) = _mapreduce(x->f(x)::Bool, |, dims, false, Size(a), a) # (benchmarking needed)

@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, (init=false,), Size(a), a)
@inline Base.in(x, a::StaticArray) = _mapreduce(==(x), |, :, false, Size(a), a)

_mean_denom(a, dims::Colon) = length(a)
_mean_denom(a, dims::Int) = size(a, dims)
_mean_denom(a, ::Val{D}) where {D} = size(a, D)
_mean_denom(a, ::Type{Val{D}}) where {D} = size(a, D)

@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, NamedTuple(), Size(a), a) / _mean_denom(a, dims)
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)

@inline minimum(a::StaticArray; dims=:) = _reduce(min, a, dims) # base has mapreduce(idenity, scalarmin, a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, NamedTuple(), Size(a), a)
@inline minimum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, min, dims, _InitialValue(), Size(a), a)

@inline maximum(a::StaticArray; dims=:) = _reduce(max, a, dims) # base has mapreduce(idenity, scalarmax, a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, NamedTuple(), Size(a), a)
@inline maximum(f::Function, a::StaticArray; dims=:) = _mapreduce(f, max, dims, _InitialValue(), Size(a), a)

# Diff is slightly different
@inline diff(a::StaticArray; dims) = _diff(Size(a), a, dims)
Expand Down Expand Up @@ -286,8 +296,6 @@ end
end
end

struct _InitialValue end

_maybe_val(dims::Integer) = Val(Int(dims))
_maybe_val(dims) = dims
_valof(::Val{D}) where D = D
Expand All @@ -299,19 +307,18 @@ _valof(::Val{D}) where D = D
_accumulate(op, a, _maybe_val(dims), init)

@inline function _accumulate(op::F, a::StaticArray, dims::Union{Val,Colon}, init) where {F}
# Adjoin the initial value to `op`:
# Adjoin the initial value to `op` (one-line version of `Base.BottomRF`):
rf(x, y) = x isa _InitialValue ? Base.reduce_first(op, y) : op(x, y)

if isempty(a)
T = return_type(rf, Tuple{typeof(init), eltype(a)})
return similar_type(a, T)()
end

# StaticArrays' `reduce` is `foldl`:
results = _reduce(
results = _foldl(
a,
dims,
(init = (similar_type(a, Union{}, Size(0))(), init),),
(similar_type(a, Union{}, Size(0))(), init),
) do (ys, acc), x
y = rf(acc, x)
# Not using `push(ys, y)` here since we need to widen element type as
Expand Down
9 changes: 9 additions & 0 deletions test/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ using Statistics: mean
@test mapreduce(x->x^2, max, sa; dims=2, init=-1.) == SMatrix{I,1}(mapreduce(x->x^2, max, a, dims=2, init=-1.))
end

@testset "[map]foldl" begin
a = rand(4,3)
v1 = [2,4,6,8]; sv1 = SVector{4}(v1)
@test foldl(+, sv1) === foldl(+, v1)
@test foldl(+, sv1; init=0) === foldl(+, v1; init=0)
@test mapfoldl(-, +, sv1) === mapfoldl(-, +, v1)
@test mapfoldl(-, +, sv1; init=0) === mapfoldl(-, +, v1, init=0)
end

@testset "implemented by [map]reduce and [map]reducedim" begin
I, J, K = 2, 2, 2
OSArray = SArray{Tuple{I,J,K}} # original
Expand Down