diff --git a/include/mxx/reduction.hpp b/include/mxx/reduction.hpp index 24bac16..7fb7bb5 100644 --- a/include/mxx/reduction.hpp +++ b/include/mxx/reduction.hpp @@ -174,9 +174,8 @@ struct get_builtin_op { * @note This assumes that the operator is commutative. * * @tparam T The input and ouput datatype of the binary operator. - * @tparam IsCommutative Whether or not the operation is commutative (default = true). */ -template +template class custom_op { public: @@ -184,13 +183,14 @@ class custom_op { * @brief Creates a custom operator given a functor and the associated * `MPI_Datatype`. * - * @tparam Func Type of the functor, can be a function pointer, lambda - * function, or std::function or any object with a - * `T operator(T& x, T& y)` member. - * @param func The instance of the functor. + * @tparam Func Type of the functor, can be a function pointer, lambda + * function, or std::function or any object with a + * `T operator(T& x, T& y)` member. + * @param func The instance of the functor. + * @param commutative Whether or not the operation is commutative (default = true). */ template - custom_op(Func func) : m_builtin(false) { + custom_op(Func func, const bool commutative = true) : m_builtin(false) { if (mxx::is_builtin_type::value) { // check if the operator is MPI built-in (in case the type // is also a MPI built-in type) @@ -215,7 +215,7 @@ class custom_op { MPI_Type_dup(dt.type(), &m_type_copy); attr_map::set(m_type_copy, 1347, m_user_func); // create op - MPI_Op_create(&custom_op::mpi_user_function, IsCommutative, &m_op); + MPI_Op_create(&custom_op::mpi_user_function, commutative, &m_op); } } @@ -731,7 +731,7 @@ inline std::vector local_scan(const std::vector& in) { // global scans template -void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm()) { +void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { OutIterator o = out; size_t n = std::distance(begin, end); // create subcommunicator for those processes which contain elements @@ -744,7 +744,7 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c T sum = T(); if (n > 0) sum = *(out+(n-1)); - T presum = exscan(sum, func, nonzero_comm); + T presum = exscan(sum, func, nonzero_comm, commutative); // accumulate previous sum on all local elements for (size_t i = 0; i < n; ++i) { *o = func(presum, *o); @@ -756,7 +756,7 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c // inplace! template -inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm()) { +inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { Iterator o = begin; size_t n = std::distance(begin, end); mxx::comm nonzero_comm = comm.split(n > 0); @@ -766,7 +766,7 @@ inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const m // mxx::exscan typedef typename std::iterator_traits::value_type T; T sum = *(begin + (n-1)); - T presum = exscan(sum, func, nonzero_comm); + T presum = exscan(sum, func, nonzero_comm, commutative); // accumulate previous sum on all local elements for (size_t i = 0; i < n; ++i) { @@ -788,8 +788,8 @@ inline void global_scan_inplace(Iterator begin, Iterator end, const mxx::comm& c // std::vector overloads template -inline void global_scan_inplace(std::vector& in, Func func, const mxx::comm& comm = mxx::comm()) { - global_scan_inplace(in.begin(), in.end(), func, comm); +inline void global_scan_inplace(std::vector& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { + global_scan_inplace(in.begin(), in.end(), func, comm, commutative); } template @@ -798,9 +798,9 @@ inline void global_scan_inplace(std::vector& in, const mxx::comm& comm = mxx: } template -inline std::vector global_scan(const std::vector& in, Func func, const mxx::comm& comm = mxx::comm()) { +inline std::vector global_scan(const std::vector& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { std::vector result(in.size()); - global_scan(in.begin(), in.end(), result.begin(), func, comm); + global_scan(in.begin(), in.end(), result.begin(), func, comm, commutative); return result; } @@ -838,9 +838,9 @@ inline std::vector exscan_vec(const std::vector& x, Func func, const mxx:: // single element template -T exscan(const T& x, Func func, const mxx::comm& comm = mxx::comm()) { +T exscan(const T& x, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { // get op - mxx::custom_op op(std::forward(func)); + mxx::custom_op op(std::forward(func), commutative); // perform reduction T result; MPI_Exscan(const_cast(&x), &result, 1, op.get_type(), op.get_op(), comm);