Skip to content

Commit

Permalink
bitonic sort for splitters in samplesort
Browse files Browse the repository at this point in the history
  • Loading branch information
patflick committed Apr 25, 2016
1 parent 2336c07 commit 09a2ca0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 100 deletions.
2 changes: 1 addition & 1 deletion include/mxx/bitonicsort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ template <typename _Iterator, typename _Compare>
void bitonic_sort(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm) {
size_t np = std::distance(begin, end);
if (!mxx::all_same(np)) {
throw std::runtime_error("bitonic sort is only valid for the same number of elements on each process.");
throw std::runtime_error("bitonic sort only works for the same number of elements on each process.");
}

if (!std::is_sorted(begin, end, comp)) {
Expand Down
141 changes: 42 additions & 99 deletions include/mxx/samplesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "collective.hpp"
#include "shift.hpp"
#include "distribution.hpp"
#include "bitonicsort.hpp"

#include <iterator>
#include <algorithm>
Expand Down Expand Up @@ -87,7 +88,7 @@ bool is_sorted(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& c

template <typename _Iterator, typename _Compare>
std::vector<typename std::iterator_traits<_Iterator>::value_type>
sample_arbit_decomp(_Iterator begin, _Iterator end, _Compare comp, int s, const mxx::comm& comm, MPI_Datatype mpi_dt) {
sample_arbit_decomp(_Iterator begin, _Iterator end, _Compare comp, size_t s, const mxx::comm& comm) {
typedef typename std::iterator_traits<_Iterator>::value_type value_type;
std::size_t local_size = std::distance(begin, end);

Expand All @@ -106,7 +107,10 @@ sample_arbit_decomp(_Iterator begin, _Iterator end, _Compare comp, int s, const
else
local_s = std::max<std::size_t>(((local_size*s*p)+total_size-1)/total_size, 1);

//. init samples
size_t n_splitters = mxx::allreduce(local_s, comm);
MXX_ASSERT(n_splitters >= p*s);

// init samples
std::vector<value_type> local_splitters;

// pick local samples
Expand All @@ -123,53 +127,36 @@ sample_arbit_decomp(_Iterator begin, _Iterator end, _Compare comp, int s, const
}
}

// 2. gather samples to `rank = 0`
// - TODO: rather call sample sort
// recursively and implement a base case for samplesort which does
// gather to rank=0, local sort and redistribute
std::vector<value_type> all_samples = mxx::gatherv(local_splitters, 0, comm);

// sort and pick p-1 samples on master
if (comm.rank() == 0) {
// 3. local sort on master
std::sort(all_samples.begin(), all_samples.end(), comp);
// distribute elements equally
mxx::distribute_inplace(local_splitters, comm);

// 4. pick p-1 splitters and broadcast them
if (local_splitters.size() != (size_t) p-1)
{
local_splitters.resize(p-1);
}
// split into `p` pieces and choose the `p-1` splitting elements
_Iterator pos = all_samples.begin();
for (std::size_t i = 0; i < local_splitters.size(); ++i)
{
std::size_t bucket_size = (p*s) / p + (i < static_cast<std::size_t>((p*s) % p) ? 1 : 0);
// pick last element of each bucket
local_splitters[i] = *(pos + (bucket_size-1));
pos += bucket_size;
}
// discard extra splitters, to make it even
if (local_splitters.size() != s) {
MXX_ASSERT(local_splitters.size() > s);
local_splitters.resize(s);
}

// size splitters for receiving
if (local_splitters.size() != (size_t)p-1) {
local_splitters.resize(p-1);
}
// sort splitters using parallel bitonic sort
bitonic_sort(local_splitters.begin(), local_splitters.end(), comp, comm);

// 4. broadcast and receive final splitters
MPI_Bcast(&local_splitters[0], local_splitters.size(), mpi_dt, 0, comm);
// select the last element on each process but the last
value_type my_splitter = local_splitters.back();
// allgather splitters (from all but the last processor)
std::vector<size_t> recv_sizes(comm.size(), 1);
recv_sizes.back() = 0;
std::vector<value_type> result_splitters = mxx::allgatherv(&my_splitter, (comm.rank() + 1 < comm.size()) ? 1 : 0, recv_sizes, comm);

return local_splitters;
// return resulting splitters
return result_splitters;
}


template <typename _Iterator, typename _Compare>
std::vector<typename std::iterator_traits<_Iterator>::value_type>
sample_block_decomp(_Iterator begin, _Iterator end, _Compare comp, int s, const mxx::comm& comm, MPI_Datatype mpi_dt)
{
sample_block_decomp(_Iterator begin, _Iterator end, _Compare comp, size_t s, const mxx::comm& comm) {
typedef typename std::iterator_traits<_Iterator>::value_type value_type;
std::size_t local_size = std::distance(begin, end);
MXX_ASSERT(local_size > 0);
int p = comm.size();

// 1. samples
// - pick `s` samples equally spaced such that `s` samples define `s+1`
Expand All @@ -184,47 +171,18 @@ sample_block_decomp(_Iterator begin, _Iterator end, _Compare comp, int s, const
++pos;
}

// 2. gather samples to `rank = 0`
// - TODO: rather call sample sort
// recursively and implement a base case for samplesort which does
// gather to rank=0, local sort and redistribute
if (comm.rank() == 0) {
std::vector<value_type> all_samples(p*s);
MPI_Gather(&local_splitters[0], s, mpi_dt,
&all_samples[0], s, mpi_dt, 0, comm);

// 3. local sort on master
std::sort(all_samples.begin(), all_samples.end(), comp);

// 4. pick p-1 splitters and broadcast them
if (local_splitters.size() != (size_t) p-1) {
local_splitters.resize(p-1);
}
// split into `p` pieces and choose the `p-1` splitting elements
_Iterator pos = all_samples.begin();
for (std::size_t i = 0; i < local_splitters.size(); ++i)
{
std::size_t bucket_size = (p*s) / p + (i < static_cast<std::size_t>((p*s) % p) ? 1 : 0);
// pick last element of each bucket
local_splitters[i] = *(pos + (bucket_size-1));
pos += bucket_size;
}
}
else
{
// simply send
MPI_Gather(&local_splitters[0], s, mpi_dt, NULL, 0, mpi_dt, 0, comm);

// resize splitters for receiving
if (local_splitters.size() != (size_t) p-1) {
local_splitters.resize(p-1);
}
}
// sort splitters using parallel bitonic sort
bitonic_sort(local_splitters.begin(), local_splitters.end(), comp, comm);

// 4. broadcast and receive final splitters
MPI_Bcast(&local_splitters[0], local_splitters.size(), mpi_dt, 0, comm);
// select the last element on each process but the last
value_type my_splitter = local_splitters.back();
// allgather splitters (from all but the last processor)
std::vector<size_t> recv_sizes(comm.size(), 1);
recv_sizes.back() = 0;
std::vector<value_type> result_splitters = mxx::allgatherv(&my_splitter, (comm.rank() + 1 < comm.size()) ? 1 : 0, recv_sizes, comm);

return local_splitters;
// return resulting splitters
return result_splitters;
}

template <typename _Iterator, typename _Compare>
Expand Down Expand Up @@ -329,7 +287,7 @@ std::vector<size_t> stable_split(_Iterator begin, _Iterator end, _Compare comp,


template<typename _Iterator, typename _Compare, bool _Stable = false>
void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_dt, const mxx::comm& comm) {
void samplesort(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm) {
// get value type of underlying data
typedef typename std::iterator_traits<_Iterator>::value_type value_type;

Expand All @@ -353,8 +311,6 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
SS_TIMER_END_SECTION("local_sort");


// number of samples
int s = p-1;
// local size
std::size_t local_size = std::distance(begin, end);

Expand All @@ -364,10 +320,10 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
bool _AssumeBlockDecomp = mxx::all_of(local_size == mypart.local_size(), comm);

// sample sort
// 1. pick `s` samples on each processor
// 2. gather to `rank=0`
// 3. local sort on master
// 4. broadcast the p-1 final splitters
// 1. local sort
// 2. pick `s` samples regularly spaced on each processor
// 3. bitonic sort samples
// 4. allgather the last sample of each process -> splitters
// 5. locally find splitter positions in data
// (if an identical splitter appears twice, then split evenly)
// => send_counts
Expand All @@ -381,10 +337,12 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
// get splitters, using the method depending on whether the input consists
// of arbitrary decompositions or not
std::vector<value_type> local_splitters;
// number of samples
size_t s = p-1;
if(_AssumeBlockDecomp)
local_splitters = sample_block_decomp(begin, end, comp, s, comm, mpi_dt);
local_splitters = sample_block_decomp(begin, end, comp, s, comm);
else
local_splitters = sample_arbit_decomp(begin, end, comp, s, comm, mpi_dt);
local_splitters = sample_arbit_decomp(begin, end, comp, s, comm);
SS_TIMER_END_SECTION("get_splitters");

// 5. locally find splitter positions in data
Expand All @@ -400,7 +358,6 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_

std::vector<size_t> recv_counts = mxx::all2all(send_counts, comm);
std::size_t recv_n = std::accumulate(recv_counts.begin(), recv_counts.end(), static_cast<size_t>(0));
// TODO: use different approach if there are less than p local elements
MXX_ASSERT(!_AssumeBlockDecomp || (local_size <= (size_t)p || recv_n <= 2* local_size));
std::vector<value_type> recv_elements(recv_n);
// TODO: use collective with iterators [begin,end) instead of pointers!
Expand Down Expand Up @@ -482,20 +439,6 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
SS_TIMER_END_SECTION("fix_partition");
}

template<typename _Iterator, typename _Compare, bool _Stable = false>
void samplesort(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm)
{
// get value type of underlying data
typedef typename std::iterator_traits<_Iterator>::value_type value_type;

// get MPI type
mxx::datatype dt = mxx::get_datatype<value_type>();
MPI_Datatype mpi_dt = dt.type();

// sort
impl::samplesort<_Iterator, _Compare, _Stable>(begin, end, comp, mpi_dt, comm);
}

} // namespace impl
} // namespace mxx

Expand Down

0 comments on commit 09a2ca0

Please sign in to comment.