Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Use RTC for elementwise and broadcast ops #18622

Merged
merged 76 commits into from
Aug 20, 2020

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Jun 25, 2020

Description

As described in #18280 (comment), MXNet currently contains too many CUDA kernels, that affect negatively compile time, size of the resulting binary (resulting in issues like #17045 and #18205), and GPU memory consumption (as all of those kernels need to be loaded during the first GPU context creation to GPU memory).

The reason of those problems is the number of templates that need to be instantiated, especially in the case of NumPy operators which need to accept different input/output types - this results in multiple nested MSHADOW_TYPE_SWITCH macros and great increase in the number of kernels generated, most of them pretty much never used. For example, executing this command:

cuobjdump -symbols -arch sm_70 /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so | grep GLOBAL | wc -l

on the nightly build of mxnet-cu102 from 6/25 shows 69169 kernels (the same command executed on the library built with this PR at the time of writing gives 51511 kernels).

The proposed approach is to use RTC (runtime compilation) in order to generate the needed kernels at runtime. This saves the ahead-of-time compilation time and binary size as well as the GPU memory utilization (since only the needed kernels are generated, not all combinations).

To test the impact on binary size and memory consumption reduction I compiled MXNet for 5 GPU architectures (5.2, 6.0, 6.1, 7.0, 7.5) using CUDA 11 both from the head of this PR and from the latest master commit included (f872b43).
Binary size reduction: 292 MB (from 2 GB to 1.7 GB)
Idle GPU memory consumption reduction: 96 MB (from 1442 MB to 1346 MB)
Idle GPU memory consumption reduction was checked by launching Python interpreter and checking GPU memory consumption after calling:

import mxnet as mx
a = mx.nd.zeros((1,), ctx=mx.gpu())

This PR uses that approach to handle elementwise and broadcast kernels (as well as their backward), which constitute a big portion of the total number of kernels in MXNet.

FYI @leezu @sxjscience @eric-haibin-lin

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented:
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • To the best of my knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • RTC is now required for using CUDA in MXNet
  • Unary, binary, binary with scalar, binary broadcast ops and their backward counterparts were changed to use RTC

Comments

  • Things left to do:

    • Test the performance impact of cache lookup for kernel code
    • Convert MixedUnaryBackward functions
    • Update PR description with the change in GPU memory utilization and binary size resulting from this PR
  • After this PR the next step would be to use the same approach for reduce kernels - this PR already contains a ground work for this as reduction was needed for backward of broadcast ops, but it does not apply that path to standalone reduction ops. Grepping for reduce_kernel in the symbols visible in libmxnet.so after application of this PR:

cuobjdump -symbols -arch sm_70 /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so | grep GLOBAL | grep reduce_kernel | wc -l

gives 12057 entries. This would also help with reducing the amount of code duplication that this PR introduces (to maintain both RTC and non-RTC paths).

@mxnet-bot
Copy link

Hey @ptrendx , Thanks for submitting the PR
All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands:

  • To trigger all jobs: @mxnet-bot run ci [all]
  • To trigger specific jobs: @mxnet-bot run ci [job1, job2]

CI supported jobs: [unix-cpu, windows-gpu, centos-cpu, unix-gpu, windows-cpu, edge, miscellaneous, website, centos-gpu, sanity, clang]


Note:
Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin.
All CI tests must pass before the PR can be merged.

@leezu
Copy link
Contributor

leezu commented Jun 25, 2020

Thank you @ptrendx! As this makes nvrtc feature mandatory, it may be necessary to first/prior to next release also fix #17858? It seems that there are a number of users that try to open GPU builds on machines without libcuda.so and this is broken since 1.6 due to #17858 (but currently there is the workaround of disabling nvrtc)

@ptrendx
Copy link
Member Author

ptrendx commented Jun 26, 2020

Yeah, we will need to sort it out before the release - I actually thought that if we manage to get everything via RTC (which seems daunting though), we could actually dynamically load both libcuda and libnvrtc and have a single build that supports everything instead of mxnet-cu*. That said, RTC for everything is a big task and I would definitely need help from community if we were to pull it off.

@ptrendx
Copy link
Member Author

ptrendx commented Jun 26, 2020

Also, I thought that you made the compilation use C++17 (so I use if constexpr), but now I see that the CUDA part is compiled with C++14 (and fails). What is the reason of that?

@leezu
Copy link
Contributor

leezu commented Jun 26, 2020

Also, I thought that you made the compilation use C++17 (so I use if constexpr), but now I see that the CUDA part is compiled with C++14 (and fails). What is the reason of that?

The reason is that CUDA does not support C++17 prior to CUDA11. Thus cuda files are compiled with C++14. We can consider requiring CUDA11 for MXNet 2

Fixes for mixed type gradient functions
Set the launch bounds on RTC kernels
@ptrendx ptrendx changed the title [WIP] Use RTC for elementwise and broadcast ops Use RTC for elementwise and broadcast ops Aug 6, 2020
@ptrendx ptrendx added pr-awaiting-review PR is waiting for code review and removed pr-work-in-progress PR is still work in progress labels Aug 6, 2020
Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked at the performance impact of RTC and it adds ~2us of CPU time to the launch, mostly due to string manipulation.

If we use cuda graph to cache all the kernels to launch, would this overhead be mitigated?

@ptrendx
Copy link
Member Author

ptrendx commented Aug 7, 2020

@eric-haibin-lin Yes. The overhead comes from preparing a string with kernel options (like the datatypes) and searching for the kernel function in cache. CUDA graph caches the resulting function so the lookup does not occur anymore.

That said, this overhead is lower than the overhead of cudaLaunchKernel itself and is barely noticeable - I tried it with a worst case scenario of fully hybridized model that was adding tensors with single element (to be 100% CPU limited) and got ~10% slowdown. More realistic workload with kernels taking longer than a few us would not show any difference. The same CPU limited test with non-hybridized model did not show noticeable slowdown (overheads of imperative mode are way higher than this).

@ptrendx
Copy link
Member Author

ptrendx commented Aug 10, 2020

@mxnet-bot run ci [unix-cpu]

@mxnet-bot
Copy link

Jenkins CI successfully triggered : [unix-cpu]

@@ -47,6 +47,12 @@ The following tutorials will help you learn how to customize MXNet.
How to create new MXNet operators in MXNet's backend using C++.
An example custom quadratic function op.

.. card::
:title: Using runtime compilation (RTC) to write CUDA kernels in MXNet
:link: https://mxnet.apache.org/api/faq/using_rtc
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use :link: /api/faq/using_rtc instead as the documentation is versioned.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok - I was just copy-pasting from the other sections there (like add_op_in_backend). Will update those as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the toctree also has the full links - will /api/faq/... approach work there too?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I tried putting those relative links in the toctree as well and building of Python docs complains with

/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/new_op'                                                              
/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/add_op_in_backend'                           
/work/mxnet/docs/python_docs/python/build/tutorials/extend/index.rst:57: WARNING: toctree contains reference to nonexisting document 'api/faq/using_rtc'

and those entries are not shown in the final website.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm currently fixing a number of issues in #18839. You may get conflict from this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :-). I will push the changes to address your other comments then and will get back to the website part once your PR is merged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#18839 is merged, but it did not tackle the issue of toctree - should I leave the links as they are right now @szha?

@szha szha merged commit 29d6f27 into apache:master Aug 20, 2020
@ZiyueHuang
Copy link
Member

After this PR, the training of Electra Model in gluon-nlp will raise error below

mxnet.base.MXNetError: Traceback (most recent call last):
  File "/home/ubuntu/mxnet/src/common/cuda/rtc.cc", line 163
MXNetError: Check failed: compileResult == NVRTC_SUCCESS (6 vs. 0) : NVRTC Compilation failed.
The generated code was stored in mxnet_rtc_debug_code.log
binary_scalar_kernel_kernel.cu(1118): error: more than one instance of overloaded function "isnan" matches the argument list:
            function "isnan(float)"
            function "isnan(double)"
            function "isnan(long double)"
            argument types are: (const InputType0)
          detected during instantiation of "type_util::mixed_type<DType, DType2, void>::type op::min(DType, DType2) [with DType=InputType0, DType2=InputType0]"
(2285): here

Setting MXNET_RTC_VERBOSE =1 will get the code for binary_scalar_kernel

using InputType0 = float32;
using OutputType0 = float32;
const bool aligned = true;
const int nvec = 4;
const OpReqType req = OpReqType::kWriteTo;
#define OP op::div


struct binary_scalar_kernel_params {
  const void *inputs[2];
  void *outputs[1];
  double scalar;
};

__launch_bounds__(kRTCMaxThreadsPerBlock)
__global__ void binary_scalar_kernel(const binary_scalar_kernel_params params,
                                     const index_t lead_dim,
                                     const index_t other_dim,
                                     const index_t N,
                                     const index_t num_aligned_elements) {
  using namespace vector;
  VectorizedLoader<InputType0, nvec, aligned> loader(
    reinterpret_cast<const InputType0*>(params.inputs[0]), N);
  VectorizedStorer<OutputType0, nvec, aligned> storer(
    reinterpret_cast<OutputType0*>(params.outputs[0]), N);

  using IType = AccType<InputType0>;
  using OType = AccType<OutputType0>;

  const index_t M = num_aligned_elements;

  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
       tid < M;
       tid += gridDim.x * blockDim.x) {
    loader.load(tid, N);
    if (req == OpReqType::kAddTo) {
      storer.load(tid, N);
    }
#pragma unroll
    for (int i = 0; i < nvec; ++i) {
      const auto input = IType::from(loader.separate()[i]);
      // enables returning different type
      const auto temp = OP(input,
                           static_cast<typename type_util::mixed_type<typename IType::type,
                                                                      typename OType::type>::type>
                             (params.scalar));

      if (req == OpReqType::kAddTo) {
        // temp2 may have a wider type than either temp
        // or OType
        const auto temp2 = op::add(temp, OType::from(storer.separate()[i]));
        storer.separate()[i] = OType::to(temp2);
      } else {
        storer.separate()[i] = OType::to(temp);
      }
    }
    storer.store(tid, N);
  }
}

@ptrendx Could you please take a look?

@ptrendx
Copy link
Member Author

ptrendx commented Aug 22, 2020

Yes, I will look into it.

@ptrendx ptrendx mentioned this pull request Aug 22, 2020
@ptrendx
Copy link
Member Author

ptrendx commented Aug 22, 2020

@ZiyueHuang Please try with PR #18984.

@ZiyueHuang
Copy link
Member

@ptrendx It works. Thanks for the fix.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants