Skip to content

Commit

Permalink
[MKL-DNN] Integrate Conv3d and Pool3d/1d (apache#17884)
Browse files Browse the repository at this point in the history
* Integrate MKl-DNN conv3d and pool3d/1d

* fix UT & address comments

* clean code

* rebase against latest master
  • Loading branch information
wuxun-zhang authored and ChaiBapchya committed Apr 22, 2020
1 parent 10b6aef commit 20632af
Show file tree
Hide file tree
Showing 14 changed files with 533 additions and 294 deletions.
21 changes: 18 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,29 @@ bool SupportMKLDNNAct(const ActivationParam& param) {
}

bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d data layout
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 4) ||
(input.dtype() != mshadow::kFloat32))
(input.shape().ndim() > 5) ||
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
return false;
return SupportMKLDNNAct(param);
}

bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param) {
return param.act_type == leakyrelu::kLeakyReLU
|| param.act_type == leakyrelu::kELU
|| param.act_type == leakyrelu::kGELU;
}

bool SupportMKLDNNLeakyRelu(const LeakyReLUParam& param, const NDArray &input) {
// MKL-DNN Activation supports 1d, 2d, 3d, 4d and 5d data layout
if ((input.shape().ndim() < 1) ||
(input.shape().ndim() > 5) ||
!(input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16))
return false;
return SupportMKLDNNLeakyRelu(param);
}

bool SupportQuantizedMKLDNNAct(const ActivationParam &param) {
// TODO(zhennan): Add more activation type when mkldnn supports.
// Remove this when it's identity to SupportMKLDNNAct.
Expand Down
45 changes: 31 additions & 14 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,12 @@ static inline bool SupportStorageMKLDNN(int stype) {

static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
int ndim = shape.ndim();
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
if (ndim == 0 || shape.Size() == 0) {
// MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
return false;
}
return (dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16) &&
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNQuantize(int dtype) {
Expand Down Expand Up @@ -263,20 +268,32 @@ inline static mkldnn::memory::desc GetWeightDesc(const NDArray &arr,
if (num_groups == 1) {
return GetMemDesc(arr, dtype);
} else {
auto ndim = arr.shape().ndim();
CHECK((ndim == 3) || (ndim == 4))
<< "MKL-DNN weight currectly supports 3d and 4d layout";
const auto ndim = arr.shape().ndim();
CHECK((ndim == 3) || (ndim == 4) || (ndim == 5))
<< "MKL-DNN weight currently supports 3d or 4d or 5d layout";
auto tz = mkldnn::memory::dims{0};
const int N = 0, H = 2, W = 3, C = 1;
if (ndim == 3) {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H])};
} else {
tz = mkldnn::memory::dims{
num_groups, static_cast<int>(arr.shape()[N] / num_groups),
static_cast<int>(arr.shape()[C]), static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])};
int N = 0, C = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
switch (ndim) {
case 3:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[H]};
break;
case 4:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[H], arr.shape()[W]};
break;
case 5:
tz = mkldnn::memory::dims{
num_groups, arr.shape()[N] / num_groups,
arr.shape()[C], arr.shape()[D], arr.shape()[H], arr.shape()[W]};
}
return mkldnn::memory::desc{tz, get_mkldnn_type(dtype), mkldnn::memory::format::any};
}
Expand Down
56 changes: 36 additions & 20 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,29 +234,45 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
auto tz = mkldnn::memory::dims{0};
auto format = mkldnn::memory::format::format_undef;
auto engine = CpuEngine::Get()->get_engine();
const int O = 0, I = 1, H = 2, W = 3;
if (arr.shape().ndim() == 2) {
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I])};
format = mkldnn::memory::format::oi;
} else if (arr.shape().ndim() == 3) {

const int ndim = arr.shape().ndim();
int O = 0, I = 1, H = 2, W = 3;
int D = -1;
if (ndim == 5) {
D = 2;
H = 3;
W = 4;
}
if (ndim == 2) {
tz = mkldnn::memory::dims{arr.shape()[O], arr.shape()[I]};
format_tag = mkldnn::memory::format_tag::oi;
} else if (ndim == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])}
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])};
format = num_groups > 1 ? mkldnn::memory::format::goiw : mkldnn::memory::format::oiw;
} else if (arr.shape().ndim() == 4) {
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[H]}
: mkldnn::memory::dims{arr.shape()[O],
arr.shape()[I], arr.shape()[H]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw
: mkldnn::memory::format_tag::oiw;
} else if (ndim == 4) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[H],
arr.shape()[W]}
: mkldnn::memory::dims{
arr.shape()[O], arr.shape()[I], arr.shape()[H], arr.shape()[W]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw
: mkldnn::memory::format_tag::oihw;
} else if (ndim == 5) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]),
static_cast<int>(arr.shape()[W])}
? mkldnn::memory::dims{num_groups, arr.shape()[O] / num_groups,
arr.shape()[I], arr.shape()[D],
arr.shape()[H], arr.shape()[W]}
: mkldnn::memory::dims{
static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]), static_cast<int>(arr.shape()[W])};
format = num_groups > 1 ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw;
arr.shape()[O], arr.shape()[I], arr.shape()[D],
arr.shape()[H], arr.shape()[W]};
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::oidhw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
Expand Down
60 changes: 52 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ DMLC_REGISTER_PARAMETER(MKLDNNConvParam);

bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input) {
if ((params.kernel.ndim() != 1) &&
(params.kernel.ndim() != 2))
(params.kernel.ndim() != 2) &&
(params.kernel.ndim() != 3))
return false;
return SupportMKLDNNQuantize(input.dtype()) &&
((input.shape().ndim() == 3) ||
(input.shape().ndim() == 4));
(input.shape().ndim() == 4) ||
(input.shape().ndim() == 5));
}

mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam &param,
Expand Down Expand Up @@ -76,9 +78,19 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
strides[1] = param.conv_param.stride[1];
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
} else if (param.conv_param.kernel.ndim() == 3) {
CHECK_GE(param.conv_param.stride.ndim(), 3);
CHECK_GE(param.conv_param.pad.ndim(), 3);
CHECK_GE(param.conv_param.dilate.ndim(), 3);
strides[0] = param.conv_param.stride[0];
strides[1] = param.conv_param.stride[1];
strides[2] = param.conv_param.stride[2];
padding[0] = param.conv_param.pad[0];
padding[1] = param.conv_param.pad[1];
padding[2] = param.conv_param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size "
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2.";
<< param.conv_param.kernel.ndim() << ", supporting only 1 or 2 or 3.";
}
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
Expand Down Expand Up @@ -141,9 +153,13 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
} else if (param.conv_param.dilate.ndim() == 2) {
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
} else if (param.conv_param.dilate.ndim() == 3) {
dilates[0] = param.conv_param.dilate[0] - 1;
dilates[1] = param.conv_param.dilate[1] - 1;
dilates[2] = param.conv_param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size " << param.conv_param.dilate.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}
if (bias_md_ptr == nullptr) {
mkldnn::convolution_forward::desc desc(prop, mkldnn::algorithm::convolution_direct, data_md,
Expand Down Expand Up @@ -182,9 +198,19 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.kernel.ndim() == 3) {
CHECK_GE(param.stride.ndim(), 3);
CHECK_GE(param.pad.ndim(), 3);
CHECK_GE(param.dilate.ndim(), 3);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
strides[2] = param.stride[2];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
padding[2] = param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}

// MKL-DNN introduced padded formats since 0.15 which require more memory
Expand All @@ -209,9 +235,13 @@ static mkldnn::convolution_backward_data::primitive_desc GetConvBwdData(
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
} else if (param.dilate.ndim() == 3) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
dilates[2] = param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
<< param.dilate.ndim() << ", supporting only 1 or 2.";
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
}
mkldnn::convolution_backward_data::desc desc(mkldnn::algorithm::convolution_direct,
data_md, weight_md, out_md, strides, dilates, padding, padding,
Expand Down Expand Up @@ -250,9 +280,19 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
strides[1] = param.stride[1];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.kernel.ndim() == 3) {
CHECK_GE(param.stride.ndim(), 3);
CHECK_GE(param.pad.ndim(), 3);
CHECK_GE(param.dilate.ndim(), 3);
strides[0] = param.stride[0];
strides[1] = param.stride[1];
strides[2] = param.stride[2];
padding[0] = param.pad[0];
padding[1] = param.pad[1];
padding[2] = param.pad[2];
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv kernel size " << param.kernel.ndim()
<< ", supporting only 1 or 2.";
<< ", supporting only 1 or 2 or 3.";
}

// MKL-DNN introduced padded formats since 0.15 which require more memory
Expand Down Expand Up @@ -289,9 +329,13 @@ static mkldnn::convolution_backward_weights::primitive_desc GetConvBwdWeights(
} else if (param.dilate.ndim() == 2) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
} else if (param.dilate.ndim() == 3) {
dilates[0] = param.dilate[0] - 1;
dilates[1] = param.dilate[1] - 1;
dilates[2] = param.dilate[2] - 1;
} else {
LOG(FATAL) << "Unexpected MKL-DNN Conv dilate size "
<< param.dilate.ndim() << ", supporting only 1 or 2.";
<< param.dilate.ndim() << ", supporting only 1 or 2 or 3.";
}
if (bias == nullptr) {
mkldnn::convolution_backward_weights::desc desc(mkldnn::algorithm::convolution_direct,
Expand Down
62 changes: 43 additions & 19 deletions src/operator/nn/mkldnn/mkldnn_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ class MKLDNNPoolingFwd {
public:
MKLDNNPoolingFwd(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r,
const mkldnn::memory::dims &kernel,
const mkldnn::memory::dims &strides,
const mkldnn::memory::dims &pad_l,
const mkldnn::memory::dims &pad_r,
const mkldnn::algorithm alg_kind,
const bool with_workspace, const bool is_train) :
is_train_(is_train),
with_workspace_(with_workspace),
alg_kind_(alg_kind),
fwd_(nullptr), data_(nullptr), out_(nullptr), workspace_(nullptr) {
Init(input, output,
kernel_h, kernel_w, stride_h, stride_w,
padding_t, padding_b, padding_l, padding_r);
fwd_(nullptr) {
Init(input, output, kernel, strides, pad_l, pad_r,
is_train, alg_kind);
}

~MKLDNNPoolingFwd() {}
Expand All @@ -74,10 +72,11 @@ class MKLDNNPoolingFwd {
private:
void Init(const mxnet::NDArray &input,
const mxnet::NDArray &output,
const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w,
const int padding_t, const int padding_b,
const int padding_l, const int padding_r);
const mkldnn::memory::dims &kernel,
const mkldnn::memory::dims &strides,
const mkldnn::memory::dims &pad_l,
const mkldnn::memory::dims &pad_r,
const bool is_train, const mkldnn::algorithm alg_kind);
};

class MKLDNNPoolingBwd {
Expand All @@ -102,22 +101,47 @@ class MKLDNNPoolingBwd {
};

inline bool SupportMKLDNNPooling(const PoolingParam &param) {
return param.kernel.ndim() == 2 &&
return (param.kernel.ndim() == 1 || param.kernel.ndim() == 2 ||
param.kernel.ndim() == 3) &&
(param.pool_type == pool_enum::kMaxPooling ||
param.pool_type == pool_enum::kAvgPooling) &&
(!param.layout.has_value() || param.layout.value() == mshadow::kNCHW);
(!param.layout.has_value() ||
(param.layout.value() == mshadow::kNCW || param.layout.value() == mshadow::kNCHW ||
param.layout.value() == mshadow::kNCDHW));
}

inline bool SupportMKLDNNPooling(const PoolingParam &param,
const mxnet::TShape &dshape) {
bool ret = SupportMKLDNNPooling(param);
if (!ret)
const NDArray &input) {
const auto dshape = input.shape();
const auto ndim = dshape.ndim();
const auto dtype = input.dtype();

if (!(SupportStorageMKLDNN(input.storage_type()) && (ndim == 3 || ndim == 4 || ndim == 5) &&
(dtype == mshadow::kFloat32 || dtype == mshadow::kBfloat16)))
return false;

if (!SupportMKLDNNPooling(param))
return false;

if (param.pooling_convention == pool_enum::kValid) {
return true;
} else {
// currently, only max-pooling is supported for full convention
if (param.pool_type == pool_enum::kAvgPooling) {
// mkldnn works differently when padding is asymmetric, so let's skip this case.
bool is_symmetric = true;
switch (ndim) {
case 5:
is_symmetric = is_symmetric && (param.pad[2] == GetPaddingSizeFull(dshape[4],
param.pad[2], param.pad[2], param.kernel[2], param.stride[2]));
case 4:
is_symmetric = is_symmetric && (param.pad[1] == GetPaddingSizeFull(dshape[3],
param.pad[1], param.pad[1], param.kernel[1], param.stride[1]));
case 3:
is_symmetric = is_symmetric && (param.pad[0] == GetPaddingSizeFull(dshape[2],
param.pad[0], param.pad[0], param.kernel[0], param.stride[0]));
}
return is_symmetric;
}
return param.pool_type == pool_enum::kMaxPooling;
}
}
Expand Down
Loading

0 comments on commit 20632af

Please sign in to comment.