Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-381] Enhancement of take operator #11326

Merged
merged 3 commits into from
Jul 17, 2018

Conversation

haojin2
Copy link
Contributor

@haojin2 haojin2 commented Jun 18, 2018

Description

Previously our take operator only supports axis=0 and mode = 'clip' case, this PR adds support for axis in range [-r, r-1] and an additional mode 'wrap'.

Checklist

Essentials

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Support for take on any dimension
  • New mode 'wrap' for indices
  • Unit tests for enhanced take operator

Comments

The legacy implementation for axis=0 and mode='clip' is still preserved to ensure there's no performance or accuracy regression after the enhancement.

@haojin2 haojin2 changed the title [MXNET-381] Enhancement of take operator [MXNET-381] [WIP] [DO NOT MERGE] [DO NOT REVIEW] Enhancement of take operator Jun 18, 2018
@haojin2 haojin2 changed the title [MXNET-381] [WIP] [DO NOT MERGE] [DO NOT REVIEW] Enhancement of take operator [MXNET-381] Enhancement of take operator Jun 18, 2018
@haojin2
Copy link
Contributor Author

haojin2 commented Jun 18, 2018

@reminisce @piiswrong @anirudh2290 @rahul003 @eric-haibin-lin Please give a review when you have time, thanks!

@haojin2 haojin2 force-pushed the take_op_enhance branch 2 times, most recently from 28b8007 to a6ed57c Compare June 18, 2018 23:30
@zheng-da
Copy link
Contributor

@junrushao1994 you might want to keep an eye on this PR.

@haojin2
Copy link
Contributor Author

haojin2 commented Jun 25, 2018

@piiswrong @reminisce @anirudh2290 @rahul003 ping for review

@haojin2 haojin2 force-pushed the take_op_enhance branch 2 times, most recently from 229897a to b4b5af3 Compare June 27, 2018 17:36
.set_default(0)
.describe("The axis of input array to be taken.");
.describe("The axis of input array to be taken."
Copy link
Member

@anirudh2290 anirudh2290 Jun 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy take currently has raise as default but it doesnt seem to be supported in mxnet currently. We can also make raise as default but it will be a breaking change. We should add an issue to add it for 2.0 release.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree, but something worth noting here is that adding 'raise' mode may impact the performance a bit as you need another kernel to check if all indices are within the legal range.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you dont need another kernel you can just do it inside the same kernel. you already have the bounds check for indices inside the Take kernel you can just maintain state of whether bound check passed or failed.

oshape[i + idxshape.ndim()] = arrshape[i + 1];
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
TShape oshape(idxshape.ndim() + arrshape.ndim() - 1);
for (int i = 0; i < static_cast<int>(idxshape.ndim()); ++i) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use index_t here and avoid static_cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, will do.

for (int i = 0; i < static_cast<int>(idxshape.ndim()); ++i) {
oshape[i + actual_axis] = idxshape[i];
}
for (int i = 0; i < static_cast<int>(arrshape.ndim()); i++) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use index_t here and avoid static_cast

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, will do.

}
for (size_t i = 0; i < arrshape.ndim() - 1; i++) {
oshape[i + idxshape.ndim()] = arrshape[i + 1];
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_t here ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

const int in_ndims, const int out_ndims, const int idx_ndims,
const int axis_dim, const int axis) {
// i is the global flattened index in the output
const int out_head_index = (axis == 0) ? 0 : (i / out_stride[axis - 1]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IType can be used for all indexes here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's possibility of IType to be of a floating number type, so compiler will complain about it. That's also the reason why the legacy Map function above is also using a cast.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay makes sense.

@@ -389,7 +389,7 @@ Examples::
)code" ADD_FILELINE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given an input array with shape ``(d0, d1, d2)`` and indices with shape ``(i0, i1)``, the output This only holds true for axis =0 right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update that doc.

MSHADOW_TYPE_SWITCH(idx.type_flag_, IType, {
// get size of temporary storage for sort
char* temp_storage_ptr = nullptr;
int* src_indptr_ptr = nullptr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use dim_t instead of int here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're using cub DeviceHistogram for doing the histogramming of indices here we need to stick to int32, currently int32 should suffice. Or we can switch our own histogram kernel which supports all types, but that would be slower compared to cub's implementation.

s, idxshape.Size(), sorted_idx_ptr, sorted_idx_ptr, static_cast<int>(arrshape[axis]));
}
Tensor<cpu, 1, int> original_idx(original_idx_ptr, Shape1(idxshape.Size()), s);
Tensor<cpu, 1, char> temp_storage(temp_storage_ptr, Shape1(temp_storage_bytes), s);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can move this to the start and use temp_storage.dptr_ to reuse it and remove temp_storage_ptr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that I did not notice this comment earlier, the tensor is purely for the SortByKey function call, so keeping declaration of it closer to the function call makes more sense.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can also keep it at the same place. i am essentially suggesting that temp_storage_ptr seems not required and can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used as a shorthand for the calculated pointer within the whole workspace pool: https://github.com/apache/incubator-mxnet/pull/11326/files#diff-ed06b8d9798aca630313f2a9dd3fcd68R950

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do the following:

Tensor<cpu, 1, char> temp_storage(workspace.dptr_ + 2 * original_idx_bytes + src_indptr_bytes, Shape1(temp_storage_bytes), s);

and use temp_storage or temp_storage.dptr_ for the pointer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

});
});
}

#ifdef __CUDACC__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this gpu specific code be moved to cuh file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to re-use the kernel here, if I move this to cuh and the cpu compiler will not see that kernel.

- `axis`- Only slicing along axis 0 is supported for now.
- `mode`- Only `clip` mode is supported for now.
- `axis`- Could be from -r to r-1 where r is the rank of input tensor
- `mode`- Could be either `clip` or `wrap`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can move this explanation to the respective arguments and delete the note.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

- `axis`- Only slicing along axis 0 is supported for now.
- `mode`- Only `clip` mode is supported for now.
- `axis`- Could be from -r to r-1 where r is the rank of input tensor
- `mode`- Could be either `clip` or `wrap`.

Examples::
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to lack of extra blank lines, fixed.

@haojin2 haojin2 force-pushed the take_op_enhance branch 2 times, most recently from 27feafd to e8e51d7 Compare July 2, 2018 21:50
@haojin2
Copy link
Contributor Author

haojin2 commented Jul 3, 2018

@piiswrong @reminisce please give a review when you have time, thanks!

@@ -272,7 +274,7 @@ inline void AddTakeGradLargeBatch(mshadow::Tensor<gpu, 2, DType> dst,
const mshadow::Tensor<gpu, 1, IndexType>& sorted,
const mshadow::Tensor<gpu, 1, IndexType>& index,
const mshadow::Tensor<gpu, 2, DType> &src,
mshadow::Tensor<gpu, 1, char>* workspace) {
mshadow::Tensor<gpu, 1, char>* workspace = NULL) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: NULL -> nullptr. NULL has more semantic meanings than nullptr and should be deprecated in C++11.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will change.

@haojin2
Copy link
Contributor Author

haojin2 commented Jul 11, 2018

@piiswrong Please give a review once you have a minute.

@eric-haibin-lin eric-haibin-lin merged commit 3051c49 into apache:master Jul 17, 2018
KellenSunderland pushed a commit to KellenSunderland/incubator-mxnet that referenced this pull request Jul 19, 2018
* take forward for any axis with enhanced test

* general take backward on gpu

* backward of enhanced take op
@haojin2 haojin2 deleted the take_op_enhance branch July 19, 2018 20:12
KellenSunderland pushed a commit to KellenSunderland/incubator-mxnet that referenced this pull request Jul 21, 2018
* take forward for any axis with enhanced test

* general take backward on gpu

* backward of enhanced take op
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
* take forward for any axis with enhanced test

* general take backward on gpu

* backward of enhanced take op
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants