Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1253] fix control_flow_op #13555

Merged
merged 3 commits into from
Dec 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/operator/tensor/control_flow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct where {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
const DType* x, const DType* y) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i]? x[i] : y[i]));
}
Expand All @@ -64,7 +64,7 @@ struct where_csr {
// CType is condition data type
// i is for i-th row in the output
template<typename DType, typename CType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* out, const IType* cond_idx,
MSHADOW_XINLINE static void Map(index_t i, DType* out, const IType* cond_idx,
const IType* cond_indptr, const CType* cond_data,
const nnvm::dim_t num_cols, const DType* x) {
using nnvm::dim_t;
Expand Down Expand Up @@ -92,8 +92,8 @@ struct where_batch {
// DType is the output data type
// CType is the condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* out, const CType* cond,
const DType* x, const DType* y, int M) {
MSHADOW_XINLINE static void Map(index_t i, DType* out, const CType* cond,
const DType* x, const DType* y, index_t M) {
KERNEL_ASSIGN(out[i], req, (0 != cond[i/M]? x[i] : y[i]));
}
};
Expand All @@ -109,7 +109,7 @@ struct where_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond) {
KERNEL_ASSIGN(grad_out[i], req,
Expand All @@ -130,7 +130,7 @@ struct where_backward_csr {
// CType is condition data type
// IType is condition aux data type
template<typename DType, typename CType, typename IType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond_data,
const IType* cond_idx,
Expand Down Expand Up @@ -161,9 +161,9 @@ struct where_batch_backward {
// DType is the output data type
// CType is condition data type
template<typename DType, typename CType>
MSHADOW_XINLINE static void Map(int i, DType* grad_out,
MSHADOW_XINLINE static void Map(index_t i, DType* grad_out,
const DType* grad_in,
const CType* cond, int M) {
const CType* cond, index_t M) {
KERNEL_ASSIGN(grad_out[i], req,
((0 == cond[i/M])^negate)? grad_in[i] : static_cast<DType>(0));
}
Expand Down
11 changes: 11 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ def test_Dense(ctx=mx.cpu(0)):
res.wait_to_read()
assert res.shape == (50000000, 100)

def test_where():
a = nd.ones(shape=(LARGE_X, SMALL_Y))
b = nd.arange(0, LARGE_X).reshape(LARGE_X, 1)
b = nd.broadcast_to(b, shape=(b.shape[0], SMALL_Y))
res = nd.where(b > 100, a, b)
apeforest marked this conversation as resolved.
Show resolved Hide resolved
assert np.sum(res[-1].asnumpy() == 1) == b.shape[1]

csr_cond = nd.sparse.cast_storage(b < 10, 'csr')
res = nd.sparse.where(csr_cond, a, b)
assert np.sum(res[0].asnumpy() == 1) == b.shape[1]


if __name__ == '__main__':
import nose
Expand Down