Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Integrate MKLDNN Conv1d and support 3d layout #13530

Merged
merged 14 commits into from
Jan 2, 2019
Merged

Conversation

xinyu-intel
Copy link
Contributor

@xinyu-intel xinyu-intel commented Dec 4, 2018

Description

This PR aims to integrate MKLDNN Conv1d and enable 3d layout for Conv and Activation.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http:https://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • MKLDNN Conv1d
  • MKLDNN 3d layout for Conv and Activation

@pengzhao-intel @TaoLv

@pengzhao-intel
Copy link
Contributor

@TaoLv
Copy link
Member

TaoLv commented Dec 4, 2018

Related issues: #11906, #11161

@TaoLv
Copy link
Member

TaoLv commented Dec 5, 2018

@mxnet-label-bot add [MKLDNN, Operator, pr-awaiting-review]

@marcoabreu marcoabreu added MKLDNN Operator pr-awaiting-review PR is waiting for code review labels Dec 5, 2018
Copy link
Member

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please point out where are the unit tests for this feature. Is it covered by the existing CI? Notice activation is also changed in this PR.

src/operator/nn/activation.cc Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_act.cc Outdated Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_act.cc Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_base-inl.h Outdated Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_base-inl.h Outdated Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_base.cc Outdated Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_convolution.cc Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_convolution.cc Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_convolution.cc Show resolved Hide resolved
src/operator/nn/mkldnn/mkldnn_convolution.cc Show resolved Hide resolved
@xinyu-intel
Copy link
Contributor Author

@TaoLv address Tao's comments. These changes are already covered by test_operator:test_convolution_grouping and test_operator:test_activation. I check if they run into MKL-DNN kernel by checking the mkldnn_verbose.

} else if (arr.shape().ndim() == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups,
static_cast<int>(arr.shape()[0] /
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use O, I, H, W instead of 0, 1, 2 ...

dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
} else {
LOG(FATAL) << "MKL-DNN currently supports 1d and 2d convolution";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we can mention the given dimension in the error message. Same in other LOG(FATAL).

@TaoLv
Copy link
Member

TaoLv commented Dec 8, 2018

Do we have unit test for 1D Convolution without grouping? Please also make sure that this change works well for quantized Convolution (which I think doesn't support 1D yet).

@xinyu-intel
Copy link
Contributor Author

Address Tao's comments. I've skip conv1d convolution on both C level and Python level. @ZhennanQin Please help take a look. Maybe skip all non-4d data layout in quantization.py is not a good choice.

@@ -488,6 +488,9 @@ def quantize_model(sym, arg_params, aux_params,
A tuple of quantized symbol, quantized arg_params, and aux_params.
-------
"""
if ctx == cpu(0) and len(calib_data.provide_data[0].shape) != 3:
raise ValueError('MKL-DNN quantized OPs temporary support 4d layout.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't check calib_data as quantization flow support non-calib mode.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@TaoLv
Copy link
Member

TaoLv commented Dec 15, 2018

@zheng-da @ZhennanQin Please help to review and approve if no further concerns.
@xinyu-intel Please rebase code. I notice the last time of CI run is 4 days ago.
Thank you all.

@ZhennanQin
Copy link
Contributor

LGTM.

@TaoLv
Copy link
Member

TaoLv commented Dec 20, 2018

Seems there was problem with CI. Do you mind re-triggering it with an empty commit? @xinyu-intel
Ping @zheng-da for review. Thank you.

Copy link
Contributor

@pengzhao-intel pengzhao-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add the test for 1D activation and quantization conv (see if the msg are printed as expectation?

dims.resize(shape.ndim() + 1);
dims[0] = 1;
for (size_t i = 0; i < shape.ndim(); i++)
dims[i + 1] = shape[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a performance difference between 3D and 4D implementation?

num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is the common part of dim3 and dim4, right?

static_cast<int>(arr.shape()[3])};
return mkldnn::memory::desc{tz, get_mkldnn_type(arr.dtype()),
mkldnn::memory::format::any};
CHECK((arr.shape().ndim() == 3) || (arr.shape().ndim() == 4))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a variable to save the value of arr.shape().ndim() to avoid mutiple time call

@xinyu-intel
Copy link
Contributor Author

xinyu-intel commented Dec 25, 2018

@pengzhao-intel Conv1d will fall back to native cpu implement before this optimization.

100 iterations total time(ms) on Xeon Skylake 8180 1 socket:

shape before opt after opt speedup
(1,256,200) 715.47 73.88 9.68x
(1,1024,512) 1970.15 101.06 19.49x
(64,1024,512) 131312.48 4196.24 31.29x

I've add 1d,3d,4d data shape to activation test. Regarding quantized conv, it will now return error when the data shape is 3d and users should exclude this layer:

  CHECK_EQ(param.full_conv_param.conv_param.kernel.ndim(), 2U)
      << "MKL-DNN only supports quantized conv2d.";

@pengzhao-intel pengzhao-intel mentioned this pull request Dec 25, 2018
Copy link
Contributor

@pengzhao-intel pengzhao-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.
Thanks for the contributions :)

@xinyu-intel
Copy link
Contributor Author

@zheng-da Please take a look. Thanks!

Copy link
Member

@TaoLv TaoLv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase code then I think it's good to merge.

@TaoLv
Copy link
Member

TaoLv commented Jan 2, 2019

@xinyu-intel Thank you for the contribution. Now merging.

@TaoLv TaoLv merged commit d7f9a07 into apache:master Jan 2, 2019
@bputrycz
Copy link

bputrycz commented Jan 3, 2019

I noticed a very nice improvement with this change.
So, thank you.

Still, for my use-case: conv1d, small batch size, small channel dimension, long sequence length,
I don't see much improvement with more cores added to the computations.

The simplified snippet to reproduce (conv1d.py):

import mxnet as mx
from mxnet import gluon, nd

from mxnet import profiler
profiler.set_config(profile_all=True, aggregate_stats=True, filename='profile_output.json')
 
channels = 64

net = gluon.nn.Sequential()

conv = gluon.nn.Conv1D(channels, 4, padding=1)
act = gluon.nn.Activation('sigmoid')

for i in range(3):
    net.add(conv)
    net.add(act)

net.initialize()

data = nd.random.uniform(shape=(1, channels, 2**16))

# Warm-up
y = net(data)
nd.waitall()

profiler.set_state('run')
for i in range(10):
    y = net(data)
    nd.waitall()
profiler.set_state('stop')

print(profiler.dumps())

When run on a host with a lot of cores (AWS c4.8xlarge) results in:

$ OMP_NUM_THREADS=1 python conv1d.py | grep "Convolution\|Activation"
Activation                             60         244.8700           3.3960           9.7510           4.0812
Convolution                            60        1648.7010          24.9280          40.6520          27.4783
$ OMP_NUM_THREADS=2 python conv1d.py | grep "Convolution\|Activation"
Activation                             60         127.4460           1.6600           5.4070           2.1241
Convolution                            60         866.8680          12.6670          22.9810          14.4478
$ OMP_NUM_THREADS=4 python conv1d.py | grep "Convolution\|Activation"
Activation                             60          65.3190           0.8940           2.9280           1.0886
Convolution                            60         854.2900          12.6230          20.3230          14.2382

There is no improvement when number of threads is increased to 4, or more.

Playing more with this example, for higher 'channels' value, it starts to be a little better parallelizable.
So, it seems the parallelization is done only per a single sequence "point". Is it the case?
But, it seems quite natural to parallelize also along the sequence, especially when it is long - different threads doing a different part of the sequence.
Then, parallelization should scale linearly with number of threads.
Isn't it done like that?

Bartosz

@pengzhao-intel
Copy link
Contributor

pengzhao-intel commented Jan 4, 2019

@bputrycz really thanks for trying the new MKLDNN API and give us the very useful feedback.
Nice analysis and reproducible examples. I will contact Intel MKL-DNN team to see how to fix it.

In a short time, you can try to launch multiple instances (processor) and each one binds to 2/4 cores to get the max throughput in case this is the inference case.

@pengzhao-intel
Copy link
Contributor

@TaoLv is following this issue and will back to you with more details soon.

@TaoLv
Copy link
Member

TaoLv commented Jan 4, 2019

@bputrycz Cannot reproduce the problem. The code snippet you provided scales well on my machine:

$ OMP_NUM_THREADS=1 python conv1d.py | grep "Convolution\|Activation"
Activation                             60         183.5660           2.6600           6.1490           3.0594
Convolution                            60        1047.5291          14.2060          26.2510          17.4588
$
$ OMP_NUM_THREADS=2 python conv1d.py | grep "Convolution\|Activation"
Activation                             60         103.8990           1.4330           3.5280           1.7316
Convolution                            60         550.4830           7.5440          13.9890           9.1747
$
$ OMP_NUM_THREADS=4 python conv1d.py | grep "Convolution\|Activation"
Activation                             60          57.5180           0.7640           2.4380           0.9586
Convolution                            60         290.6110           3.8830           7.8330           4.8435
$
$ OMP_NUM_THREADS=8 python conv1d.py | grep "Convolution\|Activation"
Activation                             60          39.8030           0.4470           1.8200           0.6634
Convolution                            60         179.9880           2.2860           5.4140           2.9998

Have you ever tried the cpu affinity before run multi-thread case?

export KMP_AFFINITY=granularity=fine,compact

@TaoLv
Copy link
Member

TaoLv commented Jan 4, 2019

Ummm, just notice that you're using c4.8xlarge which I think have no AVX512. Will have another try and come back to you later.

@TaoLv
Copy link
Member

TaoLv commented Jan 4, 2019

Confirmed this problem exists on machine without AVX512 as threading optimization for 1D Convolution is made for AVX512 only. Would you mind having a try on AWS c5? @bputrycz

@bputrycz
Copy link

bputrycz commented Jan 4, 2019

Yes. I confirm.
It scales much better on AWS c5.

Thank you very much.

@pengzhao-intel
Copy link
Contributor

Yes. I confirm.
It scales much better on AWS c5.

Thank you very much.

It's good to see the problem can be resolved on AWS c5.
Feel free to ping us if you have any questions or issues :)
We will document this behavior in MKLDNN README @xinyu-intel

@TaoLv
Copy link
Member

TaoLv commented Jan 4, 2019

@bputrycz Also feel free to let me know if the performance on c4 is critical for you.

@bputrycz
Copy link

bputrycz commented Jan 4, 2019

@TaoLv Performance on c4 is not critical for me, as for now.

rondogency pushed a commit to rondogency/incubator-mxnet that referenced this pull request Jan 9, 2019
* add 3d layout support for MKLDNN Conv and Activation

* fix lint

* code refactor

* add testcase for group1 conv and skip quantization for conv1d

* fix lint

* avoid conv1d quantization

* code refactor and add activation ut

* del todo
haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* add 3d layout support for MKLDNN Conv and Activation

* fix lint

* code refactor

* add testcase for group1 conv and skip quantization for conv1d

* fix lint

* avoid conv1d quantization

* code refactor and add activation ut

* del todo
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
MKLDNN Operator pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants