Skip to content

Commit

Permalink
Define necessary methods to make tests pass without hitting scalar
Browse files Browse the repository at this point in the history
indexing and some other cleanup.
  • Loading branch information
andreasnoack committed Jul 28, 2018
1 parent f509f59 commit 24a150b
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 58 deletions.
2 changes: 1 addition & 1 deletion src/DistributedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using LinearAlgebra

import Base: +, -, *, div, mod, rem, &, |, xor
import Base.Callable
import LinearAlgebra: axpy!, dot, norm,
import LinearAlgebra: axpy!, dot, norm

import Primes
import Primes: factor
Expand Down
3 changes: 2 additions & 1 deletion src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ end
Get the vector of processes storing pieces of DArray `d`.
"""
Distributed.procs(d::DArray) = d.pids
Distributed.procs(d::DArray) = d.pids
Distributed.procs(d::SubDArray) = procs(parent(d))

"""
localpart(A)
Expand Down
52 changes: 42 additions & 10 deletions src/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,41 @@ function localindices(d::DArray)
return d.indices[lpidx]
end

# find which piece holds index (I...)
locate(d::DArray, I::Int...) =
ntuple(i -> searchsortedlast(d.cuts[i], I[i]), ndims(d))
# Equality
function Base.:(==)(d::DArray{<:Any,<:Any,A}, a::AbstractArray) where A
if size(d) != size(a)
return false
else
b = asyncmap(procs(d)) do p
remotecall_fetch(p) do
localpart(d) == A(a[localindices(d)...])
end
end
return all(b)
end
end
Base.:(==)(d::SubDArray, a::AbstractArray) = copy(d) == a
Base.:(==)(a::AbstractArray, d::DArray) = d == a
Base.:(==)(a::AbstractArray, d::SubDArray) = d == a
Base.:(==)(d1::DArray, d2::DArray) = invoke(==, Tuple{DArray, AbstractArray}, d1, d2)
Base.:(==)(d1::SubDArray, d2::DArray) = copy(d1) == d2
Base.:(==)(d1::DArray, d2::SubDArray) = d1 == copy(d2)
Base.:(==)(d1::SubDArray, d2::SubDArray) = copy(d1) == copy(d2)

"""
locate(d::DArray, I::Int...)
Determine the index of `procs(d)` that hold element `I`.
"""
function locate(d::DArray, I::Int...)
ntuple(ndims(d)) do i
fi = searchsortedlast(d.cuts[i], I[i])
if fi >= length(d.cuts[i])
throw(ArgumentError("element not contained in array"))
end
return fi
end
end

chunk(d::DArray{T,N,A}, i...) where {T,N,A} = remotecall_fetch(localpart, d.pids[i...], d)::A

Expand Down Expand Up @@ -479,15 +511,15 @@ end
function (::Type{Array{S,N}})(s::SubDArray{T,N}) where {S,T,N}
I = s.indices
d = s.parent
if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S
if isa(I,Tuple{Vararg{UnitRange{Int}}}) && S<:T && T<:S && !isempty(s)
l = locate(d, map(first, I)...)
if isequal(d.indices[l...], I)
# SubDArray corresponds to a chunk
return chunk(d, l...)
end
end
a = Array{S}(undef, size(s))
a[[1:size(a,i) for i=1:N]...] .= s
a[[1:size(a,i) for i=1:N]...] = s
return a
end

Expand Down Expand Up @@ -540,15 +572,15 @@ end

function Base.getindex(d::DArray, i::Int)
_scalarindexingallowed()
return getindex_tuple(d, CartesianIndices(d)[i])
return getindex_tuple(d, Tuple(CartesianIndices(d)[i]))
end
function Base.getindex(d::DArray, i::Int...)
_scalarindexingallowed()
return getindex_tuple(d, i)
end

Base.getindex(d::DArray) = d[1]
Base.getindex(d::DArray, I::Union{Int,UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) = view(d, I...)
Base.getindex(d::SubOrDArray, I::Union{Int,UnitRange{Int},Colon,Vector{Int},StepRange{Int,Int}}...) = view(d, I...)

function Base.isassigned(D::DArray, i::Integer...)
try
Expand All @@ -564,15 +596,15 @@ function Base.isassigned(D::DArray, i::Integer...)
end


Base.copyto!(dest::SubOrDArray, src::SubOrDArray) = begin
function Base.copyto!(dest::SubOrDArray, src::AbstractArray)
asyncmap(procs(dest)) do p
remotecall_fetch(p) do
localpart(dest)[:] = src[localindices(dest)...]
ldest = localpart(dest)
ldest[:] = Array(view(src, localindices(dest)...))
end
end
return dest
end
Base.copy!(dest::SubOrDArray, src::SubOrDArray) = copyto!(dest, src)

function Base.deepcopy(src::DArray)
dest = similar(src)
Expand Down
82 changes: 56 additions & 26 deletions src/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
function Base.copy(D::Adjoint{T,<:DArray{T,2}}) where T
function Base.copy(Dadj::Adjoint{T,<:DArray{T,2}}) where T
D = parent(Dadj)
DArray(reverse(size(D)), procs(D)) do I
lp = Array{T}(undef, map(length, I))
rp = convert(Array, D[reverse(I)...])
adjoint!(lp, rp)
end
end

function Base.copy(D::Transpose{T,<:DArray{T,2}}) where T
function Base.copy(Dtr::Transpose{T,<:DArray{T,2}}) where T
D = parent(Dtr)
DArray(reverse(size(D)), procs(D)) do I
lp = Array{T}(undef, map(length, I))
rp = convert(Array, D[reverse(I)...])
Expand Down Expand Up @@ -49,7 +51,7 @@ function dot(x::DVector, y::DVector)
return reduce(+, results)
end

function norm(x::DVector, p::Real = 2)
function norm(x::DArray, p::Real = 2)
results = []
@sync begin
for pp in procs(x)
Expand Down Expand Up @@ -83,7 +85,7 @@ function add!(dest, src, scale = one(dest[1]))
return dest
end

function A_mul_B!::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVector)
function mul!(y::DVector, A::DMatrix, x::AbstractVector, α::Number = 1, β::Number = 0)

# error checks
if size(A, 2) != length(x)
Expand All @@ -106,11 +108,14 @@ function A_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVec

# Scale y if necessary
if β != one(β)
@sync for p in y.pids
if β != zero(β)
@async remotecall_fetch(y -> (rmul!(localpart(y), β); nothing), p, y)
else
@async remotecall_fetch(y -> (fill!(localpart(y), 0); nothing), p, y)
asyncmap(procs(y)) do p
remotecall_fetch(p) do
if !iszero(β)
rmul!(localpart(y), β)
else
fill!(localpart(y), 0)
end
return nothing
end
end
end
Expand All @@ -127,7 +132,9 @@ function A_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVec
return y
end

function Ac_mul_B!::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVector)
function mul!(y::DVector, adjA::Adjoint{<:Number,<:DMatrix}, x::AbstractVector, α::Number = 1, β::Number = 0)

A = parent(adjA)

# error checks
if size(A, 1) != length(x)
Expand All @@ -148,11 +155,14 @@ function Ac_mul_B!(α::Number, A::DMatrix, x::AbstractVector, β::Number, y::DVe

# Scale y if necessary
if β != one(β)
@sync for p in y.pids
if β != zero(β)
@async remotecall_fetch(() -> (rmul!(localpart(y), β); nothing), p)
else
@async remotecall_fetch(() -> (fill!(localpart(y), 0); nothing), p)
asyncmap(procs(y)) do p
remotecall_fetch(p) do
if !iszero(β)
rmul!(localpart(y), β)
else
fill!(localpart(y), 0)
end
return nothing
end
end
end
Expand Down Expand Up @@ -189,7 +199,7 @@ function LinearAlgebra.rmul!(DA::DMatrix, D::Diagonal)
end

# Level 3
function _matmatmul!(α::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix, tA)
function _matmatmul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number, β::Number, tA)
# error checks
Ad1, Ad2 = (tA == 'N') ? (1,2) : (2,1)
mA, nA = (size(A, Ad1), size(A, Ad2))
Expand Down Expand Up @@ -254,40 +264,60 @@ function _matmatmul!(α::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::D
return C
end

A_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'N')
Ac_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'C')
At_mul_B!::Number, A::DMatrix, B::AbstractMatrix, β::Number, C::DMatrix) = _matmatmul!(α, A, B, β, C, 'T')
At_mul_B!(C::DMatrix, A::DMatrix, B::AbstractMatrix) = At_mul_B!(one(eltype(C)), A, B, zero(eltype(C)), C)
mul!(C::DMatrix, A::DMatrix, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, A, B, α, β, 'N')
mul!(C::DMatrix, A::Adjoint{<:Number,<:DMatrix}, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, parent(A), B, α, β, 'C')
mul!(C::DMatrix, A::Transpose{<:Number,<:DMatrix}, B::AbstractMatrix, α::Number = 1, β::Number = 0) = _matmatmul!(C, parent(A), B, α, β, 'T')

_matmul_op = (t,s) -> t*s + t*s

function Base.:*(A::DMatrix, x::AbstractVector)
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
y = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 1),), procs(A)[:,1], (size(procs(A), 1),))
return A_mul_B!(one(T), A, x, zero(T), y)
return mul!(y, A, x)
end
function Base.:*(A::DMatrix, B::AbstractMatrix)
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
C = DArray(I -> Array{T}(undef, map(length, I)),
(size(A, 1), size(B, 2)),
procs(A)[:,1:min(size(procs(A), 2), size(procs(B), 2))],
(size(procs(A), 1), min(size(procs(A), 2), size(procs(B), 2))))
return A_mul_B!(one(T), A, B, zero(T), C)
return mul!(C, A, B)
end

function Base.:*(adjA::Adjoint{<:Any,<:DMatrix}, x::AbstractVector)
A = parent(adjA)
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
y = DArray(I -> Array{T}(undef, map(length, I)),
(size(A, 2),),
procs(A)[1,:],
(size(procs(A), 2),))
return mul!(y, adjA, x)
end
function Base.:*(adjA::Adjoint{<:Any,<:DMatrix}, B::AbstractMatrix)
A = parent(adjA)
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
C = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 2),
size(B, 2)),
procs(A)[1:min(size(procs(A), 1), size(procs(B), 2)),:],
(size(procs(A), 2), min(size(procs(A), 1), size(procs(B), 2))))
return mul!(C, adjA, B)
end

function Ac_mul_B(A::DMatrix, x::AbstractVector)
function Base.:*(trA::Transpose{<:Any,<:DMatrix}, x::AbstractVector)
A = parent(trA)
T = Base.promote_op(_matmul_op, eltype(A), eltype(x))
y = DArray(I -> Array{T}(undef, map(length, I)),
(size(A, 2),),
procs(A)[1,:],
(size(procs(A), 2),))
return Ac_mul_B!(one(T), A, x, zero(T), y)
return mul!(y, trA, x)
end
function Ac_mul_B(A::DMatrix, B::AbstractMatrix)
function Base.:*(trA::Transpose{<:Any,<:DMatrix}, B::AbstractMatrix)
A = parent(trA)
T = Base.promote_op(_matmul_op, eltype(A), eltype(B))
C = DArray(I -> Array{T}(undef, map(length, I)), (size(A, 2),
size(B, 2)),
procs(A)[1:min(size(procs(A), 1), size(procs(B), 2)),:],
(size(procs(A), 2), min(size(procs(A), 1), size(procs(B), 2))))
return Ac_mul_B!(one(T), A, B, zero(T), C)
return mul!(C, trA, B)
end
35 changes: 31 additions & 4 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import SparseArrays: nnz

Base.map(f, d0::DArray, ds::AbstractArray...) = broadcast(f, d0, ds...)

function Base.map!(f::F, dest::DArray, src::DArray) where {F}
function Base.map!(f::F, dest::DArray, src::DArray{<:Any,<:Any,A}) where {F,A}
asyncmap(procs(dest)) do p
remotecall_fetch(p) do
map!(f, localpart(dest), src[localindices(dest)...])
map!(f, localpart(dest), A(view(src, localindices(dest)...)))
return nothing
end
end
Expand Down Expand Up @@ -53,7 +53,7 @@ rewrite_local(x) = x

function Base.reduce(f, d::DArray)
results = asyncmap(procs(d)) do p
remotecall_fetch(p, f, d) do (f, d)
remotecall_fetch(p) do
return reduce(f, localpart(d))
end
end
Expand Down Expand Up @@ -122,12 +122,39 @@ function Base.mapreducedim!(f, op, R::DArray, A::DArray)
end
region = tuple(collect(1:ndims(A))[[size(R)...] .!= [size(A)...]]...)
if isempty(region)
return copy!(R, A)
return copyto!(R, A)
end
B = mapreducedim_within(f, op, A, region)
return mapreducedim_between!(identity, op, R, B, region)
end

function Base._all(f, A::DArray, ::Colon)
B = asyncmap(procs(A)) do p
remotecall_fetch(p) do
all(f, localpart(A))
end
end
return all(B)
end

function Base._any(f, A::DArray, ::Colon)
B = asyncmap(procs(A)) do p
remotecall_fetch(p) do
any(f, localpart(A))
end
end
return any(B)
end

function Base.count(f, A::DArray)
B = asyncmap(procs(A)) do p
remotecall_fetch(p) do
count(f, localpart(A))
end
end
return sum(B)
end

function nnz(A::DArray)
B = asyncmap(A.pids) do p
remotecall_fetch(nnzlocalpart, p, A)
Expand Down
4 changes: 2 additions & 2 deletions src/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ function Serialization.serialize(S::AbstractSerializer, d::DArray{T,N,A}) where
# Only send the ident for participating workers - we expect the DArray to exist in the
# remote registry. DO NOT send the localpart.
destpid = worker_id_from_socket(S.io)
serialize_type(S, typeof(d))
Serialization.serialize_type(S, typeof(d))
if (destpid in d.pids) || (destpid == d.id[1])
serialize(S, (true, d.id)) # (id_only, id)
else
Expand Down Expand Up @@ -64,7 +64,7 @@ function Serialization.serialize(S::AbstractSerializer, s::DestinationSerializer
pid = worker_id_from_socket(S.io)
pididx = findfirst(isequal(pid), s.pids)
@assert pididx !== nothing
serialize_type(S, typeof(s))
Serialization.serialize_type(S, typeof(s))
serialize(S, s.generate(pididx))
end

Expand Down
Loading

0 comments on commit 24a150b

Please sign in to comment.