Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy][Operator] 'where' Implementation in MXNet #16829

Merged
merged 8 commits into from
Nov 18, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix according to reviews
  • Loading branch information
hgt312 committed Nov 16, 2019
commit bb2896720cd7648f115b8928b9ed06ba4efd70af
14 changes: 11 additions & 3 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -5311,10 +5311,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):


@set_module('mxnet.ndarray.numpy')
def where(condition, x, y):
"""
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.

.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.

Parameters
----------
condition : ndarray
Expand Down Expand Up @@ -5371,4 +5376,7 @@ def where(condition, x, y):
[ 0., 2., -1.],
[ 0., 3., -1.]])
"""
return _npi.where(condition, x, y, out=None)
if x is None and y is None:
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)
9 changes: 7 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7298,10 +7298,15 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None, **kwargs):


@set_module('mxnet.numpy')
def where(condition, x, y):
"""
def where(condition, x=None, y=None):
"""where(condition, [x, y])
Return elements chosen from `x` or `y` depending on `condition`.

.. note::
When only `condition` is provided, this function is a shorthand for
``np.asarray(condition).nonzero()``. The rest of this documentation
covers only the case where all three arguments are provided.

Parameters
----------
condition : ndarray
Expand Down
98 changes: 25 additions & 73 deletions src/operator/numpy/np_where_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
namespace mxnet {
namespace op {

#define NUMPY_WHERE_MAX_DIM 5

using namespace mshadow;

template<int ndim>
Expand Down Expand Up @@ -75,52 +73,6 @@ struct numpy_where_backward_kernel {
}
};

inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& operand1 = (*in_attrs)[0];
mxnet::TShape& operand2 = (*in_attrs)[1];
mxnet::TShape& operand3 = (*in_attrs)[2];

if (operand1 == operand2 && operand2 == operand3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1);
return shape_is_known(out_attrs->at(0));
}
mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1);
const int b1 = out.ndim() - operand1.ndim();
const int b2 = out.ndim() - operand2.ndim();
const int b3 = out.ndim() - operand3.ndim();
for (int i = 0; i < out.ndim(); ++i) {
int s1 = 1, s2 = 1, s3 = 1;
if (i >= b1) s1 = operand1[i-b1];
if (i >= b2) s2 = operand2[i-b2];
if (i >= b3) s3 = operand3[i-b3];
if (!(s1 == s2 && s2 == s3)) {
CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) ||
(s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2))
<< "Operands could not be broadcast together.";
out[i] = std::max({s1, s2, s3});
} else {
out[i] = s1;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return shape_is_known(out);
}

inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ(in_attrs->at(1), in_attrs->at(2));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
return (out_attrs->at(0) != -1);
}

template<typename xpu>
inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand All @@ -130,29 +82,29 @@ inline void NumpyWhereOpForward(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 1U);
if (outputs[0].shape_.Size() == 0U) return; // zero-size tensor
CHECK_LE(outputs[0].shape_.ndim(), NUMPY_WHERE_MAX_DIM);
CHECK_LE(outputs[0].shape_.ndim(), broadcast::MAX_DIM);

Stream<xpu> *s = ctx.get_stream<xpu>();
std::vector<Shape<NUMPY_WHERE_MAX_DIM>> in_strides;
std::vector<Shape<broadcast::MAX_DIM>> in_strides;
in_strides.resize(3);
for (int i = 0; i < 3; ++i) {
TShape expanded_ishape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_ishape(broadcast::MAX_DIM, 1);
const TShape& ishape = inputs[i].shape_;
const int ndim_delta = expanded_ishape.ndim() - ishape.ndim();
for (int j = 0; j < ishape.ndim(); ++j) {
expanded_ishape[j + ndim_delta] = ishape[j];
}
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<NUMPY_WHERE_MAX_DIM>());
in_strides[i] = mxnet_op::calc_stride(expanded_ishape.get<broadcast::MAX_DIM>());
}
TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_oshape(broadcast::MAX_DIM, 1);
const int ndim_delta = expanded_oshape.ndim() - outputs[0].shape_.ndim();
for (int j = 0; j < outputs[0].shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (outputs[0].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> oshape = expanded_oshape.get<NUMPY_WHERE_MAX_DIM>();
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
hgt312 marked this conversation as resolved.
Show resolved Hide resolved
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, CType, {
mxnet_op::Kernel<numpy_where_kernel<NUMPY_WHERE_MAX_DIM>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_kernel<broadcast::MAX_DIM>, xpu>::Launch(
s, outputs[0].Size(), req[0],
in_strides[0], in_strides[1], in_strides[2], oshape,
inputs[0].dptr<CType>(), inputs[1].dptr<DType>(),
Expand All @@ -173,28 +125,28 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
if (inputs[0].shape_.Size() == 0U) return; // zero-size tensor
Stream<xpu> *s = ctx.get_stream<xpu>();
// get expanded oshape
TShape expanded_oshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_oshape(broadcast::MAX_DIM, 1);
int ndim_delta = expanded_oshape.ndim() - inputs[0].shape_.ndim();
for (int j = 0; j < inputs[0].shape_.ndim(); ++j) {
expanded_oshape[j + ndim_delta] = (inputs[0].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> oshape = expanded_oshape.get<NUMPY_WHERE_MAX_DIM>();
Shape<broadcast::MAX_DIM> oshape = expanded_oshape.get<broadcast::MAX_DIM>();
// get cond stride
TShape expanded_cshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_cshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_cshape.ndim() - inputs[1].shape_.ndim();
for (int j = 0; j < inputs[1].shape_.ndim(); ++j) {
expanded_cshape[j + ndim_delta] = (inputs[1].shape_)[j];
}
Shape<NUMPY_WHERE_MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<NUMPY_WHERE_MAX_DIM>());
Shape<broadcast::MAX_DIM> cstride =
mxnet_op::calc_stride(expanded_cshape.get<broadcast::MAX_DIM>());
// get expanded lshape
TShape expanded_lshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_lshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_lshape.ndim() - outputs[0].shape_.ndim();
for (int j = 0; j < outputs[0].shape_.ndim(); ++j) {
expanded_lshape[j + ndim_delta] = (outputs[0].shape_)[j];
}
// get expanded rshape
TShape expanded_rshape(NUMPY_WHERE_MAX_DIM, 1);
TShape expanded_rshape(broadcast::MAX_DIM, 1);
ndim_delta = expanded_rshape.ndim() - outputs[1].shape_.ndim();
for (int j = 0; j < outputs[1].shape_.ndim(); ++j) {
expanded_rshape[j + ndim_delta] = (outputs[1].shape_)[j];
Expand All @@ -203,27 +155,27 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
hgt312 marked this conversation as resolved.
Show resolved Hide resolved
MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, CType, {
Tensor<xpu, 1, char> largespace;
Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType> workspace;
Tensor<xpu, broadcast::MAX_DIM, DType> workspace;
size_t ws_size = 0;
if (!(inputs[0].shape_ != outputs[0].shape_) || !(inputs[0].shape_ != outputs[1].shape_)) {
size_t ws_size1 = broadcast::ReduceWorkspaceSize<NUMPY_WHERE_MAX_DIM, DType>(
size_t ws_size1 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_lshape, req[0], expanded_oshape);
size_t ws_size2 = broadcast::ReduceWorkspaceSize<NUMPY_WHERE_MAX_DIM, DType>(
size_t ws_size2 = broadcast::ReduceWorkspaceSize<broadcast::MAX_DIM, DType>(
s, expanded_rshape, req[1], expanded_oshape);
ws_size = std::max(ws_size1, ws_size2);
}
// process left output
if (inputs[0].shape_ == outputs[0].shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, true>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, inputs[0].Size(), req[0], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), outputs[0].dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType>(
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<NUMPY_WHERE_MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, true>, xpu>::Launch(
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, true>, xpu>::Launch(
s, inputs[0].Size(), req[0], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(outputs[0].type_flag_, outputs[0].type_flag_)) {
Expand All @@ -236,16 +188,16 @@ inline void NumpyWhereOpBackward(const nnvm::NodeAttrs& attrs,
}
// process right output
if (inputs[0].shape_ == outputs[1].shape_) {
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, false>, xpu>::Launch(
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, false>, xpu>::Launch(
s, inputs[0].Size(), req[1], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), outputs[1].dptr<DType>());
} else {
largespace = ctx.requested[0].get_space_typed<xpu, 1, char>(
Shape1(inputs[0].shape_.Size() * sizeof(DType) + ws_size), s);
workspace = Tensor<xpu, NUMPY_WHERE_MAX_DIM, DType>(
workspace = Tensor<xpu, broadcast::MAX_DIM, DType>(
reinterpret_cast<DType*>(largespace.dptr_ + ws_size),
expanded_oshape.get<NUMPY_WHERE_MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<NUMPY_WHERE_MAX_DIM, false>, xpu>::Launch(
expanded_oshape.get<broadcast::MAX_DIM>(), s);
mxnet_op::Kernel<numpy_where_backward_kernel<broadcast::MAX_DIM, false>, xpu>::Launch(
s, inputs[0].Size(), req[1], cstride, oshape,
inputs[1].dptr<CType>(), inputs[0].dptr<DType>(), workspace.dptr_);
if (NeedSafeAcc<true>(outputs[1].type_flag_, outputs[1].type_flag_)) {
Expand Down
46 changes: 46 additions & 0 deletions src/operator/numpy/np_where_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,52 @@
namespace mxnet {
namespace op {

inline bool NumpyWhereOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& operand1 = (*in_attrs)[0];
mxnet::TShape& operand2 = (*in_attrs)[1];
mxnet::TShape& operand3 = (*in_attrs)[2];

if (operand1 == operand2 && operand2 == operand3) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, operand1);
return shape_is_known(out_attrs->at(0));
}
mxnet::TShape out(std::max({operand1.ndim(), operand2.ndim(), operand3.ndim()}), -1);
const int b1 = out.ndim() - operand1.ndim();
const int b2 = out.ndim() - operand2.ndim();
const int b3 = out.ndim() - operand3.ndim();
for (int i = 0; i < out.ndim(); ++i) {
int s1 = 1, s2 = 1, s3 = 1;
if (i >= b1) s1 = operand1[i-b1];
if (i >= b2) s2 = operand2[i-b2];
if (i >= b3) s3 = operand3[i-b3];
if (!(s1 == s2 && s2 == s3)) {
CHECK((s1 == 1 && s2 == 1) || (s1 == 1 && s3 == 1) || (s2 == 1 && s3 == 1) ||
(s1 == 1 && s2 == s3) || (s2 == 1 && s1 == s3) || (s3 == 1 && s1 == s2))
<< "Operands could not be broadcast together.";
out[i] = std::max({s1, s2, s3});
} else {
out[i] = s1;
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out);
return shape_is_known(out);
}

inline bool NumpyWhereOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 3U)
<< "where operator takes 3 arguments (" << in_attrs->size() << " given)";
CHECK_EQ(out_attrs->size(), 1U);
std::vector<int> sub_in_attrs(in_attrs->begin() + 1, in_attrs->end());
bool flag = ElemwiseType<2, 1>(attrs, &sub_in_attrs, out_attrs);
return flag && (in_attrs->at(0) != -1);
}

NNVM_REGISTER_OP(_npi_where)
.set_num_inputs(3)
.set_num_outputs(1)
Expand Down