Skip to content

Commit

Permalink
Overload mean to take a function alongwith a dimension (#31576)
Browse files Browse the repository at this point in the history
  • Loading branch information
eulerkochy authored and ararslan committed Apr 16, 2019
1 parent b2b35e9 commit 4f59976
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ New library functions
Standard library changes
------------------------


#### LinearAlgebra


Expand All @@ -36,6 +35,7 @@ Standard library changes

#### Statistics

* `mean` now accepts both a function argument and a `dims` keyword ([#31576]).

#### Miscellaneous

Expand Down
27 changes: 26 additions & 1 deletion stdlib/Statistics/src/Statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,32 @@ function mean(f, itr)
end
return total/count
end
mean(f, A::AbstractArray) = sum(f, A) / length(A)

"""
mean(f::Function, A::AbstractArray; dims)
Apply the function `f` to each element of array `A` and take the mean over dimensions `dims`.
!!! compat "Julia 1.3"
This method requires at least Julia 1.3.
```jldoctest
julia> mean(√, [1, 2, 3])
1.3820881233139908
julia> mean([√1, √2, √3])
1.3820881233139908
julia> mean(√, [1 2 3; 4 5 6], dims=2)
2×1 Array{Float64,2}:
1.3820881233139908
2.2285192400943226
```
"""
mean(f, A::AbstractArray; dims=:) = _mean(f, A, dims)

_mean(f, A::AbstractArray, ::Colon) = sum(f, A) / length(A)
_mean(f, A::AbstractArray, dims) = sum(f, A, dims=dims) / mapreduce(i -> size(A, i), *, unique(dims); init=1)

"""
mean!(r, v)
Expand Down
5 changes: 5 additions & 0 deletions stdlib/Statistics/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ end
@test mean([1,2,3]) == 2.
@test mean([0 1 2; 4 5 6], dims=1) == [2. 3. 4.]
@test mean([1 2 3; 4 5 6], dims=1) == [2.5 3.5 4.5]
@test mean(-, [1 2 3 ; 4 5 6], dims=1) == [-2.5 -3.5 -4.5]
@test mean(-, [1 2 3 ; 4 5 6], dims=2) == transpose([-2.0 -5.0])
@test mean(-, [1 2 3 ; 4 5 6], dims=(1, 2)) == -3.5 .* ones(1, 1)
@test mean(-, [1 2 3 ; 4 5 6], dims=(1, 1)) == [-2.5 -3.5 -4.5]
@test mean(-, [1 2 3 ; 4 5 6], dims=()) == Float64[-1 -2 -3 ; -4 -5 -6]
@test mean(i->i+1, 0:2) === 2.
@test mean(isodd, [3]) === 1.
@test mean(x->3x, (1,1)) === 3.
Expand Down

0 comments on commit 4f59976

Please sign in to comment.