From df2f2642288173735712b415591d6eacf55bcc28 Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 27 Mar 2019 20:47:35 +0800 Subject: [PATCH 01/25] support SyncBatchNorm5D --- src/operator/contrib/sync_batch_norm-inl.h | 36 ++++++++++++++-------- tests/python/gpu/test_gluon_gpu.py | 2 ++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index b94416640f55..82768d3d878a 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -275,14 +275,20 @@ class SyncBatchNorm : public Operator { static_cast(in_data[syncbatchnorm::kData].shape_.Size()); Tensor data; Tensor out; - if (in_data[syncbatchnorm::kData].ndim() == 2) { + if (in_data[syncbatchnorm::kData].ndim() == 4) { + data = in_data[syncbatchnorm::kData].get(s); + out = out_data[syncbatchnorm::kOut].get(s); + } else { + index_t num_channels = in_data[syncbatchnorm::kData].shape_[1]; + if (in_data[syncbatchnorm::kData].ndim() > 4) { + // ignore the last two axes + for (index_t i = 2; i < in_data[syncbatchnorm::kData].Size() - 2; ++i) + num_channels *= in_data[syncbatchnorm::kData].shape_[i]; + } Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], - in_data[syncbatchnorm::kData].shape_[1], 1, 1); + num_channels, 1, 1); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); - } else { - data = in_data[syncbatchnorm::kData].get(s); - out = out_data[syncbatchnorm::kOut].get(s); } Tensor slope = in_data[syncbatchnorm::kGamma].get(s); Tensor bias = in_data[syncbatchnorm::kBeta].get(s); @@ -354,16 +360,22 @@ class SyncBatchNorm : public Operator { Tensor data, grad, grad_in; const real_t scale = static_cast(out_grad[syncbatchnorm::kOut].shape_[1]) / static_cast(out_grad[syncbatchnorm::kOut].shape_.Size()); - if (in_data[syncbatchnorm::kData].ndim() == 2) { - Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], - out_grad[syncbatchnorm::kOut].shape_[1], 1, 1); - data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); - grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); - grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); - } else { + if (in_data[syncbatchnorm::kData].ndim() == 4) { data = in_data[syncbatchnorm::kData].get(s); grad = out_grad[syncbatchnorm::kOut].get(s); grad_in = in_grad[syncbatchnorm::kData].get(s); + } else { + index_t num_channels = out_grad[syncbatchnorm::kOut].shape_[1]; + if (out_grad[syncbatchnorm::kOut].ndim() > 4) { + // ignore the last two axes + for (index_t i = 2; i < out_grad[syncbatchnorm::kOut].Size() - 2; ++i) + num_channels *= out_grad[syncbatchnorm::kOut].shape_[i]; + } + Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], + num_channels, 1, 1); + data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); + grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); + grad_in = in_grad[syncbatchnorm::kData].ge4_with_shape(dshape, s); } Tensor mean = out_data[syncbatchnorm::kMean].get(s); diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 9eeeec749211..be4da7157b1e 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -340,6 +340,8 @@ def get_num_devices(): for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 3, 4, 4)), + num_devices=ndev, cuda=True) @with_seed() From c3142414889e870dcb88d11ded593810617ab7a4 Mon Sep 17 00:00:00 2001 From: wkcn Date: Wed, 27 Mar 2019 21:23:19 +0800 Subject: [PATCH 02/25] fix --- src/operator/contrib/sync_batch_norm-inl.h | 24 ++++++++-------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 82768d3d878a..09ddfc7366a4 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -279,14 +279,11 @@ class SyncBatchNorm : public Operator { data = in_data[syncbatchnorm::kData].get(s); out = out_data[syncbatchnorm::kOut].get(s); } else { - index_t num_channels = in_data[syncbatchnorm::kData].shape_[1]; - if (in_data[syncbatchnorm::kData].ndim() > 4) { - // ignore the last two axes - for (index_t i = 2; i < in_data[syncbatchnorm::kData].Size() - 2; ++i) - num_channels *= in_data[syncbatchnorm::kData].shape_[i]; - } + index_t spatial_size = in_data[syncbatchnorm::kData].Size() / ( + in_data[syncbatchnorm::kData].shape_[0] * + in_data[syncbatchnorm::kData].shape_[1]); Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], - num_channels, 1, 1); + in_data[syncbatchnorm::kData].shape_[1], 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); } @@ -365,17 +362,14 @@ class SyncBatchNorm : public Operator { grad = out_grad[syncbatchnorm::kOut].get(s); grad_in = in_grad[syncbatchnorm::kData].get(s); } else { - index_t num_channels = out_grad[syncbatchnorm::kOut].shape_[1]; - if (out_grad[syncbatchnorm::kOut].ndim() > 4) { - // ignore the last two axes - for (index_t i = 2; i < out_grad[syncbatchnorm::kOut].Size() - 2; ++i) - num_channels *= out_grad[syncbatchnorm::kOut].shape_[i]; - } + index_t spatial_size = out_grad[syncbatchnorm::kOut].Size() / ( + out_grad[syncbatchnorm::kOut].shape_[0] * + out_grad[syncbatchnorm::kOut].shape_[1]); Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], - num_channels, 1, 1); + out_grad[syncbatchnorm::kOut].shape_[1], 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); - grad_in = in_grad[syncbatchnorm::kData].ge4_with_shape(dshape, s); + grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); } Tensor mean = out_data[syncbatchnorm::kMean].get(s); From 46bc4269823472c8d85ee07299a28abc48878690 Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 08:57:26 +0800 Subject: [PATCH 03/25] update testcase and reformat code --- src/operator/contrib/sync_batch_norm-inl.h | 14 +- tests/python/gpu/test_gluon_gpu.py | 185 ++++++++++++--------- 2 files changed, 119 insertions(+), 80 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 09ddfc7366a4..17339c996dcd 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -279,11 +279,12 @@ class SyncBatchNorm : public Operator { data = in_data[syncbatchnorm::kData].get(s); out = out_data[syncbatchnorm::kOut].get(s); } else { + index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ? + in_data[syncbatchnorm::kData].shape_[1] : 1; index_t spatial_size = in_data[syncbatchnorm::kData].Size() / ( - in_data[syncbatchnorm::kData].shape_[0] * - in_data[syncbatchnorm::kData].shape_[1]); + in_data[syncbatchnorm::kData].shape_[0] * num_channels); Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], - in_data[syncbatchnorm::kData].shape_[1], 1, spatial_size); + num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); } @@ -362,11 +363,12 @@ class SyncBatchNorm : public Operator { grad = out_grad[syncbatchnorm::kOut].get(s); grad_in = in_grad[syncbatchnorm::kData].get(s); } else { + index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ? + out_grad[syncbatchnorm::kOut].shape_[1] : 1; index_t spatial_size = out_grad[syncbatchnorm::kOut].Size() / ( - out_grad[syncbatchnorm::kOut].shape_[0] * - out_grad[syncbatchnorm::kOut].shape_[1]); + out_grad[syncbatchnorm::kOut].shape_[0] * num_channels); Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], - out_grad[syncbatchnorm::kOut].shape_[1], 1, spatial_size); + num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index be4da7157b1e..4008db60f765 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -45,6 +45,7 @@ set_default_context(mx.gpu(0)) + def check_rnn_layer(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) with mx.gpu(0): @@ -62,6 +63,7 @@ def check_rnn_layer(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + @with_seed() def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) @@ -89,11 +91,12 @@ def test_lstmp(): batch_size, seq_len = 7, 11 input_size = 5 ctx = mx.gpu(0) - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx) - shapes = {'i2h_weight': (hidden_size*4, input_size), - 'h2h_weight': (hidden_size*4, projection_size), - 'i2h_bias': (hidden_size*4,), - 'h2h_bias': (hidden_size*4,), + lstm_input = mx.nd.uniform( + shape=(seq_len, batch_size, input_size), ctx=ctx) + shapes = {'i2h_weight': (hidden_size * 4, input_size), + 'h2h_weight': (hidden_size * 4, projection_size), + 'i2h_bias': (hidden_size * 4,), + 'h2h_bias': (hidden_size * 4,), 'h2r_weight': (projection_size, hidden_size)} weights = {k: rand_ndarray(v) for k, v in shapes.items()} lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, @@ -107,23 +110,26 @@ def test_lstmp(): layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() for k, v in weights.items(): - layer_params['lstm0_l0_'+k].set_data(v.copy()) - cell_params['lstm0_l0_'+k].set_data(v.copy()) + layer_params['lstm0_l0_' + k].set_data(v.copy()) + cell_params['lstm0_l0_' + k].set_data(v.copy()) with autograd.record(): layer_output = lstm_layer(lstm_input.copy()) cell_output = lstm_cell.unroll(seq_len, lstm_input.copy(), layout='TNC', merge_outputs=True)[0] - assert_almost_equal(layer_output.asnumpy(), cell_output.asnumpy(), rtol=rtol, atol=atol) + assert_almost_equal(layer_output.asnumpy(), + cell_output.asnumpy(), rtol=rtol, atol=atol) layer_output.backward() cell_output.backward() for k, v in weights.items(): - layer_grad = layer_params['lstm0_l0_'+k].grad() - cell_grad = cell_params['lstm0_l0_'+k].grad() - print('checking gradient for {}'.format('lstm0_l0_'+k)) + layer_grad = layer_params['lstm0_l0_' + k].grad() + cell_grad = cell_params['lstm0_l0_' + k].grad() + print('checking gradient for {}'.format('lstm0_l0_' + k)) assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(), rtol=rtol, atol=atol) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM( + 10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones( + (8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), run_only=True, ctx=ctx) @@ -139,7 +145,8 @@ def test_lstm_clip(): batch_size, seq_len = 32, 80 input_size = 50 clip_min, clip_max, clip_nan = -5, 5, True - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) + lstm_input = mx.nd.uniform( + shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) lstm_states = [mx.nd.uniform(shape=(2, batch_size, projection_size), ctx=mx.gpu(0)), mx.nd.uniform(shape=(2, batch_size, hidden_size), ctx=mx.gpu(0))] lstm_layer = gluon.rnn.LSTM(hidden_size, projection_size=projection_size, @@ -165,7 +172,8 @@ def test_rnn_layer(): check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM( + 100, num_layers=3, bidirectional=True)) def check_layer_bidirectional(size, in_size, proj_size): @@ -173,8 +181,10 @@ class RefBiLSTM(gluon.Block): def __init__(self, size, proj_size, **kwargs): super(RefBiLSTM, self).__init__(**kwargs) with self.name_scope(): - self._lstm_fwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='l0') - self._lstm_bwd = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=False, prefix='r0') + self._lstm_fwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False, prefix='l0') + self._lstm_bwd = gluon.rnn.LSTM( + size, projection_size=proj_size, bidirectional=False, prefix='r0') def forward(self, inpt): fwd = self._lstm_fwd(inpt) @@ -184,16 +194,23 @@ def forward(self, inpt): return nd.concat(fwd, bwd, dim=2) weights = {} for d in ['l', 'r']: - weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform( + shape=(size * 4, in_size)) if proj_size: - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, proj_size)) - weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform(shape=(proj_size, size)) + weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform( + shape=(size * 4, proj_size)) + weights['lstm_{}0_h2r_weight'.format(d)] = mx.random.uniform( + shape=(proj_size, size)) else: - weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) - weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) - - net = gluon.rnn.LSTM(size, projection_size=proj_size, bidirectional=True, prefix='lstm_') + weights['lstm_{}0_h2h_weight'.format( + d)] = mx.random.uniform(shape=(size * 4, size)) + weights['lstm_{}0_i2h_bias'.format( + d)] = mx.random.uniform(shape=(size * 4,)) + weights['lstm_{}0_h2h_bias'.format( + d)] = mx.random.uniform(shape=(size * 4,)) + + net = gluon.rnn.LSTM(size, projection_size=proj_size, + bidirectional=True, prefix='lstm_') ref_net = RefBiLSTM(size, proj_size, prefix='lstm_') net.initialize() ref_net.initialize() @@ -201,16 +218,19 @@ def forward(self, inpt): ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) + ref_net_params[k.replace('l0', 'l0l0').replace( + 'r0', 'r0l0')].set_data(weights[k]) data = mx.random.uniform(shape=(11, 10, in_size)) assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy()) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_layer_bidirectional(): check_layer_bidirectional(7, 5, 0) + @with_seed() @assert_raises_cudnn_not_satisfied(min_version='7.2.1') def test_layer_bidirectional_proj(): @@ -221,7 +241,8 @@ def test_layer_bidirectional_proj(): @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnn_layer_begin_state_type(): fake_data = nd.random.uniform(shape=(3, 5, 7), dtype='float16') - modeling_layer = gluon.rnn.LSTM(hidden_size=11, num_layers=2, dropout=0.2, bidirectional=True) + modeling_layer = gluon.rnn.LSTM( + hidden_size=11, num_layers=2, dropout=0.2, bidirectional=True) modeling_layer.cast('float16') modeling_layer.initialize() modeling_layer(fake_data) @@ -229,9 +250,10 @@ def test_rnn_layer_begin_state_type(): def test_gluon_ctc_consistency(): loss = mx.gluon.loss.CTCLoss() - data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) - cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) - gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) + data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0) + ).reshape((2, 20, 4)).flip(axis=0) + cpu_label = mx.nd.array([[2, 1, -1, -1], [3, 2, 2, -1]], ctx=mx.cpu(0)) + gpu_label = mx.nd.array([[2, 1, -1, -1], [3, 2, 2, -1]], ctx=mx.gpu(0)) cpu_data = data.copy().as_in_context(mx.cpu(0)) cpu_data.attach_grad() @@ -245,15 +267,17 @@ def test_gluon_ctc_consistency(): l_gpu = loss(gpu_data, gpu_label) l_gpu.backward() - assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(cpu_data.grad.asnumpy(), + gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) @with_seed() def test_global_norm_clip_multi_device(): for check_isfinite in [True, False]: - x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) - x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0, check_isfinite=check_isfinite) + x1 = mx.nd.ones((3, 3), ctx=mx.gpu(0)) + x2 = mx.nd.ones((4, 4), ctx=mx.cpu(0)) + norm = gluon.utils.clip_global_norm( + [x1, x2], 1.0, check_isfinite=check_isfinite) if check_isfinite: assert norm == 5.0 else: @@ -264,6 +288,7 @@ def test_global_norm_clip_multi_device(): def _check_batchnorm_result(input, num_devices=1, cuda=False): from mxnet.gluon.utils import split_and_load + def _find_bn(module): if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): return module @@ -288,9 +313,10 @@ def _syncParameters(bn1, bn2, ctx): else: ctx_list = [mx.cpu(0) for _ in range(num_devices)] - nch = input.shape[1] + nch = input.shape[1] if input.ndim > 1 else 1 bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm( + in_channels=nch, num_devices=num_devices) bn1.initialize(ctx=ctx_list[0]) bn2.initialize(ctx=ctx_list) @@ -305,43 +331,46 @@ def _syncParameters(bn1, bn2, ctx): with mx.autograd.record(): output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] + output2 = [bn2(xi) for xi in inputs2] loss1 = (output1 ** 2).sum() loss2 = [(output ** 2).sum() for output in output2] mx.autograd.backward(loss1) mx.autograd.backward(loss2) - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + output2 = mx.nd.concat(*[output.as_in_context(input.context) + for output in output2], dim=0) # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), + output2.asnumpy(), atol=1e-3, rtol=1e-3) assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), atol=1e-3, rtol=1e-3) assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), atol=1e-3, rtol=1e-3) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat( + *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), + input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + @with_seed() def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i # no need to use SyncBN with 1 gpu - if get_num_devices() < 2: - return - ndev = 2 + if mx.context.num_gpus() < 2: + ndev = 1 + cuda = False + else: + ndev = 2 + cuda = True # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 3, 4, 4)), - num_devices=ndev, cuda=True) + for shape in [(4, 2), (4, 2, 4), (4, 2, 4, 4), (4, 2, 3, 4, 4)]: + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=shape, + ctx=mx.cpu()), + num_devices=ndev, cuda=cuda) @with_seed() @@ -354,10 +383,11 @@ def test_symbol_block_fp16(): tmpfile = os.path.join(tmp, 'resnet34_fp16') ctx = mx.gpu(0) - net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2(pretrained=True, ctx=ctx, root=tmp) + net_fp32 = mx.gluon.model_zoo.vision.resnet34_v2( + pretrained=True, ctx=ctx, root=tmp) net_fp32.cast('float16') net_fp32.hybridize() - data = mx.nd.zeros((1,3,224,224), dtype='float16', ctx=ctx) + data = mx.nd.zeros((1, 3, 224, 224), dtype='float16', ctx=ctx) net_fp32.forward(data) net_fp32.export(tmpfile, 0) @@ -391,7 +421,8 @@ def test_large_models(): # Compute the height (=width) of the square tensor of the given size in bytes def tensor_size(big_tensor_bytes): bytes_per_float = 4 - sz = int(math.sqrt(big_tensor_bytes / largest_num_features / bytes_per_float)) + sz = int(math.sqrt(big_tensor_bytes / + largest_num_features / bytes_per_float)) return (sz // 100) * 100 # The idea is to create models with large tensors of (say) 20% of the total memory. @@ -400,12 +431,13 @@ def tensor_size(big_tensor_bytes): (free_mem_bytes, total_mem_bytes) = mx.context.gpu_memory_info(ctx.device_id) start_size = tensor_size(0.20 * total_mem_bytes) num_trials = 10 - sys.stderr.write(' testing global memory of size {} ... '.format(total_mem_bytes)) + sys.stderr.write( + ' testing global memory of size {} ... '.format(total_mem_bytes)) sys.stderr.flush() for i in range(num_trials): sz = start_size - 10 * i - (height, width) = (sz,sz) - sys.stderr.write(" {}x{} ".format(height,width)) + (height, width) = (sz, sz) + sys.stderr.write(" {}x{} ".format(height, width)) sys.stderr.flush() data_in = nd.random_uniform(low=0, high=255, shape=(1, 3, height, width), ctx=ctx, dtype="float32") @@ -413,6 +445,8 @@ def tensor_size(big_tensor_bytes): net(data_in).asnumpy() # isolated execution bulking test function to be invoked with different env var settings + + def _test_bulking_in_process(seed, time_per_iteration): # Use flip since it's a simple function with same-sized I/O unlikely to ever be fused. class Flip(gluon.HybridBlock): @@ -442,7 +476,7 @@ def get_net(num_ops): # time a number of forward() and backward() executions after some warm-up iterations warmups = 1 - for i in range(num_iterations+warmups): + for i in range(num_iterations + warmups): with autograd.record(): if i == warmups: start = time.time() @@ -452,20 +486,22 @@ def get_net(num_ops): time_per_iteration.value = (time.time() - start) / num_iterations + @with_seed() def test_bulking(): # test case format: (max_fwd_segment_size, max_bwd_segment_size, enable_bulking_in_training) - test_cases = [(0,0,True), (1,1,True), (15,15,False), (15,0,True), (0,15,True), (15,15,True)] + test_cases = [(0, 0, True), (1, 1, True), (15, 15, False), + (15, 0, True), (0, 15, True), (15, 15, True)] times = {} times_str = '' for seg_sizes in test_cases: # Create shared variable to return measured time from test process time_per_iteration = mp.Manager().Value('d', 0.0) if not run_in_spawned_process(_test_bulking_in_process, - {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD' : seg_sizes[0], - 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD' : seg_sizes[1], - 'MXNET_EXEC_BULK_EXEC_TRAIN' : seg_sizes[2]}, - time_per_iteration): + {'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_FWD': seg_sizes[0], + 'MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD': seg_sizes[1], + 'MXNET_EXEC_BULK_EXEC_TRAIN': seg_sizes[2]}, + time_per_iteration): # skip test since the python version can't run it properly. Warning msg was logged. return times[seg_sizes] = time_per_iteration.value @@ -473,21 +509,22 @@ def test_bulking(): '\n runtime of (fwd,bwd,enable) op seg setting ({},{},{}) =\t{:.1f} msec'.format( seg_sizes[0], seg_sizes[1], seg_sizes[2], 1000.0 * times[seg_sizes]) - fastest_non_bulked_time = min(times[(0,0,True)], times[(1,1,True)], times[(15,15,False)]) - slowest_half_bulked_time = max(times[(0,15,True)], times[(15,0,True)]) - fastest_half_bulked_time = min(times[(0,15,True)], times[(15,0,True)]) - fully_bulked_time = times[(15,15,True)] + fastest_non_bulked_time = min( + times[(0, 0, True)], times[(1, 1, True)], times[(15, 15, False)]) + slowest_half_bulked_time = max(times[(0, 15, True)], times[(15, 0, True)]) + fastest_half_bulked_time = min(times[(0, 15, True)], times[(15, 0, True)]) + fully_bulked_time = times[(15, 15, True)] print(times_str) # Non-bulked times[0,0,True], times[1,1,True] and times[15,15,False] should be about the same, # slower than both half-bulked times[0,15,True] and times[15,0,True] assert slowest_half_bulked_time < fastest_non_bulked_time, \ 'A half-bulked exec time is slower than the non-bulked time by {} secs! {}' \ - .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) + .format(slowest_half_bulked_time - fastest_non_bulked_time, times_str) # The fully bulked times[15,15,True] should be faster than both half-bulked runs assert fully_bulked_time < fastest_half_bulked_time, \ 'The fully-bulked exec time is slower than a half-bulked time by {} secs! {}' \ - .format(fully_bulked_time - fastest_half_bulked_time, times_str) + .format(fully_bulked_time - fastest_half_bulked_time, times_str) if __name__ == '__main__': From 4729860153a71211119294234f161751555a2186 Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 11:26:32 +0800 Subject: [PATCH 04/25] retrigger CI From e7bd3bb12e9db6f5b3b1ceb9a8357e20bcad478c Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 11:27:45 +0800 Subject: [PATCH 05/25] update test case --- tests/python/gpu/test_gluon_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 4008db60f765..5306c7f10e4f 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -366,7 +366,7 @@ def test_sync_batchnorm(): ndev = 2 cuda = True # check with unsync version - for shape in [(4, 2), (4, 2, 4), (4, 2, 4, 4), (4, 2, 3, 4, 4)]: + for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu()), From 02415f98f05a81c44c882d7ba0e664f9f87eb3aa Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 12:49:01 +0800 Subject: [PATCH 06/25] test --- tests/python/gpu/test_gluon_gpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 5306c7f10e4f..673dab9db822 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -366,7 +366,8 @@ def test_sync_batchnorm(): ndev = 2 cuda = True # check with unsync version - for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + # for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for shape in [(4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu()), From 380e750f3d1530b6a42e1fbe9413a731cc9d6c30 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Thu, 28 Mar 2019 15:45:30 +0800 Subject: [PATCH 07/25] Retrigger CI From 46533903036eed21f34a24f3815bc43c8e2a94e3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 21:30:51 +0800 Subject: [PATCH 08/25] disable cudnn for batchnorm --- src/operator/nn/batch_norm.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 1199ec7fcce5..1f720dfce97b 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -668,7 +668,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 + if (false && !param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); From c8ad1a854bf4b697bbdcdd35f315c100648a7205 Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 21:38:28 +0800 Subject: [PATCH 09/25] fix BatchNorm(cudnn) --- src/operator/nn/batch_norm.cu | 2 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 13 +++++++------ tests/python/gpu/test_gluon_gpu.py | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 1f720dfce97b..1ff41226005a 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -668,7 +668,7 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - if (false && !param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 + if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index d4b9f84ed2f5..c664844eea4c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -273,12 +273,13 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - for (int i = 0; i < 4; ++i) { - if (i < in_data.ndim()) { - shape_[i] = in_data.shape_[i]; - } else { - shape_[i] = 1; - } + if (in_data.ndim() == 4) shape_ = in_data.shape_; + else { + // when in_data.ndim() != 4 + shape_[0] = in_data.shape_[0]; + shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; + shape_[2] = 1; + shape_[3] = in_data.Size() / (shape_[0] * shape_[1]); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 673dab9db822..5306c7f10e4f 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -366,8 +366,7 @@ def test_sync_batchnorm(): ndev = 2 cuda = True # check with unsync version - # for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: - for shape in [(4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu()), From c22b400dd2b7e4da26efb3efab22f265c25a320e Mon Sep 17 00:00:00 2001 From: wkcn Date: Thu, 28 Mar 2019 21:43:09 +0800 Subject: [PATCH 10/25] fix build --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index c664844eea4c..3126988a12e7 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -273,8 +273,10 @@ class CuDNNBatchNormOp { private: void Init(const TBlob &in_data) { - if (in_data.ndim() == 4) shape_ = in_data.shape_; - else { + if (in_data.ndim() == 4) { + for (int i = 0; i < 4; ++i) + shape_[i] = in_data.shape_[i]; + } else { // when in_data.ndim() != 4 shape_[0] = in_data.shape_[0]; shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; From f316e91f3fec881d733a154182d24dc691dc9adf Mon Sep 17 00:00:00 2001 From: JackieWu Date: Fri, 29 Mar 2019 00:29:27 +0800 Subject: [PATCH 11/25] Remove a testcase --- tests/python/gpu/test_gluon_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 5306c7f10e4f..152ab66422e4 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -366,7 +366,7 @@ def test_sync_batchnorm(): ndev = 2 cuda = True # check with unsync version - for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for shape in [(4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu()), From ea470be6269d7b9b833f22fe9997dad373af5479 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Fri, 29 Mar 2019 07:29:29 +0800 Subject: [PATCH 12/25] Update sync_batch_norm-inl.h From d4a118da48876405ab2500fa72a113c1f23c2743 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 14:25:04 +0800 Subject: [PATCH 13/25] update unittest --- src/operator/contrib/sync_batch_norm-inl.h | 1 - tests/python/gpu/test_gluon_gpu.py | 87 --------------------- tests/python/unittest/test_gluon.py | 85 +++++++++++++++++++++ tests/python/unittest/test_operator.py | 88 ++++++++++++++++++++++ 4 files changed, 173 insertions(+), 88 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 17339c996dcd..d32804b396b8 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -69,7 +69,6 @@ struct SyncBatchNormParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(ndev).set_default(1) .describe("The count of GPU devices"); DMLC_DECLARE_FIELD(key) - .set_default("") .describe("Hash key for synchronization, please set the same hash key for same layer, " "Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`."); } diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 5306c7f10e4f..1c5a5835e6f9 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -286,93 +286,6 @@ def test_global_norm_clip_multi_device(): assert_almost_equal(x2.asnumpy(), np.ones((4, 4)) / 5) -def _check_batchnorm_result(input, num_devices=1, cuda=False): - from mxnet.gluon.utils import split_and_load - - def _find_bn(module): - if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] if input.ndim > 1 else 1 - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm( - in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) - for output in output2], dim=0) - # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), - atol=1e-3, rtol=1e-3) - assert_almost_equal(output1.asnumpy(), - output2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - input2grad = mx.nd.concat( - *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), - input2grad.asnumpy(), atol=1e-3, rtol=1e-3) - - -@with_seed() -def test_sync_batchnorm(): - # no need to use SyncBN with 1 gpu - if mx.context.num_gpus() < 2: - ndev = 1 - cuda = False - else: - ndev = 2 - cuda = True - # check with unsync version - for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=shape, - ctx=mx.cpu()), - num_devices=ndev, cuda=cuda) - - @with_seed() def test_symbol_block_fp16(): # Test case to verify if initializing the SymbolBlock from a model with params diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6af7a5f948e2..de52bc18e64f 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -583,6 +583,91 @@ def test_batchnorm(): check_layer_forward(layer, (2, 10, 10, 10)) +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] if input.ndim > 1 else 1 + prefix = str(input.shape) + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch, prefix='bn_' + prefix) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm( + in_channels=nch, num_devices=num_devices, prefix='sync_bn_' + prefix) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) + for output in output2], dim=0) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), + output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat( + *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), + input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + + +@with_seed() +def test_sync_batchnorm(): + cfgs = [(1, False)] + num_gpus = mx.context.num_gpus() + for i in range(1, num_gpus + 1): + cfgs.append((i, True)) + for ndev, cuda in cfgs: + # check with unsync version + for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=shape), + num_devices=ndev, cuda=cuda) + + @with_seed() def test_instancenorm(): layer = nn.InstanceNorm(in_channels=10) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index c9498ecb0bd2..f3f66c23ac1c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1598,6 +1598,94 @@ def check_batchnorm_training(stype): check_batchnorm_training('default') +@with_seed() +def test_batchnorm(): + momentum = 0.9 + epsilon = 1e-5 + for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: + for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for axis in range(len(shape)): + kwargs = dict() + if op == mx.nd.contrib.SyncBatchNorm: + if axis != 1: + continue + key = str(op) + str(shape) + str(axis) + kwargs.update(dict(key=key)) + else: + kwargs.update(dict(axis=axis)) + nch = shape[axis] + + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() + + bn_beta = mx.nd.random.uniform(shape=(nch,)) + bn_beta.attach_grad() + + bn_running_mean = mx.nd.zeros(nch) + bn_running_var = mx.nd.ones(nch) + + running_mean = mx.nd.zeros(nch) + running_var = mx.nd.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + for _ in range(num_iters): + data = mx.nd.random.uniform(shape=shape) + data.attach_grad() + ograd = mx.nd.random.uniform(shape=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, + momentum=momentum, eps=epsilon, fix_gamma=False, **kwargs) + output.backward(ograd) + mx.nd.waitall() + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, keepdims=True) + + target_output = (data - data_mean) / (data_var + epsilon).sqrt() * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / mx.nd.sqrt(data_var + epsilon) + nx = xsm * nd + m = np.prod(shape) / shape[axis] + dvar = (dnx * xsm).sum(axis=axis, keepdims=True, + exclude=True) * (-0.5) * mx.nd.power(nd, 3.0) + dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ + dvar * xsm.mean(axis=axis, keepdims=True, + exclude=True) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = (ograd * nx).sum(axis=axis, exclude=True) + db = ograd.sum(axis=axis, exclude=True) + + #assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=1e-3, rtol=1e-3) + + assert_almost_equal(data.grad.asnumpy(), + dX.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal( + bn_beta.grad.asnumpy(), db.asnumpy(), atol=1e-3, rtol=1e-3) + + @with_seed() def test_convolution_grouping(): for dim in [1, 2, 3]: From d1a7787f6f250c7de04c4634e534253f2d6a34bf Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 14:49:00 +0800 Subject: [PATCH 14/25] update unittest --- src/operator/contrib/sync_batch_norm-inl.h | 8 ++++---- tests/python/unittest/test_gluon.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index d32804b396b8..1e6ab25db0e2 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -280,8 +280,8 @@ class SyncBatchNorm : public Operator { } else { index_t num_channels = in_data[syncbatchnorm::kData].ndim() > 1 ? in_data[syncbatchnorm::kData].shape_[1] : 1; - index_t spatial_size = in_data[syncbatchnorm::kData].Size() / ( - in_data[syncbatchnorm::kData].shape_[0] * num_channels); + index_t spatial_size = in_data[syncbatchnorm::kData].shape_.ProdShape(2, + in_data[syncbatchnorm::kData].ndim()); Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); @@ -364,8 +364,8 @@ class SyncBatchNorm : public Operator { } else { index_t num_channels = out_grad[syncbatchnorm::kOut].ndim() > 1 ? out_grad[syncbatchnorm::kOut].shape_[1] : 1; - index_t spatial_size = out_grad[syncbatchnorm::kOut].Size() / ( - out_grad[syncbatchnorm::kOut].shape_[0] * num_channels); + index_t spatial_size = out_grad[syncbatchnorm::kOut].shape_.ProdShape(2, + out_grad[syncbatchnorm::kOut].ndim()); Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], num_channels, 1, spatial_size); data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index de52bc18e64f..06e4eb767d8b 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -664,7 +664,8 @@ def test_sync_batchnorm(): # check with unsync version for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=shape), + _check_batchnorm_result(mx.nd.random.uniform(shape=shape, + ctx=mx.cpu(0)), num_devices=ndev, cuda=cuda) From 2a3ba52622deaae9f302c23f2f8ad3c90730a824 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 18:17:15 +0800 Subject: [PATCH 15/25] update test --- tests/python/unittest/test_gluon.py | 2 +- tests/python/unittest/test_operator.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 06e4eb767d8b..e398a5e3f0f1 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -662,7 +662,7 @@ def test_sync_batchnorm(): cfgs.append((i, True)) for ndev, cuda in cfgs: # check with unsync version - for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu(0)), diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index f3f66c23ac1c..21b8b40e7aae 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1603,7 +1603,7 @@ def test_batchnorm(): momentum = 0.9 epsilon = 1e-5 for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: - for shape in [(4, 2), (4, 3, 4), (4, 4, 4, 4), (4, 5, 6, 4, 4)]: + for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): kwargs = dict() if op == mx.nd.contrib.SyncBatchNorm: @@ -1672,7 +1672,7 @@ def test_batchnorm(): dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) - #assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3) assert_almost_equal(bn_running_mean.asnumpy( ), running_mean.asnumpy(), atol=1e-3, rtol=1e-3) assert_almost_equal(bn_running_var.asnumpy( From 6d6142f04cc5ea98f6d1e070be74af2e2ac51d0d Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 19:27:24 +0800 Subject: [PATCH 16/25] fix test --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 2 +- tests/python/unittest/test_gluon.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 3126988a12e7..6f8b82352853 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -281,7 +281,7 @@ class CuDNNBatchNormOp { shape_[0] = in_data.shape_[0]; shape_[1] = in_data.ndim() > 1 ? in_data.shape_[1] : 1; shape_[2] = 1; - shape_[3] = in_data.Size() / (shape_[0] * shape_[1]); + shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); } CUDNN_CALL(cudnnSetTensor4dDescriptor(io_desc_, diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index e398a5e3f0f1..e22bfdc21e77 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -611,10 +611,9 @@ def _syncParameters(bn1, bn2, ctx): ctx_list = [mx.cpu(0) for _ in range(num_devices)] nch = input.shape[1] if input.ndim > 1 else 1 - prefix = str(input.shape) - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch, prefix='bn_' + prefix) + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) bn2 = mx.gluon.contrib.nn.SyncBatchNorm( - in_channels=nch, num_devices=num_devices, prefix='sync_bn_' + prefix) + in_channels=nch, num_devices=num_devices) bn1.initialize(ctx=ctx_list[0]) bn2.initialize(ctx=ctx_list) From 659f1dbde0f3f3c4802057e81ae0b4dcf02d5d64 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 19:31:43 +0800 Subject: [PATCH 17/25] change atol and rtol --- tests/python/unittest/test_operator.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 21b8b40e7aae..0adf00452100 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1672,18 +1672,20 @@ def test_batchnorm(): dW = (ograd * nx).sum(axis=axis, exclude=True) db = ograd.sum(axis=axis, exclude=True) - assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=1e-3, rtol=1e-3) + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=1e-3, rtol=1e-3) + ), running_mean.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=1e-3, rtol=1e-3) + ), running_var.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(data.grad.asnumpy(), - dX.asnumpy(), atol=1e-3, rtol=1e-3) + dX.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=1e-3, rtol=1e-3) + bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal( - bn_beta.grad.asnumpy(), db.asnumpy(), atol=1e-3, rtol=1e-3) + bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) @with_seed() From 36f930d45d3b73cd6af6f7da6833a1dde2f56eb3 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 20:38:03 +0800 Subject: [PATCH 18/25] BN(cudnn) 5d --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 6f8b82352853..820f8504d74c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -84,7 +84,6 @@ class CuDNNBatchNormOp { } CHECK_EQ(req[cudnnbatchnorm::kOut], kWriteTo); CHECK_GE(in_data[cudnnbatchnorm::kData].ndim(), 2); - CHECK_LE(in_data[cudnnbatchnorm::kData].ndim(), 4); Init(in_data[cudnnbatchnorm::kData]); Stream *s = ctx.get_stream(); From 9ec1f51161b896124994bcc92af0b5f45d8aa81b Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 21:38:54 +0800 Subject: [PATCH 19/25] update test --- tests/python/unittest/test_gluon.py | 45 +++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index e22bfdc21e77..e163593ab4be 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -636,21 +636,56 @@ def _syncParameters(bn1, bn2, ctx): output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # check bn1 + + momentum = 0.9 + epsilon = 1e-5 + axis = 1 + running_mean = mx.nd.zeros(nch) + running_var = mx.nd.ones(nch) + data = input1 + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, keepdims=True) + + target_output = (data - data_mean) / (data_var + epsilon).sqrt() + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(output1.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + running_mean.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + running_var.asnumpy(), + atol=atol, rtol=rtol) # assert forwarding assert_almost_equal(input1.asnumpy(), input2.asnumpy(), - atol=1e-3, rtol=1e-3) + atol=atol, rtol=rtol) assert_almost_equal(output1.asnumpy(), - output2.asnumpy(), atol=1e-3, rtol=1e-3) + output2.asnumpy(), atol=atol, rtol=rtol) assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) + atol=atol, rtol=rtol) assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) + atol=atol, rtol=rtol) input2grad = mx.nd.concat( *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) assert_almost_equal(input1.grad.asnumpy(), - input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + input2grad.asnumpy(), atol=atol, rtol=rtol) @with_seed() From 904f5bd2e0c3c45f808991fc4eed229ab18ec749 Mon Sep 17 00:00:00 2001 From: wkcn Date: Fri, 29 Mar 2019 22:48:50 +0800 Subject: [PATCH 20/25] test --- tests/python/unittest/test_gluon.py | 7 +++++-- tests/python/unittest/test_operator.py | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index e163593ab4be..f962182bf5e4 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -641,9 +641,9 @@ def _syncParameters(bn1, bn2, ctx): momentum = 0.9 epsilon = 1e-5 axis = 1 - running_mean = mx.nd.zeros(nch) - running_var = mx.nd.ones(nch) data = input1 + running_mean = mx.nd.zeros(nch, ctx=data.context) + running_var = mx.nd.ones(nch, ctx=data.context) data_mean = data.mean( axis=axis, exclude=True, keepdims=True) @@ -690,6 +690,7 @@ def _syncParameters(bn1, bn2, ctx): @with_seed() def test_sync_batchnorm(): + import logging cfgs = [(1, False)] num_gpus = mx.context.num_gpus() for i in range(1, num_gpus + 1): @@ -697,6 +698,8 @@ def test_sync_batchnorm(): for ndev, cuda in cfgs: # check with unsync version for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: + logging.info(str((ndev, cuda, shape))) + print(str((ndev, cuda, shape))) for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, ctx=mx.cpu(0)), diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0adf00452100..3ed1f3147749 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1600,11 +1600,14 @@ def check_batchnorm_training(stype): @with_seed() def test_batchnorm(): + import logging momentum = 0.9 epsilon = 1e-5 for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): + logging.info(str((op, shape, axis))) + print(str((op, shape, axis))) kwargs = dict() if op == mx.nd.contrib.SyncBatchNorm: if axis != 1: From e238132e6be6b22addd486dafb65a6f14c2f42d0 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Sat, 30 Mar 2019 00:16:11 +0800 Subject: [PATCH 21/25] Testing --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 820f8504d74c..6c5be8b1acb0 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -282,6 +282,7 @@ class CuDNNBatchNormOp { shape_[2] = 1; shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); } + LOG(INFO)< Date: Sat, 30 Mar 2019 11:02:51 +0800 Subject: [PATCH 22/25] Update batch_norm.cu --- src/operator/nn/batch_norm.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 1ff41226005a..68eebbc2abc3 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -673,15 +673,18 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) + LOG(INFO) << "cudnn"; } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }) + LOG(INFO) << "fwd1"; } #else MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }); + LOG(INFO) << "fwd2"; #endif } @@ -697,7 +700,7 @@ void BatchNormGradCompute(const nnvm::NodeAttrs& attrs, param.axis = mxnet::op::batchnorm::GetRealAxis(shape, param.axis); #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 - if (!param.use_global_stats && !param.cudnn_off && shape.ndim() <= 4 + if (!param.use_global_stats && !param.cudnn_off && param.axis == mxnet::op::batchnorm::DEFAULT_AXIS) { MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Backward(ctx, inputs, req, outputs); From 784b77dac8239e4739968d261e310f3b12fc1c34 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Sat, 30 Mar 2019 13:23:53 +0800 Subject: [PATCH 23/25] test cudnnoff --- tests/python/unittest/test_operator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3ed1f3147749..b0dcc3f3edf6 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1608,7 +1608,7 @@ def test_batchnorm(): for axis in range(len(shape)): logging.info(str((op, shape, axis))) print(str((op, shape, axis))) - kwargs = dict() + kwargs = dict(cudnn_off=True) if op == mx.nd.contrib.SyncBatchNorm: if axis != 1: continue From 767efeb993921aa11765545ad3cb33c4130570ce Mon Sep 17 00:00:00 2001 From: JackieWu Date: Sat, 30 Mar 2019 14:15:38 +0800 Subject: [PATCH 24/25] Update test_operator.py --- tests/python/unittest/test_operator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b0dcc3f3edf6..0945515c0300 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1608,14 +1608,14 @@ def test_batchnorm(): for axis in range(len(shape)): logging.info(str((op, shape, axis))) print(str((op, shape, axis))) - kwargs = dict(cudnn_off=True) + kwargs = dict() if op == mx.nd.contrib.SyncBatchNorm: if axis != 1: continue key = str(op) + str(shape) + str(axis) kwargs.update(dict(key=key)) else: - kwargs.update(dict(axis=axis)) + kwargs.update(dict(axis=axis, cudnn_off=True)) nch = shape[axis] bn_gamma = mx.nd.random.uniform(shape=(nch,)) From 993e8e385db801a353dd6e83a55c87441f041580 Mon Sep 17 00:00:00 2001 From: wkcn Date: Sat, 30 Mar 2019 23:32:01 +0800 Subject: [PATCH 25/25] update BN! : ) --- src/operator/nn/batch_norm.cu | 3 - src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 1 - tests/python/unittest/test_gluon.py | 193 +++++++++--------- tests/python/unittest/test_operator.py | 197 +++++++++++-------- 4 files changed, 208 insertions(+), 186 deletions(-) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 68eebbc2abc3..9fb44e8fae81 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -673,18 +673,15 @@ void BatchNormCompute(const nnvm::NodeAttrs& attrs, MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { GetCuDNNOp(param).Forward(ctx, in_data, req, outputs, aux_states); }) - LOG(INFO) << "cudnn"; } else { MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DType, AccReal, { BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }) - LOG(INFO) << "fwd1"; } #else MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { BatchNormForward(ctx, param, in_data, req, outputs, aux_states); }); - LOG(INFO) << "fwd2"; #endif } diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 6c5be8b1acb0..820f8504d74c 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -282,7 +282,6 @@ class CuDNNBatchNormOp { shape_[2] = 1; shape_[3] = in_data.shape_.ProdShape(2, in_data.ndim()); } - LOG(INFO)< 1 else 1 - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm( - in_channels=nch, num_devices=num_devices) + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) + nch = input.shape[1] if input.ndim > 1 else 1 + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm( + in_channels=nch, num_devices=num_devices) - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) - for output in output2], dim=0) - # check bn1 - - momentum = 0.9 - epsilon = 1e-5 - axis = 1 - data = input1 - running_mean = mx.nd.zeros(nch, ctx=data.context) - running_var = mx.nd.ones(nch, ctx=data.context) - - data_mean = data.mean( - axis=axis, exclude=True, keepdims=True) - data_var = (data - data_mean).square().mean(axis=axis, - exclude=True, keepdims=True) - - target_output = (data - data_mean) / (data_var + epsilon).sqrt() - - # squeeze data_mean and data_var - data_mean_flat = data_mean.squeeze() - data_var_flat = data_var.squeeze() - - running_mean = running_mean * momentum + \ - data_mean_flat * (1 - momentum) - running_var = running_var * momentum + \ - data_var_flat * (1 - momentum) - - atol = 1e-2 - rtol = 1e-2 - assert_almost_equal(output1.asnumpy(), target_output.asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - running_mean.asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - running_var.asnumpy(), - atol=atol, rtol=rtol) - # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(output1.asnumpy(), - output2.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=atol, rtol=rtol) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=atol, rtol=rtol) - input2grad = mx.nd.concat( - *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), - input2grad.asnumpy(), atol=atol, rtol=rtol) + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) + for output in output2], dim=0) + # check bn1 + + momentum = 0.9 + epsilon = 1e-5 + axis = 1 + data = input1 + running_mean = mx.nd.zeros(nch, ctx=data.context) + running_var = mx.nd.ones(nch, ctx=data.context) + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, keepdims=True) + + target_output = (data - data_mean) / (data_var + epsilon).sqrt() + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + atol = 1e-2 + rtol = 1e-2 + assert_almost_equal(output1.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + running_mean.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + running_var.asnumpy(), + atol=atol, rtol=rtol) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output1.asnumpy(), + output2.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=atol, rtol=rtol) + input2grad = mx.nd.concat( + *[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), + input2grad.asnumpy(), atol=atol, rtol=rtol) -@with_seed() -def test_sync_batchnorm(): - import logging cfgs = [(1, False)] num_gpus = mx.context.num_gpus() for i in range(1, num_gpus + 1): @@ -698,7 +696,6 @@ def test_sync_batchnorm(): for ndev, cuda in cfgs: # check with unsync version for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: - logging.info(str((ndev, cuda, shape))) print(str((ndev, cuda, shape))) for i in range(10): _check_batchnorm_result(mx.nd.random.uniform(shape=shape, diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0945515c0300..845ae113c218 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1600,95 +1600,124 @@ def check_batchnorm_training(stype): @with_seed() def test_batchnorm(): - import logging momentum = 0.9 epsilon = 1e-5 + + def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): + print(str((op, shape, axis, cudnn_off))) + + kwargs = dict(output_mean_var=output_mean_var) + if op == mx.nd.contrib.SyncBatchNorm: + if axis != 1: + return + key = str(op) + str(shape) + str(axis) + kwargs.update(dict(key=key)) + if cudnn_off: + return + else: + kwargs.update(dict(axis=axis, cudnn_off=cudnn_off)) + nch = shape[axis] + + bn_gamma = mx.nd.random.uniform(shape=(nch,)) + bn_gamma.attach_grad() + + bn_beta = mx.nd.random.uniform(shape=(nch,)) + bn_beta.attach_grad() + + bn_running_mean = mx.nd.zeros(nch) + bn_running_var = mx.nd.ones(nch) + + running_mean = mx.nd.zeros(nch) + running_var = mx.nd.ones(nch) + num_iters = 10 + expand_shape = [1] * len(shape) + expand_shape[axis] = shape[axis] + for _ in range(num_iters): + data = mx.nd.random.uniform(shape=shape) + data.attach_grad() + ograd = mx.nd.random.uniform(shape=shape) + with mx.autograd.record(): + output = op(data, bn_gamma, bn_beta, + bn_running_mean, bn_running_var, + momentum=momentum, eps=epsilon, + fix_gamma=False, **kwargs) + if output_mean_var: + output, output_mean, output_std = output + output.backward(ograd) + mx.nd.waitall() + + data_mean = data.mean( + axis=axis, exclude=True, keepdims=True) + data_var = (data - data_mean).square().mean(axis=axis, + exclude=True, + keepdims=True) + + target_output = (data - data_mean) / \ + (data_var + epsilon).sqrt() * \ + bn_gamma.reshape(expand_shape) + \ + bn_beta.reshape(expand_shape) + + # squeeze data_mean and data_var + data_mean_flat = data_mean.squeeze() + data_var_flat = data_var.squeeze() + + running_mean = running_mean * momentum + \ + data_mean_flat * (1 - momentum) + running_var = running_var * momentum + \ + data_var_flat * (1 - momentum) + + W = bn_gamma.reshape(expand_shape) + dnx = ograd * W + xsm = data - data_mean + nd = 1.0 / mx.nd.sqrt(data_var + epsilon) + nx = xsm * nd + m = np.prod(shape) / shape[axis] + dvar = (dnx * xsm).sum(axis=axis, keepdims=True, + exclude=True) * (-0.5) * mx.nd.power(nd, 3) + dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ + dvar * xsm.mean(axis=axis, keepdims=True, + exclude=True) * 2.0 + dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) + dW = (ograd * nx).sum(axis=axis, exclude=True) + db = ograd.sum(axis=axis, exclude=True) + + atol = 1e-2 + rtol = 1e-2 + + if output_mean_var: + assert_almost_equal(output_mean.asnumpy(), + data_mean_flat.asnumpy(), + atol=atol, rtol=rtol) + if op != mx.nd.contrib.SyncBatchNorm: + assert_almost_equal(output_std.asnumpy(), + (1.0 / (data_var_flat + + epsilon).sqrt()).asnumpy(), + atol=atol, rtol=rtol) + else: + assert_almost_equal(output_std.asnumpy(), + data_var_flat.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(output.asnumpy(), target_output.asnumpy(), + atol=atol, rtol=rtol) + assert_almost_equal(bn_running_mean.asnumpy( + ), running_mean.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal(bn_running_var.asnumpy( + ), running_var.asnumpy(), atol=atol, rtol=rtol) + + assert_almost_equal(data.grad.asnumpy(), + dX.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) + assert_almost_equal( + bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]: for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]: for axis in range(len(shape)): - logging.info(str((op, shape, axis))) - print(str((op, shape, axis))) - kwargs = dict() - if op == mx.nd.contrib.SyncBatchNorm: - if axis != 1: - continue - key = str(op) + str(shape) + str(axis) - kwargs.update(dict(key=key)) - else: - kwargs.update(dict(axis=axis, cudnn_off=True)) - nch = shape[axis] - - bn_gamma = mx.nd.random.uniform(shape=(nch,)) - bn_gamma.attach_grad() - - bn_beta = mx.nd.random.uniform(shape=(nch,)) - bn_beta.attach_grad() - - bn_running_mean = mx.nd.zeros(nch) - bn_running_var = mx.nd.ones(nch) - - running_mean = mx.nd.zeros(nch) - running_var = mx.nd.ones(nch) - num_iters = 10 - expand_shape = [1] * len(shape) - expand_shape[axis] = shape[axis] - for _ in range(num_iters): - data = mx.nd.random.uniform(shape=shape) - data.attach_grad() - ograd = mx.nd.random.uniform(shape=shape) - with mx.autograd.record(): - output = op(data, bn_gamma, bn_beta, bn_running_mean, bn_running_var, - momentum=momentum, eps=epsilon, fix_gamma=False, **kwargs) - output.backward(ograd) - mx.nd.waitall() - - data_mean = data.mean( - axis=axis, exclude=True, keepdims=True) - data_var = (data - data_mean).square().mean(axis=axis, - exclude=True, keepdims=True) - - target_output = (data - data_mean) / (data_var + epsilon).sqrt() * \ - bn_gamma.reshape(expand_shape) + \ - bn_beta.reshape(expand_shape) - - # squeeze data_mean and data_var - data_mean_flat = data_mean.squeeze() - data_var_flat = data_var.squeeze() - - running_mean = running_mean * momentum + \ - data_mean_flat * (1 - momentum) - running_var = running_var * momentum + \ - data_var_flat * (1 - momentum) - - W = bn_gamma.reshape(expand_shape) - dnx = ograd * W - xsm = data - data_mean - nd = 1.0 / mx.nd.sqrt(data_var + epsilon) - nx = xsm * nd - m = np.prod(shape) / shape[axis] - dvar = (dnx * xsm).sum(axis=axis, keepdims=True, - exclude=True) * (-0.5) * mx.nd.power(nd, 3.0) - dmean = -nd * dnx.sum(axis=axis, keepdims=True, exclude=True) - \ - dvar * xsm.mean(axis=axis, keepdims=True, - exclude=True) * 2.0 - dX = dnx * nd + dvar * xsm * (2.0 / m) + dmean * (1.0 / m) - dW = (ograd * nx).sum(axis=axis, exclude=True) - db = ograd.sum(axis=axis, exclude=True) - - atol = 1e-2 - rtol = 1e-2 - assert_almost_equal(output.asnumpy(), target_output.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_mean.asnumpy( - ), running_mean.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal(bn_running_var.asnumpy( - ), running_var.asnumpy(), atol=atol, rtol=rtol) - - assert_almost_equal(data.grad.asnumpy(), - dX.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_gamma.grad.asnumpy(), dW.asnumpy(), atol=atol, rtol=rtol) - assert_almost_equal( - bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol) + for cudnn_off in [False, True]: + for output_mean_var in [False, True]: + _test_batchnorm_impl(op, shape, axis, + cudnn_off, output_mean_var) @with_seed()