Skip to content

Commit

Permalink
Adding overloads for noncommutaive global_scan*
Browse files Browse the repository at this point in the history
  • Loading branch information
asrivast28 committed Feb 12, 2020
1 parent 14408b4 commit 7bcef25
Showing 1 changed file with 32 additions and 10 deletions.
42 changes: 32 additions & 10 deletions include/mxx/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ inline std::vector<T> local_scan(const std::vector<T>& in) {
// global scans

template <typename InIterator, typename OutIterator, typename Func>
void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) {
OutIterator o = out;
size_t n = std::distance(begin, end);
// create subcommunicator for those processes which contain elements
Expand All @@ -753,10 +753,15 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c
}
}

template <typename InIterator, typename OutIterator, typename Func>
void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, const mxx::comm& comm = mxx::comm()) {
global_scan(begin, end, out, func, true, comm);
}


// inplace!
template <typename Iterator, typename Func>
inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) {
Iterator o = begin;
size_t n = std::distance(begin, end);
mxx::comm nonzero_comm = comm.split(n > 0);
Expand All @@ -776,38 +781,55 @@ inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const m
}
}

template <typename Iterator, typename Func>
inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const mxx::comm& comm = mxx::comm()) {
global_scan_inplace(begin, end, func, true, comm);
}

template <typename InIterator, typename OutIterator>
inline void global_scan(InIterator begin, InIterator end, OutIterator out, const mxx::comm& comm = mxx::comm()) {
return global_scan(begin, end, out, std::plus<typename std::iterator_traits<OutIterator>::value_type>(), comm);
return global_scan(begin, end, out, std::plus<typename std::iterator_traits<OutIterator>::value_type>(), true, comm);
}

template <typename Iterator>
inline void global_scan_inplace(Iterator begin, Iterator end, const mxx::comm& comm = mxx::comm()) {
return global_scan_inplace(begin, end, std::plus<typename std::iterator_traits<Iterator>::value_type>(), comm);
return global_scan_inplace(begin, end, std::plus<typename std::iterator_traits<Iterator>::value_type>(), true, comm);
}

// std::vector overloads
template <typename T, typename Func>
inline void global_scan_inplace(std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
global_scan_inplace(in.begin(), in.end(), func, comm, commutative);
inline void global_scan_inplace(std::vector<T>& in, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) {
global_scan_inplace(in.begin(), in.end(), func, commutative, comm);
}

template <typename T, typename Func>
inline void global_scan_inplace(std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm()) {
global_scan_inplace(in.begin(), in.end(), func, true, comm);
}

template <typename T>
inline void global_scan_inplace(std::vector<T>& in, const mxx::comm& comm = mxx::comm()) {
global_scan_inplace(in.begin(), in.end(), std::plus<T>(), comm);
global_scan_inplace(in.begin(), in.end(), std::plus<T>(), true, comm);
}

template <typename T, typename Func>
inline std::vector<T> global_scan(const std::vector<T>& in, Func func, const bool commutative, const mxx::comm& comm = mxx::comm()) {
std::vector<T> result(in.size());
global_scan(in.begin(), in.end(), result.begin(), func, commutative, comm);
return result;
}

template <typename T, typename Func>
inline std::vector<T> global_scan(const std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm(), const bool commutative = true) {
inline std::vector<T> global_scan(const std::vector<T>& in, Func func, const mxx::comm& comm = mxx::comm()) {
std::vector<T> result(in.size());
global_scan(in.begin(), in.end(), result.begin(), func, comm, commutative);
global_scan(in.begin(), in.end(), result.begin(), func, true, comm);
return result;
}

template <typename T>
inline std::vector<T> global_scan(const std::vector<T>& in, const mxx::comm& comm = mxx::comm()) {
std::vector<T> result(in.size());
global_scan(in.begin(), in.end(), result.begin(), std::plus<T>(), comm);
global_scan(in.begin(), in.end(), result.begin(), std::plus<T>(), true, comm);
return result;
}

Expand Down

0 comments on commit 7bcef25

Please sign in to comment.