Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix performance regression in normalize operator #14055

Merged
136 changes: 98 additions & 38 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,37 +217,50 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
template<int req>
struct normalize_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int j, DType* out_data, const DType* in_data,
const int i, const int length, const int step,
const DType mean, const DType std_dev) {
KERNEL_ASSIGN(out_data[step + i*length + j], req,
(in_data[step + i*length + j] - mean) / std_dev);
MSHADOW_XINLINE static void Map(uint32_t c, DType* out_data, const DType* in_data,
const float mean_d0, const float mean_d1, const float mean_d2,
const float std_d0, const float std_d1, const float std_d2,
const int length, const int step) {
float mean, std;
switch (c) {
case 0 : mean = mean_d0;
std = std_d0;
break;
case 1 : mean = mean_d1;
std = std_d1;
break;
case 2 : mean = mean_d2;
std = std_d2;
break;
}
#pragma omp parallel for
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + c*length + i] - mean) / std);
}
}
};

template<typename xpu>
void NormalizeImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const NormalizeParam &param,
const int length,
const uint32_t channel,
const int step = 0) {
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const float mean_d0, const float mean_d1,
const float mean_d2, const float std_d0,
const float std_d1, const float std_d2,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
DType* input = inputs[0].dptr<DType>();
DType* output = outputs[0].dptr<DType>();

for (uint32_t i = 0; i < channel; ++i) {
DType mean = param.mean[param.mean.ndim() > i ? i : 0];
DType std_dev = param.std[param.std.ndim() > i ? i : 0];
mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
s, length, output, input,
i, length, step, mean, std_dev);
}
mxnet_op::Kernel<normalize_forward<req_type>, xpu>::Launch(
s, channel, output, input, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, step);
});
});
}
Expand All @@ -264,11 +277,35 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);

// Note: We need mean and std_dev in the kernel.
// It is costly (device copy) to pass it as vector, for gpu kernel.
// Hence, passing it as below for performance.
float mean_d0, mean_d1, mean_d2;
float std_d0, std_d1, std_d2;

// Mean and Std can be 1 or 3 D only.
if (param.mean.ndim() == 1) {
mean_d0 = mean_d1 = mean_d2 = param.mean[0];
} else {
mean_d0 = param.mean[0];
mean_d1 = param.mean[1];
mean_d2 = param.mean[2];
}

if (param.std.ndim() == 1) {
std_d0 = std_d1 = std_d2 = param.std[0];
} else {
std_d0 = param.std[0];
std_d1 = param.std[1];
std_d2 = param.std[2];
}

// 3D input (c, h, w)
if (inputs[0].ndim() == 3) {
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[0];
NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, channel);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = inputs[0].shape_[0];
Expand All @@ -278,7 +315,8 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
NormalizeImpl<xpu>(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2,
std_d0, std_d1, std_d2, length, channel, n*step);
}
}
}
Expand All @@ -287,12 +325,25 @@ void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
template<int req>
struct normalize_backward {
template<typename DType>
MSHADOW_XINLINE static void Map(int j, DType* in_grad, const DType* out_grad,
const int i, const int length,
const int step, const DType std_dev) {
MSHADOW_XINLINE static void Map(uint32_t c, DType* in_grad, const DType* out_grad,
const float std_d0, const float std_d1, const float std_d2,
const int length, const int step) {
// d/dx{(x - mean) / std_dev} => (1 / std_dev)
KERNEL_ASSIGN(in_grad[step + i*length + j], req,
out_grad[step + i*length + j] * (1.0 / std_dev));
float std_dev;
switch (c) {
case 0 : std_dev = std_d0;
break;
case 1 : std_dev = std_d1;
break;
case 2 : std_dev = std_d2;
break;
}

#pragma omp parallel for
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(in_grad[step + c*length + i], req,
out_grad[step + c*length + i] * (1.0 / std_dev));
}
}
};

Expand All @@ -301,21 +352,18 @@ void NormalizeBackwardImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const NormalizeParam &param,
const float std_d0, const float std_d1, const float std_d2,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& out_grad = inputs[0];
const TBlob& in_grad = outputs[0];

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
for (uint32_t i = 0; i < channel; ++i) {
DType std_dev = param.std[param.std.ndim() > i ? i : 0];
mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
s, length, in_grad.dptr<DType>(), out_grad.dptr<DType>(),
i, length, step, std_dev);
}
DType* out_grad = inputs[0].dptr<DType>();
DType* in_grad = outputs[0].dptr<DType>();
mxnet_op::Kernel<normalize_backward<req_type>, xpu>::Launch(
s, channel, in_grad, out_grad, std_d0, std_d1, std_d2, length, step);
});
});
}
Expand All @@ -331,6 +379,16 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
CHECK_EQ(req.size(), 1U);

const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
float std_d0, std_d1, std_d2;

// Std can be 1 or 3 D only
if (param.std.ndim() == 1) {
std_d0 = std_d1 = std_d2 = param.std[0];
} else {
std_d0 = param.std[0];
std_d1 = param.std[1];
std_d2 = param.std[2];
}

// Note: inputs[0] is out_grad
const TBlob& in_data = inputs[1];
Expand All @@ -339,7 +397,7 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
if (in_data.ndim() == 3) {
const int length = in_data.shape_[1] * in_data.shape_[2];
const uint32_t channel = in_data.shape_[0];
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel);
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, std_d0, std_d1, std_d2, length, channel);
} else if (in_data.ndim() == 4) {
// 4D input (n, c, h, w)
const int batch_size = in_data.shape_[0];
Expand All @@ -349,7 +407,9 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req, param, length, channel, n*step);
NormalizeBackwardImpl<xpu>(ctx, inputs, outputs, req,
std_d0, std_d1, std_d2, length,
channel, n*step);
}
}
}
Expand Down
33 changes: 8 additions & 25 deletions tests/python/gpu/test_gluon_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,32 +80,15 @@ def test_to_tensor():
data_in.astype(dtype=np.float32) / 255.0, (2, 0, 1)))

# 4D Input
data_in_4d = nd.random.uniform(0, 1, (2, 3, 300, 300))
out_nd_4d = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))(data_in_4d)
data_expected_4d = data_in_4d.asnumpy()
data_expected_4d[0][:][:][0] = data_expected_4d[0][:][:][0] / 3.0
data_expected_4d[0][:][:][1] = (data_expected_4d[0][:][:][1] - 1.0) / 2.0
data_expected_4d[0][:][:][2] = data_expected_4d[0][:][:][2] - 2.0
data_expected_4d[1][:][:][0] = data_expected_4d[1][:][:][0] / 3.0
data_expected_4d[1][:][:][1] = (data_expected_4d[1][:][:][1] - 1.0) / 2.0
data_expected_4d[1][:][:][2] = data_expected_4d[1][:][:][2] - 2.0
assert_almost_equal(data_expected_4d, out_nd_4d.asnumpy())

# Default normalize values i.e., mean=0, std=1
data_in_3d_def = nd.random.uniform(0, 1, (3, 300, 300))
out_nd_3d_def = transforms.Normalize()(data_in_3d_def)
data_expected_3d_def = data_in_3d_def.asnumpy()
assert_almost_equal(data_expected_3d_def, out_nd_3d_def.asnumpy())

# Invalid Input - Neither 3D or 4D input
invalid_data_in = nd.random.uniform(0, 1, (5, 5, 3, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)
data_in = np.random.uniform(0, 255, (5, 300, 300, 3)).astype(dtype=np.uint8)
out_nd = transforms.ToTensor()(nd.array(data_in, dtype='uint8'))
assert_almost_equal(out_nd.asnumpy(), np.transpose(
data_in.astype(dtype=np.float32) / 255.0, (0, 3, 1, 2)))

# Invalid Input - Channel neither 1 or 3
invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300))
normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1))
assertRaises(MXNetError, normalize_transformer, invalid_data_in)
# Invalid Input
invalid_data_in = nd.random.uniform(0, 255, (5, 5, 300, 300, 3)).astype(dtype=np.uint8)
transformer = transforms.ToTensor()
assertRaises(MXNetError, transformer, invalid_data_in)

@with_seed()
def test_resize():
Expand Down