Skip to content

Commit

Permalink
Merge pull request #27 from asrivast28/StableDistributeFix
Browse files Browse the repository at this point in the history
Fix in stable_distribute function.
  • Loading branch information
patflick committed Jun 5, 2021
2 parents e863f21 + f100016 commit 435e04d
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions include/mxx/distribution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ void stable_distribute(_InIterator begin, _InIterator end, _OutIterator out, con
// calculate where to send elements
// if there are any elements to send
std::vector<size_t> send_counts(comm.size(), 0);
blk_dist part(total_size, comm);
if (local_size > 0) {
blk_dist part(total_size, comm.size(), comm.rank());
int first_p = part.rank_of(prefix);
size_t left_to_send = local_size;
for (; left_to_send > 0 && first_p < comm.size(); ++first_p) {
Expand All @@ -126,7 +126,11 @@ void stable_distribute(_InIterator begin, _InIterator end, _OutIterator out, con
}
std::vector<size_t> recv_counts = mxx::all2all(send_counts, comm);
// TODO: accept iterators in mxx::all2all?
mxx::all2allv(&(*begin), send_counts, &(*out), recv_counts, comm);
using SendType = typename std::iterator_traits<_InIterator>::value_type;
using RecvType = typename std::iterator_traits<_OutIterator>::value_type;
const SendType* send_buf = (local_size > 0) ? &(*begin) : nullptr;
RecvType* recv_buf = (part.local_size() > 0) ? &(*out) : nullptr;
mxx::all2allv(send_buf, send_counts, recv_buf, recv_counts, comm);
}
}

Expand Down

0 comments on commit 435e04d

Please sign in to comment.