From db2709f50ad16861008fc06ff583986c26b89cf8 Mon Sep 17 00:00:00 2001 From: Andreas Noack Date: Wed, 15 Aug 2018 15:21:34 +0200 Subject: [PATCH] Fix mean with dimension argument --- src/DistributedArrays.jl | 1 + src/mapreduce.jl | 3 +++ test/darray.jl | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/DistributedArrays.jl b/src/DistributedArrays.jl index 755855c..344ab81 100644 --- a/src/DistributedArrays.jl +++ b/src/DistributedArrays.jl @@ -5,6 +5,7 @@ module DistributedArrays using Distributed using Serialization using LinearAlgebra +using Statistics import Base: +, -, *, div, mod, rem, &, |, xor import Base.Callable diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 90c65f4..bb291c3 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -116,6 +116,7 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray) return mapreducedim_between!(identity, op, R, B, region) end +## Some special cases function Base._all(f, A::DArray, ::Colon) B = asyncmap(procs(A)) do p remotecall_fetch(p) do @@ -159,6 +160,8 @@ function Base.extrema(d::DArray) return reduce((t,s) -> (min(t[1], s[1]), max(t[2], s[2])), r) end +Statistics._mean(A::DArray, region) = sum(A, dims = region) ./ prod((size(A, i) for i in region)) + # Unary vector functions (-)(D::DArray) = map(-, D) diff --git a/test/darray.jl b/test/darray.jl index 9939c68..e7b18bb 100644 --- a/test/darray.jl +++ b/test/darray.jl @@ -315,7 +315,7 @@ check_leaks() @testset "test statistical functions on DArrays" begin dims = (20,20,20) DA = drandn(dims) - A = convert(Array, DA) + A = Array(DA) @testset "test $f for dimension $dms" for f in (mean, ), dms in (1, 2, 3, (1,2), (1,3), (2,3), (1,2,3)) # std is pending implementation