diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index 78988f13510e..f854e3b78cdc 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -50,6 +50,8 @@ inline std::string mshadowTypeToString(int type) { return "int"; case mshadow::kInt64: return "long long"; + case mshadow::kBool: + return "bool"; default: LOG(FATAL) << "Unknown type enum " << type; } @@ -72,6 +74,8 @@ inline int mshadowTypeToVectorLength(int type) { return 1; case mshadow::kInt64: return 1; + case mshadow::kBool: + return 4 / sizeof(bool); default: LOG(FATAL) << "Unknown type enum " << type; } @@ -156,7 +160,7 @@ void AddPointerAndShape(const TBlob& data, std::vector>* shapes, mshadow::Stream * s) { using namespace mshadow; - MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, { Tensor tensor = data.FlatTo1D(s); ptrs->push_back(tensor.dptr_); AddShape(data.shape_, shapes); @@ -647,7 +651,9 @@ void FusedOp::CheckShapesAndTypes(const std::vector &inputs, in_ndims->push_back(blob.ndim()); in_shapes.push_back(blob.shape_); initialized_ = initialized_ && blob.type_flag_ == inputs_[counter].dtype; + initialized_ = initialized_ && blob.ndim() == inputs_[counter].ndim; inputs_[counter].dtype = blob.type_flag_; + inputs_[counter].ndim = blob.ndim(); *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_)); } @@ -657,7 +663,9 @@ void FusedOp::CheckShapesAndTypes(const std::vector &inputs, out_ndims->push_back(blob.ndim()); out_shapes.push_back(blob.shape_); initialized_ = initialized_ && blob.type_flag_ == outputs_[counter].dtype; + initialized_ = initialized_ && blob.ndim() == outputs_[counter].ndim; outputs_[counter].dtype = blob.type_flag_; + outputs_[counter].ndim = blob.ndim(); *nvec = max(*nvec, mshadowTypeToVectorLength(blob.type_flag_)); } diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h index 24603ac1932f..43491f7af47a 100644 --- a/src/operator/fusion/fused_op.h +++ b/src/operator/fusion/fused_op.h @@ -52,8 +52,9 @@ struct FusedOpConfig : public dmlc::Parameter { }; struct FusedOpEntry { - FusedOpEntry() : dtype(-1) {} + FusedOpEntry() : dtype(-1), ndim(-1) {} int dtype; + int ndim; }; class FusedOp { diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 5606eb19a9c5..beffb353ef35 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -238,6 +238,49 @@ def test_fusion_compiler_cache(): if num_gpus > 1: check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2) +@with_seed() +@use_np +def test_fusion_boolean_inputs(): + from mxnet.gluon import HybridBlock + + class Foo(HybridBlock): + def __init__(self, prefix=None, params=None): + super(Foo, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, valid_length): + mask = valid_length.astype(np.float32) + mask2 = valid_length.astype(np.float32) + mask = mask * F.np.expand_dims(mask2, axis=-1) + return mask + + foo = Foo() + foo.hybridize(static_alloc=True) + out = foo(mx.np.ones((10,), ctx=mx.gpu(), dtype=np.bool)) + mx.npx.waitall() + +@with_seed() +def test_fusion_different_dimensions(): + from mxnet.gluon import HybridBlock + + class Foo(HybridBlock): + def __init__(self, prefix=None, params=None): + super(Foo, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, x): + mask2 = x.astype(np.float32) + mask = F.expand_dims(mask2, axis=-1) + return mask + + foo = Foo() + foo.hybridize(static_alloc=True) + # Pass 1-D data + out = foo(mx.nd.ones((10,), ctx=mx.gpu())) + assert np.all(out.asnumpy() == np.ones((10,1))) + assert out.shape == (10,1) + # Pass 2-D data + out = foo(mx.nd.ones((10,10), ctx=mx.gpu())) + assert np.all(out.asnumpy() == np.ones((10,10))) + assert out.shape == (10,10,1) if __name__ == '__main__': import nose