Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix fused resnet low accuracy (#21122)
Browse files Browse the repository at this point in the history
* Change flag to postop

* Add attributes to batch norm relu node

* Refactor code with batch norm relu op

* Delete fuse_norm_relu flag

* Delete fuse_norm_relu flag

* Refactor BN operator

* Review suggestions

* Review suggestions once again

* Fix formatting

* Fix lint errors
  • Loading branch information
hankaj committed Aug 22, 2022
1 parent 7748ae7 commit daac02c
Show file tree
Hide file tree
Showing 12 changed files with 495 additions and 564 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
# are dtype neutral (can work in both bf16 and fp32)
BF16_FP32_FUNCS = [
'_contrib_AdaptiveAvgPooling2D',
'_contrib_BatchNormWithReLU',
'Activation',
'BatchNorm',
'LayerNorm',
Expand Down Expand Up @@ -102,6 +101,7 @@
if Features.instance.is_enabled('ONEDNN'):
WIDEST_TYPE_CASTS.extend([
'_sg_onednn_batch_dot',
'_sg_onednn_batch_norm',
])

# Functions that when running with Bfloat16, the params that still need float32.
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
'BlockGrad',
'Cast',
'cast_storage',
'_contrib_BatchNormWithReLU',
'_contrib_allclose',
'_contrib_arange_like',
'_contrib_dynamic_reshape',
Expand Down Expand Up @@ -637,6 +636,7 @@
'_sg_onednn_selfatt_qk',
'_sg_onednn_selfatt_valatt',
'_sg_onednn_batch_dot',
'_sg_onednn_batch_norm',
'_sg_pow_mul_scalar'
])

Expand Down
81 changes: 4 additions & 77 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# pylint: disable= arguments-differ
"""Basic neural network layers."""
__all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
'BatchNorm', 'SyncBatchNorm', 'BatchNormReLU', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'BatchNorm', 'SyncBatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
'Flatten', 'Lambda', 'HybridLambda', 'Concatenate', 'HybridConcatenate', 'Identity']
import warnings
import uuid
Expand Down Expand Up @@ -322,8 +322,6 @@ class _BatchNorm(HybridBlock):
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
fuse_relu: bool, default False
If True, this operator is equal to `BN+ReLU`.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Expand All @@ -345,14 +343,13 @@ class _BatchNorm(HybridBlock):
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False, fuse_relu=False,
use_global_stats=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
super(_BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
self.fuse_relu = fuse_relu
self._axis = axis
if in_channels != 0:
self.in_channels = in_channels
Expand Down Expand Up @@ -383,13 +380,7 @@ def cast(self, dtype):

def forward(self, x):
device = x.device
if self.fuse_relu:
return npx.batch_norm_with_relu(x, self.gamma.data(device), self.beta.data(device),
self.running_mean.data(device),
self.running_var.data(device),
name='fwd', **self._kwargs)
else:
return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
return npx.batch_norm(x, self.gamma.data(device), self.beta.data(device),
self.running_mean.data(device),
self.running_var.data(device),
name='fwd', **self._kwargs)
Expand Down Expand Up @@ -467,71 +458,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
super(BatchNorm, self).__init__(
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale,
use_global_stats=use_global_stats, fuse_relu=False,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
running_variance_initializer=running_variance_initializer,
in_channels=in_channels, **kwargs)


class BatchNormReLU(_BatchNorm):
"""Batch normalization layer (Ioffe and Szegedy, 2014).
Normalizes the input at each batch, i.e. applies a transformation
that maintains the mean activation close to 0 and the activation
standard deviation close to 1.
Parameters
----------
axis : int, default 1
The axis that should be normalized. This is typically the channels
(C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
set `axis=1` in `BatchNorm`. If `layout='NHWC'`, then set `axis=3`.
momentum: float, default 0.9
Momentum for the moving average.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
use_global_stats: bool, default False
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
running_mean_initializer: str or `Initializer`, default 'zeros'
Initializer for the running mean.
running_variance_initializer: str or `Initializer`, default 'ones'
Initializer for the running variance.
in_channels : int, default 0
Number of channels (feature maps) in input data. If not specified,
initialization will be deferred to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.
"""
def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
use_global_stats=False,
beta_initializer='zeros', gamma_initializer='ones',
running_mean_initializer='zeros', running_variance_initializer='ones',
in_channels=0, **kwargs):
super(BatchNormReLU, self).__init__(
axis=axis, momentum=momentum, epsilon=epsilon,
center=center, scale=scale,
use_global_stats=use_global_stats, fuse_relu=True,
use_global_stats=use_global_stats,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
running_mean_initializer=running_mean_initializer,
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
#endif

/*! \brief inverse standard deviation <-> variance */
#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0 / std::sqrt((__var$) + DType(__eps$)))
#define VARIANCE_TO_INVSTD(__var$, __eps$) (1.0 / std::sqrt((__var$) + (__eps$)))
#define INVSTD_TO_VARIANCE(__invstd$, __eps$) ((1.0 / ((__invstd$) * (__invstd$))) - (__eps$))

namespace mxnet {
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,7 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 5U);
if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
});
DNNLRun(DNNLBatchNormForward</*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand All @@ -491,7 +489,7 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
if (SupportDNNLBN(inputs[0])) {
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
Loading

0 comments on commit daac02c

Please sign in to comment.