Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add sum for boolean type when not built with TVM #16436

Merged
merged 1 commit into from
Oct 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ template<>
MSHADOW_XINLINE int64_t MinValue<int64_t>(void) {
return LLONG_MIN;
}
/*! \brief minimum value of bool */
template<>
MSHADOW_XINLINE bool MinValue<bool>(void) {
return false;
}

/*!
* \brief negative infinity of certain types
Expand Down Expand Up @@ -711,6 +716,11 @@ template<>
MSHADOW_XINLINE int64_t MaxValue<int64_t>(void) {
return LLONG_MAX;
}
/*! \brief maximum value of bool */
template<>
MSHADOW_XINLINE bool MaxValue<bool>(void) {
return true;
}

/*!
* \brief positive infinity of certain types
Expand Down
22 changes: 16 additions & 6 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,14 @@ struct AccType<mshadow::half::half_t> {
"floating point types, not int64"; \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
typedef int64_t AType; \
LOG(FATAL) << "This operation only support " \
"floating point types, not bool"; \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
Expand Down Expand Up @@ -369,6 +377,13 @@ struct AccType<mshadow::half::half_t> {
{__VA_ARGS__} \
} \
break; \
case mshadow::kBool: \
{ \
typedef bool DType; \
typedef int64_t AType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}
Expand Down Expand Up @@ -608,16 +623,11 @@ template <typename xpu>
MSHADOW_CINLINE void copy(mshadow::Stream<xpu> *s, const TBlob& to, const TBlob& from) {
CHECK_EQ(from.Size(), to.Size());
CHECK_EQ(from.dev_mask(), to.dev_mask());
if (from.type_flag_ == mshadow::kBool || to.type_flag_ == mshadow::kBool) {
CHECK_EQ(from.type_flag_, to.type_flag_) << "Only supports copying between boolean ndarrays.";
mshadow::Copy(to.FlatTo1D<xpu, bool>(s), from.FlatTo1D<xpu, bool>(s), s);
return;
}
MSHADOW_TYPE_SWITCH(to.type_flag_, DType, {
if (to.type_flag_ == from.type_flag_) {
mshadow::Copy(to.FlatTo1D<xpu, DType>(s), from.FlatTo1D<xpu, DType>(s), s);
} else {
MSHADOW_TYPE_SWITCH(from.type_flag_, SrcDType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(from.type_flag_, SrcDType, {
to.FlatTo1D<xpu, DType>(s) = mshadow::expr::tcast<DType>(from.FlatTo1D<xpu, SrcDType>(s));
})
}
Expand Down
4 changes: 3 additions & 1 deletion src/operator/numpy/np_broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
if (param.initial.has_value()) {
LOG(FATAL) << "initial is not supported yet";
}
if (inputs[0].shape_.Size() == 0) {
if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
using namespace mxnet_op;
using namespace mshadow;
Stream<xpu>* s = ctx.get_stream<xpu>();
Expand All @@ -236,6 +236,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
return;
}
CHECK_NE(req[0], kWriteInplace) << "Reduce does not support write in-place";
#if MXNET_USE_TVM_OP
// If boolean ndarray, use the kernel generated by TVM
if (inputs[0].type_flag_ == mshadow::kBool) {
std::string reducer_name;
Expand All @@ -247,6 +248,7 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
return;
}
#endif
if (param.axis.has_value() && param.axis.value().ndim() == 0) {
UnaryOp::IdentityCompute<xpu>(attrs, ctx, inputs, req, outputs);
}
Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/broadcast_reduce_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
mxnet::TShape src_shape, dst_shape;
BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
const TBlob in_data = inputs[0].reshape(src_shape);
const TBlob out_data = outputs[0].reshape(dst_shape);
Expand Down
13 changes: 8 additions & 5 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry
from mxnet.runtime import Features
from mxnet.numpy_op_signature import _get_builtin_op
from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, has_tvm_ops
from mxnet.test_utils import current_context, verify_generator, gen_buckets_probs_with_ppf
from mxnet.test_utils import is_op_runnable, has_tvm_ops
import platform


Expand Down Expand Up @@ -434,13 +435,15 @@ def is_int(dtype):
shape = rand_shape_nd(in_data_dim, dim=3)
acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64',
'int8': 'int32', 'int32': 'int64', 'int64': 'int64', 'bool': 'int64'}
is_windows = sys.platform.startswith('win')
for hybridize in [False, True]:
for keepdims in [True, False]:
for axis in ([i for i in range(in_data_dim)] + [(), None]):
for itype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64', 'bool']:
for dtype in ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']:
if (is_int(dtype) and not is_int(itype))\
or (itype == 'bool' and dtype not in ('float32', 'float64', 'int32', 'int64')):
or (itype == 'bool' and\
(dtype not in ('float32', 'float64', 'int32', 'int64') or is_windows)):
continue
# test gluon
test_sum = TestSum(axis=axis, dtype=dtype, keepdims=keepdims)
Expand All @@ -456,8 +459,8 @@ def is_int(dtype):
x = np.random.uniform(-1.0, 1.0, size=shape, dtype=itype)
expected_ret = _np.sum(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims)
expected_ret = expected_ret.astype(dtype)
if itype == 'bool': # special handling of boolean ndarray
if has_tvm_ops():
if itype == 'bool':
if is_op_runnable() and (not is_windows): # special handling of boolean ndarray
y = test_sum(x)
assert y.dtype == expected_ret.dtype
assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-4, atol=1e-5,
Expand All @@ -479,7 +482,7 @@ def is_int(dtype):
x_sym = mx.sym.Variable("x").as_np_ndarray()
mx_sym = mx.sym.np.sum(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray()
check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32)
numeric_eps=1e-3, rtol=1e-2, atol=1e-3, dtype=_np.float32)
reminisce marked this conversation as resolved.
Show resolved Hide resolved

# test imperative
mx_out = np.sum(x, axis=axis, dtype=dtype, keepdims=keepdims)
Expand Down