diff --git a/src/reduce.jl b/src/reduce.jl index 50dedf542..7905c133c 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -170,3 +170,23 @@ Base.stdm(A::DataArray, m::Number; corrected::Bool=true, skipna::Bool=false) = Base.std(A::DataArray; corrected::Bool=true, mean=nothing, skipna::Bool=false) = sqrt(var(A; corrected=corrected, mean=mean, skipna=skipna)) + +## weighted mean + +function Base.mean{W,V}(a::DataArray, w::WeightVec{W,V}; skipna::Bool=false) + if skipna + v = a .* w.values + sum(v; skipna=true) / sum(DataArray(w.values, v.na); skipna=true) + else + anyna(a) ? NA : mean(a.data, w) + end +end + +function Base.mean{W,V<:DataArray}(a::DataArray, w::WeightVec{W,V}; skipna::Bool=false) + if skipna + v = a .* w.values + sum(v; skipna=true) / sum(DataArray(w.values.data, v.na); skipna=true) + else + anyna(a) || anyna(w.values) ? NA : wsum(a.data, w.values.data) / w.sum + end +end diff --git a/test/reduce.jl b/test/reduce.jl index 8f6ad8c96..958a4e6bc 100644 --- a/test/reduce.jl +++ b/test/reduce.jl @@ -1,5 +1,5 @@ module TestReduce -using DataArrays, Base.Test +using DataArrays, Base.Test, StatsBase srand(1337) @@ -132,4 +132,21 @@ end @test !reduce(&, @data([false, NA])) @test reduce(|, @data([true, NA])) @test isna(reduce(|, @data([false, NA]))) + +# weighted mean +da1 = DataArray(randn(128)) +da2 = DataArray(randn(128)) +@same_behavior mean(da1, weights(da2)) mean(da1.data, weights(da2.data)) +@same_behavior mean(da1, weights(da2.data)) mean(da1.data, weights(da2.data)) +@same_behavior mean(da1, weights(da2); skipna=true) mean(da1.data, weights(da2.data)) +@same_behavior mean(da1, weights(da2.data); skipna=true) mean(da1.data, weights(da2.data)) + +da1[1:3:end] = NA +@same_behavior mean(da1, weights(da2); skipna=true) mean(dropna(da1), weights(da2.data[!da1.na])) +@same_behavior mean(da1, weights(da2.data); skipna=true) mean(dropna(da1), weights(da2.data[!da1.na])) + +da2[1:2:end] = NA +keep = !da1.na & !da2.na +@test isna(mean(da1, weights(da2))) +@same_behavior mean(da1, weights(da2); skipna=true) mean(da1.data[keep], weights(da2.data[keep])) end