Skip to content

Commit

Permalink
Added a way to specify if the MPI operation used in global scans is c…
Browse files Browse the repository at this point in the history
…ommutative.
  • Loading branch information
asrivast28 committed Jan 6, 2020
1 parent 14b4f0d commit 14408b4
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions include/mxx/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,23 +174,23 @@ struct get_builtin_op<T, const T&(*) (const T&, const T&)> {
* @note This assumes that the operator is commutative.
*
* @tparam T The input and ouput datatype of the binary operator.
* @tparam IsCommutative Whether or not the operation is commutative (default = true).
*/
template <typename T, bool IsCommutative = true>
template <typename T>
class custom_op {
public:

/**
* @brief Creates a custom operator given a functor and the associated
* `MPI_Datatype`.
*
* @tparam Func Type of the functor, can be a function pointer, lambda
* function, or std::function or any object with a
* `T operator(T& x, T& y)` member.
* @param func The instance of the functor.
* @tparam Func Type of the functor, can be a function pointer, lambda
* function, or std::function or any object with a
* `T operator(T& x, T& y)` member.
* @param func The instance of the functor.
* @param commutative Whether or not the operation is commutative (default = true).
*/
template <typename Func>
custom_op(Func func) : m_builtin(false) {
custom_op(Func func, const bool commutative = true) : m_builtin(false) {
if (mxx::is_builtin_type<T>::value) {
// check if the operator is MPI built-in (in case the type
// is also a MPI built-in type)
Expand All @@ -215,7 +215,7 @@ class custom_op {
MPI_Type_dup(dt.type(), &m_type_copy);
attr_map<int, func_t>::set(m_type_copy, 1347, m_user_func);
// create op
MPI_Op_create(&custom_op::mpi_user_function, IsCommutative, &m_op);
MPI_Op_create(&custom_op::mpi_user_function, commutative, &m_op);
}
}

Expand Down Expand Up @@ -731,7 +731,7 @@ inline std::vector<T> local_scan(const std::vector<T>& in) {
// global scans

template <typename InIterator, typename OutIterator, typename Func>
void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm()) {
void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
OutIterator o = out;
size_t n = std::distance(begin, end);
// create subcommunicator for those processes which contain elements
Expand All @@ -744,7 +744,7 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c
T sum = T();
if (n > 0)
sum = *(out+(n-1));
T presum = exscan(sum, func, nonzero_comm);
T presum = exscan(sum, func, nonzero_comm, commutative);
// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
*o = func(presum, *o);
Expand All @@ -756,7 +756,7 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c

// inplace!
template <typename Iterator, typename Func>
inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm()) {
inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
Iterator o = begin;
size_t n = std::distance(begin, end);
mxx::comm nonzero_comm = comm.split(n > 0);
Expand All @@ -766,7 +766,7 @@ inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const m
// mxx::exscan
typedef typename std::iterator_traits<Iterator>::value_type T;
T sum = *(begin + (n-1));
T presum = exscan(sum, func, nonzero_comm);
T presum = exscan(sum, func, nonzero_comm, commutative);

// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
Expand All @@ -788,8 +788,8 @@ inline void global_scan_inplace(Iterator begin, Iterator end, const mxx::comm& c

// std::vector overloads
template <typename T, typename Func>
inline void global_scan_inplace(std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm()) {
global_scan_inplace(in.begin(), in.end(), func, comm);
inline void global_scan_inplace(std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
global_scan_inplace(in.begin(), in.end(), func, comm, commutative);
}

template <typename T>
Expand All @@ -798,9 +798,9 @@ inline void global_scan_inplace(std::vector<T>& in, const mxx::comm& comm = mxx:
}

template <typename T, typename Func>
inline std::vector<T> global_scan(const std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm()) {
inline std::vector<T> global_scan(const std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
std::vector<T> result(in.size());
global_scan(in.begin(), in.end(), result.begin(), func, comm);
global_scan(in.begin(), in.end(), result.begin(), func, comm, commutative);
return result;
}

Expand Down Expand Up @@ -838,9 +838,9 @@ inline std::vector<T> exscan_vec(const std::vector<T>& x, Func func, const mxx::
// single element

template <typename T, typename Func>
T exscan(const T& x, Func func, const mxx::comm& comm = mxx::comm()) {
T exscan(const T& x, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
// get op
mxx::custom_op<T> op(std::forward<Func>(func));
mxx::custom_op<T> op(std::forward<Func>(func), commutative);
// perform reduction
T result;
MPI_Exscan(const_cast<T*>(&x), &result, 1, op.get_type(), op.get_op(), comm);
Expand Down

0 comments on commit 14408b4

Please sign in to comment.