Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-381] Enhancement of take operator (#11326)
Browse files Browse the repository at this point in the history
* take forward for any axis with enhanced test

* general take backward on gpu

* backward of enhanced take op
  • Loading branch information
haojin2 authored and eric-haibin-lin committed Jul 17, 2018
1 parent e994a35 commit 3051c49
Show file tree
Hide file tree
Showing 5 changed files with 452 additions and 112 deletions.
9 changes: 9 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,15 @@ struct clip : public mxnet_op::tunable {
return x;
}
}
template<typename DType>
MSHADOW_XINLINE static DType Map(DType x, DType lower_bound, DType upper_bound) {
if (x > upper_bound) {
return upper_bound;
} else if (x < lower_bound) {
return lower_bound;
}
return x;
}
};

/***** gamma ******/
Expand Down
4 changes: 3 additions & 1 deletion src/operator/tensor/indexing_op-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <cub/device/device_run_length_encode.cuh>
#include <cub/device/device_scan.cuh>
#include "../mxnet_op.h"
#include "../mshadow_op.h"
#include "./util/tensor_util-inl.cuh"

#if CUDA_VERSION >= 9000
#define FULLMASK 0xFFFFFFFF
Expand Down Expand Up @@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
mshadow::Tensor<gpu, 1, char>* workspace = NULL) {
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(sorted.CheckContiguous(), true);
CHECK_EQ(index.CheckContiguous(), true);
Expand Down
29 changes: 20 additions & 9 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,36 +367,46 @@ NNVM_REGISTER_OP(take)
This function slices the input array along a particular axis with the provided indices.
Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output
will have shape ``(i0, i1, d1, d2)``, computed by::
output[i,j,:,:] = input[indices[i,j],:,:]
.. note::
- `axis`- Only slicing along axis 0 is supported for now.
- `mode`- Only `clip` mode is supported for now.
Given data tensor of rank r >= 1, and indices tensor of rank q, gather entries of the axis
dimension of data (by default outer-most one as axis=0) indexed by indices, and concatenates them
in an output tensor of rank q + (r - 1).
Examples::
x = [4. 5. 6.]
// Trivial case, take the second element along the first axis.
take(x, [1]) = [ 5. ]
// The other trivial case, axis=-1, take the third element along the first axis
take(x, [3], axis=-1, mode='clip') = [ 6. ]
x = [[ 1., 2.],
[ 3., 4.],
[ 5., 6.]]
// In this case we will get rows 0 and 1, then 1 and 2. Along axis 0
take(x, [[0,1],[1,2]]) = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
// In this case we will get rows 0 and 1, then 1 and 2 (calculated by wrapping around).
// Along axis 1
take(x, [[0, 3], [-1, -2]], axis=1, mode='wrap') = [[[ 1., 2.],
[ 3., 4.]],
[[ 3., 4.],
[ 5., 6.]]]
)code" ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(TakeParamParser<TakeParam>)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "indices"};
Expand All @@ -420,6 +430,7 @@ Examples::
NNVM_REGISTER_OP(_backward_take)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr_parser(ParamParser<TakeParam>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down
Loading

0 comments on commit 3051c49

Please sign in to comment.