Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support SyncBatchNorm5D #14542

Merged
merged 27 commits into from
Apr 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/operator/contrib/sync_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ struct SyncBatchNormParam : public dmlc::Parameter<SyncBatchNormParam> {
DMLC_DECLARE_FIELD(ndev).set_default(1)
.describe("The count of GPU devices");
DMLC_DECLARE_FIELD(key)
.set_default("")
.describe("Hash key for synchronization, please set the same hash key for same layer, "
"Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`.");
}
Expand Down Expand Up @@ -275,14 +274,18 @@ class SyncBatchNorm : public Operator {
static_cast<real_t>(in_data[syncbatchnorm::kData].shape_.Size());
Tensor<xpu, 4> data;
Tensor<xpu, 4> out;
if (in_data[syncbatchnorm::kData].ndim() == 2) {
if (in_data[syncbatchnorm::kData].ndim() == 4) {
zhreshold marked this conversation as resolved.
Show resolved Hide resolved
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
} else {
index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ?
in_data[syncbatchnorm::kData].shape_[1] : 1;
index_t spatial_size = in_data[syncbatchnorm::kData].shape_.ProdShape(2,
in_data[syncbatchnorm::kData].ndim());
Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0],
in_data[syncbatchnorm::kData].shape_[1], 1, 1);
num_channels, 1, spatial_size);
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
out = out_data[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
} else {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
out = out_data[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
}
Tensor<xpu, 1> slope = in_data[syncbatchnorm::kGamma].get<xpu, 1, real_t>(s);
Tensor<xpu, 1> bias = in_data[syncbatchnorm::kBeta].get<xpu, 1, real_t>(s);
Expand Down Expand Up @@ -354,16 +357,20 @@ class SyncBatchNorm : public Operator {
Tensor<xpu, 4> data, grad, grad_in;
const real_t scale = static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_[1]) /
static_cast<real_t>(out_grad[syncbatchnorm::kOut].shape_.Size());
if (in_data[syncbatchnorm::kData].ndim() == 2) {
if (in_data[syncbatchnorm::kData].ndim() == 4) {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
} else {
index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ?
out_grad[syncbatchnorm::kOut].shape_[1] : 1;
index_t spatial_size = out_grad[syncbatchnorm::kOut].shape_.ProdShape(2,
out_grad[syncbatchnorm::kOut].ndim());
Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0],
out_grad[syncbatchnorm::kOut].shape_[1], 1, 1);
num_channels, 1, spatial_size);
data = in_data[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
grad = out_grad[syncbatchnorm::kOut].get_with_shape<xpu, 4, real_t>(dshape, s);
grad_in = in_grad[syncbatchnorm::kData].get_with_shape<xpu, 4, real_t>(dshape, s);
} else {
data = in_data[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
grad = out_grad[syncbatchnorm::kOut].get<xpu, 4, real_t>(s);
grad_in = in_grad[syncbatchnorm::kData].get<xpu, 4, real_t>(s);
}

Tensor<xpu, 1> mean = out_data[syncbatchnorm::kMean].get<xpu, 1, real_t>(s);
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ void BatchNormCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Forward(ctx, in_data, req, outputs, aux_states);
Expand Down Expand Up @@ -697,7 +697,7 @@ void BatchNormGradCompute<gpu>(const nnvm::NodeAttrs& attrs,

param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis);
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4
if (!param.use_global_stats && !param.cudnn_off
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) {
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
GetCuDNNOp<DType>(param).Backward(ctx, inputs, req, outputs);
Expand Down
14 changes: 8 additions & 6 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ class CuDNNBatchNormOp {
}
CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo);
CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2);
CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4);

Init(in_data[cudnnbatchnorm::kData]);
Stream<gpu> *s = ctx.get_stream<gpu>();
Expand Down Expand Up @@ -273,12 +272,15 @@ class CuDNNBatchNormOp {

private:
void Init(const TBlob &in_data) {
for (int i = 0; i < 4; ++i) {
if (i < in_data.ndim()) {
if (in_data.ndim() == 4) {
for (int i = 0; i < 4; ++i)
shape_[i] = in_data.shape_[i];
} else {
shape_[i] = 1;
}
} else {
// when in_data.ndim() != 4
shape_[0] = in_data.shape_[0];
shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1;
shape_[2] = 1;
shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim());
}

CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_,
Expand Down
Loading