diff --git a/include/mxx/reduction.hpp b/include/mxx/reduction.hpp index 7fb7bb5..a84b995 100644 --- a/include/mxx/reduction.hpp +++ b/include/mxx/reduction.hpp @@ -731,7 +731,7 @@ inline std::vector local_scan(const std::vector& in) { // global scans template -void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { +void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) { OutIterator o = out; size_t n = std::distance(begin, end); // create subcommunicator for those processes which contain elements @@ -753,10 +753,15 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c } } +template +void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm()) { + global_scan(begin, end, out, func, true, comm); +} + // inplace! template -inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { +inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) { Iterator o = begin; size_t n = std::distance(begin, end); mxx::comm nonzero_comm = comm.split(n > 0); @@ -776,38 +781,55 @@ inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const m } } +template +inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm()) { + global_scan_inplace(begin, end, func, true, comm); +} + template inline void global_scan(InIterator begin, InIterator end, OutIterator out, const mxx::comm& comm = mxx::comm()) { - return global_scan(begin, end, out, std::plus::value_type>(), comm); + return global_scan(begin, end, out, std::plus::value_type>(), true, comm); } template inline void global_scan_inplace(Iterator begin, Iterator end, const mxx::comm& comm = mxx::comm()) { - return global_scan_inplace(begin, end, std::plus::value_type>(), comm); + return global_scan_inplace(begin, end, std::plus::value_type>(), true, comm); } // std::vector overloads template -inline void global_scan_inplace(std::vector& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { - global_scan_inplace(in.begin(), in.end(), func, comm, commutative); +inline void global_scan_inplace(std::vector& in, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) { + global_scan_inplace(in.begin(), in.end(), func, commutative, comm); +} + +template +inline void global_scan_inplace(std::vector& in, Func func, const mxx::comm& comm = mxx::comm()) { + global_scan_inplace(in.begin(), in.end(), func, true, comm); } template inline void global_scan_inplace(std::vector& in, const mxx::comm& comm = mxx::comm()) { - global_scan_inplace(in.begin(), in.end(), std::plus(), comm); + global_scan_inplace(in.begin(), in.end(), std::plus(), true, comm); +} + +template +inline std::vector global_scan(const std::vector& in, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) { + std::vector result(in.size()); + global_scan(in.begin(), in.end(), result.begin(), func, commutative, comm); + return result; } template -inline std::vector global_scan(const std::vector& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) { +inline std::vector global_scan(const std::vector& in, Func func, const mxx::comm& comm = mxx::comm()) { std::vector result(in.size()); - global_scan(in.begin(), in.end(), result.begin(), func, comm, commutative); + global_scan(in.begin(), in.end(), result.begin(), func, true, comm); return result; } template inline std::vector global_scan(const std::vector& in, const mxx::comm& comm = mxx::comm()) { std::vector result(in.size()); - global_scan(in.begin(), in.end(), result.begin(), std::plus(), comm); + global_scan(in.begin(), in.end(), result.begin(), std::plus(), true, comm); return result; }