From b90b7f9124048dfb0bb45c96df6152f45780cfa7 Mon Sep 17 00:00:00 2001 From: Ankit Srivastava Date: Fri, 27 Mar 2020 17:49:11 -0400 Subject: [PATCH] Fixed an issue with stable_distribute and stable_distribute_inplace when the total size is less than the number of processes. Also added a test that was failing before the fix. --- include/mxx/distribution.hpp | 19 +++++++++++-------- test/test_distribution.cpp | 10 ++++++++++ 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/include/mxx/distribution.hpp b/include/mxx/distribution.hpp index fda9fb5..f24026d 100644 --- a/include/mxx/distribution.hpp +++ b/include/mxx/distribution.hpp @@ -111,15 +111,18 @@ void stable_distribute(_InIterator begin, _InIterator end, _OutIterator out, con size_t prefix = mxx::exscan(local_size, comm); // calculate where to send elements + // if there are any elements to send std::vector send_counts(comm.size(), 0); - blk_dist part(total_size, comm.size(), comm.rank()); - int first_p = part.rank_of(prefix); - size_t left_to_send = local_size; - for (; left_to_send > 0 && first_p < comm.size(); ++first_p) { - size_t nsend = std::min(part.iprefix_size(first_p) - prefix, left_to_send); - send_counts[first_p] = nsend; - left_to_send -= nsend; - prefix += nsend; + if (local_size > 0) { + blk_dist part(total_size, comm.size(), comm.rank()); + int first_p = part.rank_of(prefix); + size_t left_to_send = local_size; + for (; left_to_send > 0 && first_p < comm.size(); ++first_p) { + size_t nsend = std::min(part.iprefix_size(first_p) - prefix, left_to_send); + send_counts[first_p] = nsend; + left_to_send -= nsend; + prefix += nsend; + } } std::vector recv_counts = mxx::all2all(send_counts, comm); // TODO: accept iterators in mxx::all2all? diff --git a/test/test_distribution.cpp b/test/test_distribution.cpp index 051e218..85bf0af 100644 --- a/test/test_distribution.cpp +++ b/test/test_distribution.cpp @@ -143,6 +143,16 @@ TEST(MxxDistribution, DistributeVector) { test_distribute>(size, gen, c); test_stable_distribute>(size, gen, c); + + // create a distribution of total size smaller than + // the total number of processes and zero elements on the last process + size = (c.rank() % 2 == 0) ? 1 : 0; + if (c.is_last()) { + size = 0; + } + // XXX: test_distribute fails with an assertion error in mxx::sort + // test_distribute>(size, gen, c); + test_stable_distribute>(size, gen, c); }