Skip to content

Commit

Permalink
Merge pull request #24 from asrivast28/FixScan
Browse files Browse the repository at this point in the history
Fix in global_scan and global_scan_inplace.
  • Loading branch information
patflick committed Mar 3, 2020
2 parents 84c9628 + 9e184c5 commit 1c0ab9f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
24 changes: 13 additions & 11 deletions include/mxx/reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,14 @@ void global_scan(InIterator begin, InIterator end, OutIterator out, Func func, c
local_scan(begin, end, out, func);
// mxx::scan
typedef typename std::iterator_traits<OutIterator>::value_type T;
T sum = T();
if (n > 0)
sum = *(out+(n-1));
T sum = *(out+(n-1));
T presum = exscan(sum, func, nonzero_comm, commutative);
// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
*o = func(presum, *o);
++o;
if (nonzero_comm.rank() != 0) {
// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
*o = func(presum, *o);
++o;
}
}
}
}
Expand All @@ -773,10 +773,12 @@ inline void global_scan_inplace(Iterator begin, Iterator end, Func func, const b
T sum = *(begin + (n-1));
T presum = exscan(sum, func, nonzero_comm, commutative);

// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
*o = func(presum, *o);
++o;
if (nonzero_comm.rank() != 0) {
// accumulate previous sum on all local elements
for (size_t i = 0; i < n; ++i) {
*o = func(presum, *o);
++o;
}
}
}
}
Expand Down
27 changes: 26 additions & 1 deletion test/test_reductions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ TEST(MxxReduce, GlobalReduce) {

TEST(MxxReduce, GlobalScan) {
mxx::comm c;
// test reduce with zero elements for some processes
// test scan with zero elements for some processes
size_t n = 0;
int presize = 0;
if (c.rank() % 2 == 0) {
Expand All @@ -273,6 +273,31 @@ TEST(MxxReduce, GlobalScan) {
}
}

TEST(MxxReduce, GlobalScanMin) {
mxx::comm c;
// test scan with min and an array of positive
// elements sorted in ascending order
size_t n = c.size();
std::vector<int> local(n);
for (size_t i = 0; i < n; ++i) {
local[i] = (c.rank()*n)+i+1;
}
// test inplace scan
std::vector<int> local_cpy(local);
mxx::global_scan_inplace(local_cpy.begin(), local_cpy.end(), mxx::min<int>(), c);
for (size_t i = 0; i < n; ++i) {
// 1 is both the first as well as the minimum element
ASSERT_EQ(1, local_cpy[i]);
}
// test scan
std::vector<int> result = mxx::global_scan(local, mxx::min<int>(), c);
ASSERT_EQ(local.size(), result.size());
for (size_t i = 0; i < n; ++i) {
// 1 is both the first as well as the minimum element
ASSERT_EQ(1, result[i]);
}
}

TEST(MxxReduce, GlobalExScan) {
mxx::comm c;
// test reduce with zero elements for some processes
Expand Down

0 comments on commit 1c0ab9f

Please sign in to comment.