Skip to content

Commit

Permalink
implement count and count! using mapreduce (#34048)
Browse files Browse the repository at this point in the history
This creates the same calling interface for `count` as for other
mapreduce-type functions like e.g. `sum`, namely allowing the `dims`
keyword.
The implementation itself is shorter than before without sacrificing
performance.
More detailed documentation for `count` was added too.
  • Loading branch information
stev47 authored May 2, 2020
1 parent dddda07 commit 301db97
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 14 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ New library features
will acquire locks for safe multi-threaded access. Setting it to `false` provides better
performance when only one thread will access the file.
* The introspection macros (`@which`, `@code_typed`, etc.) now work with `do`-block syntax ([#35283]) and with dot syntax ([#35522]).
* `count` now accepts the `dims` keyword.
* new in-place `count!` function similar to `sum!`.

Standard library changes
------------------------
Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,7 @@ export
any,
firstindex,
collect,
count!,
count,
delete!,
deleteat!,
Expand Down
18 changes: 4 additions & 14 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,8 @@ end

## count

_bool(f::Function) = x->f(x)::Bool

"""
count(p, itr) -> Integer
count(itr) -> Integer
Expand All @@ -853,22 +855,10 @@ julia> count([true, false, true, true])
3
```
"""
function count(pred, itr)
n = 0
for x in itr
n += pred(x)::Bool
end
return n
end
function count(pred, a::AbstractArrayOrBroadcasted)
n = 0
for i in eachindex(a)
@inbounds n += pred(a[i])::Bool
end
return n
end
count(itr) = count(identity, itr)

count(f, itr) = mapreduce(_bool(f), add_sum, itr, init=0)

function count(::typeof(identity), x::Array{Bool})
n = 0
chunks = length(x) ÷ sizeof(UInt)
Expand Down
61 changes: 61 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,67 @@ julia> reduce(max, a, dims=1)
reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...)

##### Specific reduction functions #####

"""
count([f=identity,] A::AbstractArray; dims=:)
Count the number of elements in `A` for which `f` returns `true` over the given
dimensions.
!!! compat "Julia 1.5"
`dims` keyword was added in Julia 1.5.
# Examples
```jldoctest
julia> A = [1 2; 3 4]
2×2 Array{Int64,2}:
1 2
3 4
julia> count(<=(2), A, dims=1)
1×2 Array{Int64,2}:
1 1
julia> count(<=(2), A, dims=2)
2×1 Array{Int64,2}:
2
0
```
"""
count(A::AbstractArrayOrBroadcasted; dims=:) = count(identity, A, dims=dims)
count(f, A::AbstractArrayOrBroadcasted; dims=:) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)

"""
count!([f=identity,] r, A; init=true)
Count the number of elements in `A` for which `f` returns `true` over the
singleton dimensions of `r`, writing the result into `r` in-place.
If `init` is `true`, values in `r` are initialized to zero.
!!! compat "Julia 1.5"
inplace `count!` was added in Julia 1.5.
# Examples
```jldoctest
julia> A = [1 2; 3 4]
2×2 Array{Int64,2}:
1 2
3 4
julia> count!(<=(2), [1 1], A)
1×2 Array{Int64,2}:
1 1
julia> count!(<=(2), [1; 1], A)
2-element Array{Int64,1}:
2
0
```
"""
count!(r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) = count!(identity, r, A; init=init)
count!(f, r::AbstractArray, A::AbstractArrayOrBroadcasted; init::Bool=true) =
mapreducedim!(_bool(f), add_sum, initarray!(r, add_sum, init, A), A)

"""
sum(A::AbstractArray; dims)
Expand Down
21 changes: 21 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ safe_sum(A::Array{T}, region) where {T} = safe_mapslices(sum, A, region)
safe_prod(A::Array{T}, region) where {T} = safe_mapslices(prod, A, region)
safe_maximum(A::Array{T}, region) where {T} = safe_mapslices(maximum, A, region)
safe_minimum(A::Array{T}, region) where {T} = safe_mapslices(minimum, A, region)
safe_count(A::AbstractArray{T}, region) where {T} = safe_mapslices(count, A, region)
safe_sumabs(A::Array{T}, region) where {T} = safe_mapslices(sum, abs.(A), region)
safe_sumabs2(A::Array{T}, region) where {T} = safe_mapslices(sum, abs2.(A), region)
safe_maxabs(A::Array{T}, region) where {T} = safe_mapslices(maximum, abs.(A), region)
Expand All @@ -21,15 +22,21 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
1, 2, 3, 4, 5, (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4),
(1, 2, 3), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)]
Areduc = rand(3, 4, 5, 6)
Breduc = rand(Bool, 3, 4, 5, 6)
@assert axes(Areduc) == axes(Breduc)

r = fill(NaN, map(length, Base.reduced_indices(axes(Areduc), region)))
@test sum!(r, Areduc) safe_sum(Areduc, region)
@test prod!(r, Areduc) safe_prod(Areduc, region)
@test maximum!(r, Areduc) safe_maximum(Areduc, region)
@test minimum!(r, Areduc) safe_minimum(Areduc, region)
@test count!(r, Breduc) safe_count(Breduc, region)

@test sum!(abs, r, Areduc) safe_sumabs(Areduc, region)
@test sum!(abs2, r, Areduc) safe_sumabs2(Areduc, region)
@test maximum!(abs, r, Areduc) safe_maxabs(Areduc, region)
@test minimum!(abs, r, Areduc) safe_minabs(Areduc, region)
@test count!(!, r, Breduc) safe_count(.!Breduc, region)

# With init=false
r2 = similar(r)
Expand All @@ -41,6 +48,9 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test maximum!(r, Areduc, init=false) fill!(r2, 1.8)
fill!(r, -0.2)
@test minimum!(r, Areduc, init=false) fill!(r2, -0.2)
fill!(r, 1)
@test count!(r, Breduc, init=false) safe_count(Breduc, region) .+ 1

fill!(r, 8.1)
@test sum!(abs, r, Areduc, init=false) safe_sumabs(Areduc, region) .+ 8.1
fill!(r, 8.1)
Expand All @@ -49,15 +59,20 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test maximum!(abs, r, Areduc, init=false) fill!(r2, 1.5)
fill!(r, -1.5)
@test minimum!(abs, r, Areduc, init=false) fill!(r2, -1.5)
fill!(r, 1)
@test count!(!, r, Breduc, init=false) safe_count(.!Breduc, region) .+ 1

@test @inferred(sum(Areduc, dims=region)) safe_sum(Areduc, region)
@test @inferred(prod(Areduc, dims=region)) safe_prod(Areduc, region)
@test @inferred(maximum(Areduc, dims=region)) safe_maximum(Areduc, region)
@test @inferred(minimum(Areduc, dims=region)) safe_minimum(Areduc, region)
@test @inferred(count(Breduc, dims=region)) safe_count(Breduc, region)

@test @inferred(sum(abs, Areduc, dims=region)) safe_sumabs(Areduc, region)
@test @inferred(sum(abs2, Areduc, dims=region)) safe_sumabs2(Areduc, region)
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
@test @inferred(count(!, Breduc, dims=region)) safe_count(.!Breduc, region)
end

# Test reduction along first dimension; this is special-cased for
Expand Down Expand Up @@ -416,3 +431,9 @@ end

@test sum([Variable(:x), Variable(:y)], dims=1) == [AffExpr([Variable(:x), Variable(:y)])]
end

# count
@testset "count: throw on non-bool types" begin
@test_throws TypeError count([1], dims=1)
@test_throws TypeError count!([1], [1])
end

2 comments on commit 301db97

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily benchmark build, I will reply here when finished:

@nanosoldier runbenchmarks(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here. cc @ararslan

Please sign in to comment.