Skip to content

Commit

Permalink
stream sync_cout interface + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
patflick committed Oct 7, 2015
1 parent 7fd0b7d commit e7922d7
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 30 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ development, prototyping, and deployment.
- Google Test based `MPI` unit testing framework
- Parallel sorting with similar API than `std::sort` (`mxx::sort`)

### Planned
### Planned / TODO

- Parallel random number engines (for use with `C++11` standard library distributions)
- More parallel (standard) algorithms
Expand All @@ -44,6 +44,8 @@ development, prototyping, and deployment.
- Implementing and tuning different sorting algorithms
- Communicator classes for different topologies
- `mxx::env` similar to `boost::mpi::env` for wrapping `MPI_Init` and `MPI_Finalize`
- Increase test coverage:
![codecov.io](http:https://codecov.io/github/patflick/mxx/branch.svg?branch=master)

### Status

Expand Down
1 change: 0 additions & 1 deletion include/mxx/big_collective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

#include <vector>

#include "comm.hpp"
#include "datatypes.hpp"
#include "shift.hpp" // FIXME: include only `requests`

Expand Down
4 changes: 3 additions & 1 deletion include/mxx/collective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,8 +752,10 @@ void gatherv(const T* data, size_t size, T* out, const std::vector<size_t>& recv
*/
template <typename T>
std::vector<T> gatherv(const T* data, size_t size, const std::vector<size_t>& recv_sizes, int root, const mxx::comm& comm = mxx::comm()) {
std::vector<T> result;
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
std::vector<T> result(total_size);
if (comm.rank() == root)
result.resize(total_size);
gatherv(data, size, &result[0], recv_sizes, root, comm);
return result;
}
Expand Down
12 changes: 12 additions & 0 deletions include/mxx/comm_def.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#include "shift.hpp"
#include "collective.hpp"
#include "reduction.hpp"
#include "stream.hpp"

#ifdef MXX_STREAM_DONE
#ifdef MXX_SHIFT_DONE
#ifdef MXX_COLLECTIVE_DONE
#ifdef MXX_REDUCTION_DONE
Expand All @@ -46,6 +49,14 @@ inline void comm::with_subset(bool cond, Func f) const {
}
}

inline mxx::sync_ostream comm::sync_cout() const {
return mxx::sync_cout(*this);
}

inline mxx::sync_ostream comm::sync_cerr() const {
return mxx::sync_cerr(*this);
}

inline comm comm::split_shared() const {
comm o;
#if MPI_VERSION >= 3
Expand Down Expand Up @@ -110,3 +121,4 @@ inline comm comm::split_shared() const {
#endif
#endif
#endif
#endif
12 changes: 12 additions & 0 deletions include/mxx/comm_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ namespace mxx {
static constexpr int any_tag = MPI_ANY_TAG;
static constexpr int any_source = MPI_ANY_SOURCE;


// forward declaration for streams
template <typename CharT, class Traits = std::char_traits<CharT> >
class sync_basic_ostream;

class comm {
public:
/// Default constructor defaults to COMM_WORLD
Expand Down Expand Up @@ -162,6 +167,13 @@ class comm {
}
}

// returns synchronized stream object for this communicator.
// The stream object's destructor contains a collective operation
// for synchronized cout/cerr output
sync_basic_ostream<char> sync_cout() const;
sync_basic_ostream<char> sync_cerr() const;


/**
* @brief Returns a new communicator which is the reverse of this.
*
Expand Down
29 changes: 12 additions & 17 deletions include/mxx/samplesort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,44 +22,38 @@
#ifndef MXX_SAMPLESORT_HPP
#define MXX_SAMPLESORT_HPP

#include <mpi.h>

#include <assert.h>
#include "comm.hpp"
#include "partition.hpp"
#include "datatypes.hpp"
#include "collective.hpp"
#include "shift.hpp"
#include "distribution.hpp"

#include <iterator>
#include <algorithm>
#include <vector>
#include <limits>

#ifdef __GNUC__
#ifndef __clang__
// for multiway-merge
// TODO: impelement own in case it is not GNU C++
// TODO: implement multiway merging ourselves
#include <parallel/multiway_merge.h>
#include <parallel/merge.h>
#define MXX_USE_GCC_MULTIWAY_MERGE
#endif
#endif

#include "partition.hpp"
#include "datatypes.hpp"
#include "collective.hpp"
#include "shift.hpp"
#include "distribution.hpp"
#include "timer.hpp"


#define SS_ENABLE_TIMER 0
#if SS_ENABLE_TIMER
#include "timer.hpp"
#define SS_TIMER_START(comm) mxx::section_timer timer(std::cerr, comm, 0);
#define SS_TIMER_END_SECTION(str) timer.end_section(str);
#else
#define SS_TIMER_START(comm)
#define SS_TIMER_END_SECTION(str)
#endif

#define MEASURE_LOAD_BALANCE 0

namespace mxx {
namespace impl {

Expand Down Expand Up @@ -349,6 +343,7 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
else
std::sort(begin, end, comp);

// sequential case: we're done
if (p == 1)
return;

Expand Down Expand Up @@ -378,10 +373,10 @@ void samplesort(_Iterator begin, _Iterator end, _Compare comp, MPI_Datatype mpi_
// => send_counts
// 6. distribute send_counts with all2all to get recv_counts
// 7. allocate enough space (may be more than previously allocated) for receiving
// 8. all2all
// 9. local reordering
// 8. all2allv
// 9. local reordering (multiway-merge or again std::sort)
// A. equalizing distribution into original size (e.g.,block decomposition)
// by elements to neighbors
// by sending elements to neighbors

// get splitters, using the method depending on whether the input consists
// of arbitrary decompositions or not
Expand Down
45 changes: 35 additions & 10 deletions include/mxx/sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
* @brief Implements the interface for parallel sorting.
*
* TODO:
* - [ ] fix stable sort
* - [ ] fix sort on non GCC compilers
* - [x] fix stable sort
* - [x] fix sort on non GCC compilers
* - [ ] radix sort
* - [ ] fix sorting of samples
* - [ ] implement and try out different parallel sorting algorithms
Expand All @@ -41,23 +41,40 @@ void sort(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm =
impl::samplesort<_Iterator, _Compare, false>(begin, end, comp, comm);
}

template <typename _Iterator>
void sort(_Iterator begin, _Iterator end, const mxx::comm& comm = mxx::comm()) {
typedef std::less<typename std::iterator_traits<_Iterator>::value_type> Cmp;
impl::samplesort<_Iterator, Cmp, false>(begin, end, Cmp(), comm);
}

template<typename _Iterator, typename _Compare>
void stable_sort(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm = mxx::comm()) {
// use stable sample sort
impl::samplesort<_Iterator, _Compare, true>(begin, end, comp, comm);
}

template <typename _Iterator>
void stable_sort(_Iterator begin, _Iterator end, const mxx::comm& comm = mxx::comm()) {
typedef std::less<typename std::iterator_traits<_Iterator>::value_type> Cmp;
impl::samplesort<_Iterator, Cmp, true>(begin, end, Cmp(), comm);
}

template<typename _Iterator, typename _Compare>
bool is_sorted(_Iterator begin, _Iterator end, _Compare comp, MPI_Comm comm = MPI_COMM_WORLD)
{
bool is_sorted(_Iterator begin, _Iterator end, _Compare comp, const mxx::comm& comm = mxx::comm()) {
return impl::is_sorted(begin, end, comp, comm);
}

template<typename _Iterator>
bool is_sorted(_Iterator begin, _Iterator end, const mxx::comm& comm = mxx::comm()) {
typedef std::less<typename std::iterator_traits<_Iterator>::value_type> Cmp;
return impl::is_sorted(begin, end, Cmp(), comm);
}

// assumes input is sorted, removes duplicates in global range
template <typename ForwardIt, typename BinaryPredicate>
ForwardIt unique(ForwardIt begin, ForwardIt end, BinaryPredicate eq, const mxx::comm& comm) {
typedef typename std::iterator_traits<ForwardIt>::value_type T;
ForwardIt dest = begin;
template <typename Iterator, typename BinaryPredicate>
Iterator unique(Iterator begin, Iterator end, BinaryPredicate eq, const mxx::comm& comm = mxx::comm()) {
typedef typename std::iterator_traits<Iterator>::value_type T;
Iterator dest = begin;
mxx::comm c = comm.split(begin != end);
comm.with_subset(begin != end, [&](const mxx::comm& c) {
size_t n = std::distance(begin, end);
Expand All @@ -67,8 +84,11 @@ ForwardIt unique(ForwardIt begin, ForwardIt end, BinaryPredicate eq, const mxx::
T prev = mxx::right_shift(last, c);

// skip elements which are equal to the last one on the previous processor
while (eq(prev, *begin))
++begin;
if (c.rank() > 0)
while (begin != end && eq(prev, *begin))
++begin;
if (begin == end)
return;
*dest = *begin;

// remove duplicates
Expand All @@ -80,6 +100,11 @@ ForwardIt unique(ForwardIt begin, ForwardIt end, BinaryPredicate eq, const mxx::
return dest;
}

template <typename Iterator>
Iterator unique(Iterator begin, Iterator end, const mxx::comm& comm = mxx::comm()) {
return unique(begin, end, std::equal_to<typename std::iterator_traits<Iterator>::value_type>(), comm);
}

#include "comm_def.hpp"

} // namespace mxx
Expand Down
110 changes: 110 additions & 0 deletions include/mxx/stream.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright 2015 Georgia Institute of Technology
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http:https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


#include "comm_fwd.hpp"
#include "collective.hpp"
#ifdef MXX_COLLECTIVE_DONE


#ifndef MXX_STREAM_HPP
#define MXX_STREAM_HPP

#include <iostream>
#include <sstream>

namespace mxx {


// stream class that buffers on each process till the explicit `sync_flush`
// is called
// then all data is collectively send to rank 0 and there added to the
// wrapped stream object
template <typename CharT, class Traits>
class sync_basic_ostream : public std::basic_ostream<CharT, Traits> {
protected:
const mxx::comm& m_comm;
std::basic_ostream<CharT, Traits>* m_stream;
int m_root;
typedef std::basic_stringbuf<CharT, Traits, std::allocator<CharT>> strbuf_t;
std::unique_ptr<strbuf_t> m_buf;
public:
typedef std::basic_ostream<CharT, Traits> base_stream;

// for rank 0
sync_basic_ostream(const mxx::comm& comm, int root, base_stream& stream) : m_comm(comm), m_stream(&stream), m_root(root), m_buf(new strbuf_t()) {
MXX_ASSERT(0 <= root && root < comm.size());
this->rdbuf(m_buf.get());
}

sync_basic_ostream(const mxx::comm& comm, int root) : m_comm(comm), m_stream(nullptr), m_root(root), m_buf(new strbuf_t()) {
MXX_ASSERT(0 <= root && root < comm.size());
MXX_ASSERT(root != comm.rank()); // the root node can't have a null stream
this->rdbuf(m_buf.get());
}

sync_basic_ostream(sync_basic_ostream&& o)
: m_comm(o.m_comm), m_stream(o.m_stream), m_root(o.m_root),
m_buf(std::move(o.m_buf)) {
o.setstate(std::ios_base::badbit);
o.rdbuf();
this->rdbuf(m_buf.get());
}

sync_basic_ostream(const sync_basic_ostream& o) = delete;

void sync_flush() {
// communicate all data to rank `root`
std::basic_string<CharT, Traits> str = m_buf->str();
std::vector<size_t> recv_counts = mxx::gather(str.size(), m_root, m_comm);
std::vector<CharT> strings = mxx::gatherv(&str[0], str.size(), recv_counts, m_root, m_comm);
// on `root`: output all strings
if (m_comm.rank() == m_root) {
typename std::vector<CharT>::iterator begin = strings.begin();
for (int i = 0; i < m_comm.size(); ++i) {
std::basic_string<CharT, Traits> recv_string(begin, begin + recv_counts[i]);
*m_stream << recv_string;
begin += recv_counts[i];
}
}
// clear buffer content
m_buf->str("");
}

// sync upon destruction
virtual ~sync_basic_ostream() {
sync_flush();
}
};

using sync_ostream = sync_basic_ostream<char>;

inline sync_ostream sync_cout(const mxx::comm& comm, int root = 0) {
return comm.rank() == root ? sync_ostream(comm, root, std::cout) : sync_ostream(comm, root);
}

inline sync_ostream sync_cerr(const mxx::comm& comm, int root = 0) {
return comm.rank() == root ? sync_ostream(comm, root, std::cerr) : sync_ostream(comm, root);
}

} // namespace mxx


#define MXX_STREAM_DONE
#include "comm_def.hpp"

#endif // MXX_STREAM_HPP
#endif
20 changes: 20 additions & 0 deletions test/test_sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <gtest/gtest.h>
#include <mxx/sort.hpp>
#include <mxx/shift.hpp>
#include <mxx/stream.hpp>

#include <vector>
#include <iostream>
Expand Down Expand Up @@ -114,3 +115,22 @@ TEST(MxxSort, StableSort) {
};
ASSERT_TRUE(mxx::is_sorted(vec.begin(), vec.end(), full_cmp, c));
}

TEST(MxxSort, Unique) {
mxx::comm c;
std::vector<int> vec(100);
std::srand(13*c.rank());
int i = 0;
std::generate(vec.begin(), vec.end(), [&i](){return i++ % 10;});
mxx::sort(vec.begin(), vec.end());
std::vector<int>::iterator newend = mxx::unique(vec.begin(), vec.end());
std::vector<int> unique_els(vec.begin(), newend);

mxx::sync_cout(c) << "[Rank " << c.rank() << "]: Found " << unique_els.size() << " unique elements: " << unique_els << std::endl;

std::vector<int> all = mxx::allgatherv(unique_els);
ASSERT_EQ(10ul, all.size());
for (size_t i = 0; i < all.size(); ++i) {
ASSERT_EQ((int)i, all[i]);
}
}

0 comments on commit e7922d7

Please sign in to comment.