Skip to content

Commit

Permalink
FIX issue #1: casting to 0 to size_t in std::accumulate to stop
Browse files Browse the repository at this point in the history
overflows
  • Loading branch information
patflick committed Mar 1, 2016
1 parent aac9ad9 commit 8fc5a91
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions include/mxx/collective.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ template <typename T>
void scatterv(const T* msgs, const std::vector<size_t>& sizes, T* out, size_t recv_size, int root, const mxx::comm& comm = mxx::comm()) {
MXX_ASSERT(root != comm.rank() || sizes.size() == static_cast<size_t>(comm.size()));
// get total send size
size_t send_size = std::accumulate(sizes.begin(), sizes.end(), 0);
size_t send_size = std::accumulate(sizes.begin(), sizes.end(), static_cast<size_t>(0));
mxx::datatype sizedt = mxx::get_datatype<size_t>();
MPI_Bcast(&send_size, 1, sizedt.type(), root, comm);
// check if we need to use the custom BIG scatterv
Expand Down Expand Up @@ -701,7 +701,7 @@ std::vector<T> gather(const T& x, int root, const mxx::comm& comm = mxx::comm())
*/
template <typename T>
void gatherv(const T* data, size_t size, T* out, const std::vector<size_t>& recv_sizes, int root, const mxx::comm& comm = mxx::comm()) {
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
mxx::datatype mpi_sizet = mxx::get_datatype<size_t>();
// tell everybody about the total size
MPI_Bcast(&total_size, 1, mpi_sizet.type(), root, comm);
Expand Down Expand Up @@ -756,7 +756,7 @@ void gatherv(const T* data, size_t size, T* out, const std::vector<size_t>& recv
template <typename T>
std::vector<T> gatherv(const T* data, size_t size, const std::vector<size_t>& recv_sizes, int root, const mxx::comm& comm = mxx::comm()) {
std::vector<T> result;
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
if (comm.rank() == root)
result.resize(total_size);
gatherv(data, size, &result[0], recv_sizes, root, comm);
Expand Down Expand Up @@ -995,7 +995,7 @@ std::vector<T> allgather(const T& x, const mxx::comm& comm = mxx::comm()) {
*/
template <typename T>
void allgatherv(const T* data, size_t size, T* out, const std::vector<size_t>& recv_sizes, const mxx::comm& comm = mxx::comm()) {
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
MXX_ASSERT(recv_sizes.size() == static_cast<size_t>(comm.size()));
mxx::datatype mpi_sizet = mxx::get_datatype<size_t>();
if (total_size >= mxx::max_int) {
Expand Down Expand Up @@ -1040,7 +1040,7 @@ void allgatherv(const T* data, size_t size, T* out, const std::vector<size_t>& r
*/
template <typename T>
std::vector<T> allgatherv(const T* data, size_t size, const std::vector<size_t>& recv_sizes, const mxx::comm& comm = mxx::comm()) {
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
std::vector<T> result(total_size);
allgatherv(data, size, &result[0], recv_sizes, comm);
return result;
Expand Down Expand Up @@ -1241,8 +1241,8 @@ std::vector<T> all2all(const std::vector<T>& msgs, const mxx::comm& comm = mxx::
*/
template <typename T>
void all2allv(const T* msgs, const std::vector<size_t>& send_sizes, T* out, const std::vector<size_t>& recv_sizes, const mxx::comm& comm = mxx::comm()) {
size_t total_send_size = std::accumulate(send_sizes.begin(), send_sizes.end(), 0);
size_t total_recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_send_size = std::accumulate(send_sizes.begin(), send_sizes.end(), static_cast<size_t>(0));
size_t total_recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
size_t local_max_size = std::max(total_send_size, total_recv_size);
mxx::datatype mpi_sizet = mxx::get_datatype<size_t>();
size_t max;
Expand Down Expand Up @@ -1296,8 +1296,8 @@ void all2allv(const T* msgs, const std::vector<size_t>& send_sizes, T* out, cons

template <typename T>
void all2allv(const T* msgs, const std::vector<size_t>& send_sizes, const std::vector<size_t>& send_displs, T* out, const std::vector<size_t>& recv_sizes, const std::vector<size_t>& recv_displs, const mxx::comm& comm = mxx::comm()) {
size_t total_send_size = std::accumulate(send_sizes.begin(), send_sizes.end(), 0);
size_t total_recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t total_send_size = std::accumulate(send_sizes.begin(), send_sizes.end(), static_cast<size_t>(0));
size_t total_recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
size_t local_max_size = std::max(total_send_size, total_recv_size);
mxx::datatype mpi_sizet = mxx::get_datatype<size_t>();
size_t max;
Expand Down Expand Up @@ -1338,7 +1338,7 @@ void all2allv(const T* msgs, const std::vector<size_t>& send_sizes, const std::v
*/
template <typename T>
std::vector<T> all2allv(const T* msgs, const std::vector<size_t>& send_sizes, const std::vector<size_t>& recv_sizes, const mxx::comm& comm = mxx::comm()) {
size_t recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), 0);
size_t recv_size = std::accumulate(recv_sizes.begin(), recv_sizes.end(), static_cast<size_t>(0));
std::vector<T> result(recv_size);
all2allv(msgs, send_sizes, &result[0], recv_sizes, comm);
return result;
Expand All @@ -1365,7 +1365,7 @@ std::vector<T> all2allv(const T* msgs, const std::vector<size_t>& send_sizes, co
*/
template <typename T>
std::vector<T> all2allv(const std::vector<T>& msgs, const std::vector<size_t>& send_sizes, const std::vector<size_t>& recv_sizes, const mxx::comm& comm = mxx::comm()) {
MXX_ASSERT(msgs.size() == std::accumulate(send_sizes.begin(), send_sizes.end(), (size_t)0));
MXX_ASSERT(msgs.size() == std::accumulate(send_sizes.begin(), send_sizes.end(), static_cast<size_t>(0)));
return all2allv(&msgs[0], send_sizes, recv_sizes, comm);
}

Expand Down Expand Up @@ -1447,7 +1447,7 @@ void all2all_func(std::vector<T>& msgs, Func target_func, const mxx::comm& comm
std::vector<size_t> recv_counts = all2all(send_counts, comm);

// resize messages to fit recv
std::size_t recv_size = std::accumulate(recv_counts.begin(), recv_counts.end(), 0);
std::size_t recv_size = std::accumulate(recv_counts.begin(), recv_counts.end(), static_cast<size_t>(0));
msgs.clear();
msgs.shrink_to_fit();
msgs.resize(recv_size);
Expand Down

0 comments on commit 8fc5a91

Please sign in to comment.