Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Workaround problem with fusion in CUDA 9 #17028

Merged
merged 1 commit into from
Dec 10, 2019
Merged

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Dec 9, 2019

Description

Fixes #17020

The problem comes from the bug in how NVRTC in CUDA 9 handles the default-device flag. That flag is supposed to mark all the functions in the file as __device__ functions, but it should leave the functions decorated differently (like kernels decorated with __global__) alone. This is the behavior in CUDA 10+. In CUDA 9, however, this __device__ attribute is applied to every function (including kernels), which is incompatible with __launch_bounds__() attribute that we use for kernels.

This PR removes the usage of default-device flag for NVRTC compilation and instead manually decorates all the required functions as __device__

Copy link
Contributor

@DickJC123 DickJC123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a straightforward fix to the nvrtc problem. I'm sure it was tedious to add __device__ in the many places required. What might help the file size and code duplication is the macro:

// Create a fast-math function by appending an 'f' to the canonical name.
#define DEFINE_FASTMATH_FUNC(func)  \
  template <typename DType> \
  __device__ inline DType func(const DType val) { \
    return func ## f(val); \
  }

DEFINE_FASTMATH_FUNC(sin)
... etc

This only applies to some of the many functions defined in fused_op-inl.h, so I approve this PR independent of this suggestion.

@ptrendx ptrendx merged commit 9f5b8bc into apache:master Dec 10, 2019
@perdasilva
Copy link
Contributor

Awesome stuff! Thank you @ptrendx and @DickJC123 - this puts CD back on track ^^

ptrendx added a commit to ptrendx/mxnet that referenced this pull request Dec 10, 2019
@ptrendx ptrendx mentioned this pull request Dec 10, 2019
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[CUDA 9.0] NVRTC Compilation failed
3 participants