Skip to content

Commit

Permalink
Fixed an issue with stable_distribute and stable_distribute_inplace
Browse files Browse the repository at this point in the history
when the total size is less than the number of processes.
Also added a test that was failing before the fix.
  • Loading branch information
asrivast28 committed Mar 28, 2020
1 parent 1c0ab9f commit b90b7f9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
19 changes: 11 additions & 8 deletions include/mxx/distribution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,18 @@ void stable_distribute(_InIterator begin, _InIterator end, _OutIterator out, con
size_t prefix = mxx::exscan(local_size, comm);

// calculate where to send elements
// if there are any elements to send
std::vector<size_t> send_counts(comm.size(), 0);
blk_dist part(total_size, comm.size(), comm.rank());
int first_p = part.rank_of(prefix);
size_t left_to_send = local_size;
for (; left_to_send > 0 && first_p < comm.size(); ++first_p) {
size_t nsend = std::min<size_t>(part.iprefix_size(first_p) - prefix, left_to_send);
send_counts[first_p] = nsend;
left_to_send -= nsend;
prefix += nsend;
if (local_size > 0) {
blk_dist part(total_size, comm.size(), comm.rank());
int first_p = part.rank_of(prefix);
size_t left_to_send = local_size;
for (; left_to_send > 0 && first_p < comm.size(); ++first_p) {
size_t nsend = std::min<size_t>(part.iprefix_size(first_p) - prefix, left_to_send);
send_counts[first_p] = nsend;
left_to_send -= nsend;
prefix += nsend;
}
}
std::vector<size_t> recv_counts = mxx::all2all(send_counts, comm);
// TODO: accept iterators in mxx::all2all?
Expand Down
10 changes: 10 additions & 0 deletions test/test_distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ TEST(MxxDistribution, DistributeVector) {

test_distribute<std::vector<int>>(size, gen, c);
test_stable_distribute<std::vector<int>>(size, gen, c);

// create a distribution of total size smaller than
// the total number of processes and zero elements on the last process
size = (c.rank() % 2 == 0) ? 1 : 0;
if (c.is_last()) {
size = 0;
}
// XXX: test_distribute fails with an assertion error in mxx::sort
// test_distribute<std::vector<int>>(size, gen, c);
test_stable_distribute<std::vector<int>>(size, gen, c);
}


Expand Down

0 comments on commit b90b7f9

Please sign in to comment.