Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Optimize NMS part 2 #14352

Merged
merged 2 commits into from
Mar 8, 2019
Merged

Optimize NMS part 2 #14352

merged 2 commits into from
Mar 8, 2019

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Mar 6, 2019

Description

This PR changes the batch_start calculation in the BoxNMSForward op to the custom kernel, much faster than the mshadow generated one. In MaskRCNN model it changes the runtime of that part from 20 ms to 2 us, speeding up the single GPU training by 20% in fp16 mode.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Comments

  • I'm pretty sure that on a CPU path a simple for loop would be much better than the mshadow generated kernel as well, but since I did not have experimental data, I did not change it. FYI @zhreshold

@vandanavk
Copy link
Contributor

@mxnet-label-bot add [Operator, pr-awaiting-review]

@marcoabreu marcoabreu added Operator pr-awaiting-review PR is waiting for code review labels Mar 7, 2019
@zhreshold zhreshold mentioned this pull request Mar 7, 2019
7 tasks
int num_batch) {
size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < N) {
const int32_t previous = tid > 0 ? __ldg(valid_batch_id + tid - 1) : -1;
Copy link
Member

@arcadiaphy arcadiaphy Mar 7, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using __ldg intrinsic will fail to compile on some early cuda architectures.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will fail on sm 3.0 and earlier (so Fermi and the first Kepler). I can put ifdef there, but do we care about those?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we do ;-). I will introduce the guard, thanks!

@zhreshold zhreshold merged commit 838e256 into apache:master Mar 8, 2019
vdantu pushed a commit to vdantu/incubator-mxnet that referenced this pull request Mar 31, 2019
* Optimize NMS part 2

* Guarding ldg intrinsics
nswamy pushed a commit that referenced this pull request Apr 5, 2019
* Optimize NMS part 2

* Guarding ldg intrinsics
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Optimize NMS part 2

* Guarding ldg intrinsics
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants