Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Aggregate SGD #13346

Merged
merged 16 commits into from
Jan 24, 2019
Merged

Aggregate SGD #13346

merged 16 commits into from
Jan 24, 2019

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Nov 21, 2018

Description

Currently MXNet optimizers are invoked 1 weight at a time. This leads to a lot of synchronization overhead, as updates (especially for convolutions and batchnorm) tend to be small, but each one needs to by synchronized upon.

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • ability to control update_on_kvstore value via environment variable MXNET_UPDATE_ON_KVSTORE (default is 1, which is consistent with the current behavior)
  • if update_on_kvstore is False, in the case of SGD optimizer it attempts to bundle updates of multiple weights together and launches a single kernel to perform them all, reducing the number of kernel calls and synchronizations.

Comments

  • Current test_sgd automatically uses the new code paths, so no new tests are needed.
  • Code does not support sparse arrays and it will fall back to not aggregated calls when it encounters sparse array in the bundle of weights/gradients

@stu1130
Copy link
Contributor

stu1130 commented Nov 21, 2018

@mxnet-label-bot add [pr-awaiting-review]
Thanks for your contribution @ptrendx

@marcoabreu marcoabreu added the pr-awaiting-review PR is waiting for code review label Nov 21, 2018
@anirudhacharya
Copy link
Member

anirudhacharya commented Nov 21, 2018

@ptrendx can you share a benchmark on SGD performance when MXNET_UPDATE_ON_KVSTORE is set for aggregate SGD vs when when it is not.

@ptrendx
Copy link
Member Author

ptrendx commented Nov 21, 2018

This PR is part of upstreaming improvements to MXNet that are available in NVIDIA's NGC 18.11 MXNet container. I will use results from that container to show the impact once all the other improvements are in place. The benchmark shown is ResNet v1.5 training on single V100 32GB in DGX1-V, batch size 32.

  1. MXNET_UPDATE_ON_KVSTORE=1 (default)
root@dgx1v-loki-19:/opt/mxnet/example/image-classification# numactl --physcpubind=0-4 ./train_imagenet_runner -n 1 -b 32 --network resnet-v1b --disp-batches 50 -e 1 --no-val -s 12800                             
INFO:root:start with arguments Namespace(batch_size=32, batchnorm_eps=2e-05, batchnorm_layout='NHWC', batchnorm_mom=0.9, benchmark=0, bn_gamma_init0=False, brightness=0.4, contrast=0.4, conv_algo=-1, conv_layou$
='NHWC', custom_bn_off=0, dali_nvjpeg_memory_padding=16, dali_prefetch_queue=3, dali_threads=3, data_nthreads=40, data_train='/data/imagenet/train-480-val-256-recordio/train.rec', data_train_idx='/data/imagenet$
train-480-val-256-recordio/train.idx', data_val=None, data_val_idx='', disp_batches=50, dtype='float16', epoch_size=0, fill_value=127, force_tensor_core=0, fuse_bn_add_relu=1, fuse_bn_relu=1, gc_threshold=0.5, $
c_type='none', gpus='0', image_shape='4,224,224', initializer='default', input_layout='NCHW', kv_store='device', load_epoch=None, log='', logging_dir='logs', loss='', lr=0.0125, lr_factor=0.1, lr_step_epochs='3$
,60,80', macrobatch_size=0, max_crop_size=-1, max_random_area=1.0, max_random_aspect_ratio=1.33, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1.0, max_random_shear$
ratio=0.0, min_crop_size=-1, min_random_area=0.08, min_random_aspect_ratio=0.75, min_random_scale=1.0, model_prefix=None, mom=0.9, monitor=0, network='resnet-v1b-fl', num_classes=1000, num_epochs=1, num_example$
=12800, num_layers=50, optimizer='sgd', pad_size=0, pca_noise=0.0, pooling_layout='NHWC', profile_server_suffix='', profile_worker_suffix='', random_crop=0, random_mirror=1, random_resized_crop=1, resize=256, r$
b_mean='123.68,116.779,103.939', rgb_std='1,1,1', saturation=0.4, save_period=1, seed=None, separ_val=False, set_data_aug_level=None, set_resnet_aug=None, test_io=0, top_k=0, use_dali=True, verbose=0, warmup_ep$
chs=5, warmup_strategy='linear', wd=0.0001)
/opt/mxnet/example/image-classification/common/dali.py:142: UserWarning: 12800 training examples will be used, although full training set contains 1281167 examples
  warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
[17:04:56] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:119: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 t$
 disable)
INFO:root:Epoch[0] Batch [50]   Speed: 912.56 samples/sec lr:0.000313   accuracy=0.000613
INFO:root:Epoch[0] Batch [100]  Speed: 922.14 samples/sec lr:0.000625   accuracy=0.000625
INFO:root:Epoch[0] Batch [150]  Speed: 919.71 samples/sec lr:0.000937   accuracy=0.000625
INFO:root:Epoch[0] Batch [200]  Speed: 924.12 samples/sec lr:0.001250   accuracy=0.001875
INFO:root:Epoch[0] Batch [250]  Speed: 922.34 samples/sec lr:0.001563   accuracy=0.000625
INFO:root:Epoch[0] Batch [300]  Speed: 923.93 samples/sec lr:0.001875   accuracy=0.000625
INFO:root:Epoch[0] Batch [350]  Speed: 923.90 samples/sec lr:0.002188   accuracy=0.002500
INFO:root:Epoch[0] Train-accuracy=0.001276
INFO:root:Epoch[0] Time cost=15.579
  1. MXNET_UPDATE_ON_KVSTORE=0
    MXNET_OPTIMIZER_AGGREGATION_SIZE=1 (no aggregation)
    Speedup here comes from lack unnecessary (in single GPU case) broadcast call in the kvstore.
root@dgx1v-loki-19:/opt/mxnet/example/image-classification# numactl --physcpubind=0-4 ./train_imagenet_runner -n 1 -b 32 --network resnet-v1b --disp-batches 50 -e 1 --no-val -s 12800                             
INFO:root:start with arguments Namespace(batch_size=32, batchnorm_eps=2e-05, batchnorm_layout='NHWC', batchnorm_mom=0.9, benchmark=0, bn_gamma_init0=False, brightness=0.4, contrast=0.4, conv_algo=-1, conv_layout='NHWC', custom_bn_off=0, dali_nvjpeg_memory_padding=16, dali_prefetch_queue=3, dali_threads=3, data_nthreads=40, data_train='/data/imagenet/train-480-val-256-recordio/train.rec', data_train_idx='/data/imagenet/train-480-val-256-recordio/train.idx', data_val=None, data_val_idx='', disp_batches=50, dtype='float16', epoch_size=0, fill_value=127, force_tensor_core=0, fuse_bn_add_relu=1, fuse_bn_relu=1, gc_threshold=0.5, gc_type='none', gpus='0', image_shape='4,224,224', initializer='default', input_layout='NCHW', kv_store='device', load_epoch=None, log='', logging_dir='logs', loss='', lr=0.0125, lr_factor=0.1, lr_step_epochs='30,60,80', macrobatch_size=0, max_crop_size=-1, max_random_area=1.0, max_random_aspect_ratio=1.33, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1.0, max_random_shear_ratio=0.0, min_crop_size=-1, min_random_area=0.08, min_random_aspect_ratio=0.75, min_random_scale=1.0, model_prefix=None, mom=0.9, monitor=0, network='resnet-v1b-fl', num_classes=1000, num_epochs=1, num_examples=12800, num_layers=50, optimizer='sgd', pad_size=0, pca_noise=0.0, pooling_layout='NHWC', profile_server_suffix='', profile_worker_suffix='', random_crop=0, random_mirror=1, random_resized_crop=1, resize=256, rgb_mean='123.68,116.779,103.939', rgb_std='1,1,1', saturation=0.4, save_period=1, seed=None, separ_val=False, set_data_aug_level=None, set_resnet_aug=None, test_io=0, top_k=0, use_dali=True, verbose=0, warmup_epochs=5, warmup_strategy='linear', wd=0.0001)
/opt/mxnet/example/image-classification/common/dali.py:142: UserWarning: 12800 training examples will be used, although full training set contains 1281167 examples
  warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
[17:12:43] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:119: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:Epoch[0] Batch [50]   Speed: 959.50 samples/sec lr:0.000313   accuracy=0.000613
INFO:root:Epoch[0] Batch [100]  Speed: 968.80 samples/sec lr:0.000625   accuracy=0.000625
INFO:root:Epoch[0] Batch [150]  Speed: 966.11 samples/sec lr:0.000937   accuracy=0.000625
INFO:root:Epoch[0] Batch [200]  Speed: 969.05 samples/sec lr:0.001250   accuracy=0.001875
INFO:root:Epoch[0] Batch [250]  Speed: 971.04 samples/sec lr:0.001563   accuracy=0.000625
INFO:root:Epoch[0] Batch [300]  Speed: 971.68 samples/sec lr:0.001875   accuracy=0.000625
INFO:root:Epoch[0] Batch [350]  Speed: 971.70 samples/sec lr:0.002188   accuracy=0.002500
INFO:root:Epoch[0] Train-accuracy=0.001276
INFO:root:Epoch[0] Time cost=14.874
  1. MXNET_UPDATE_ON_KVSTORE=0
    MXNET_OPTIMIZER_AGGREGATION_SIZE=4 (default in this PR)
root@dgx1v-loki-19:/opt/mxnet/example/image-classification# numactl --physcpubind=0-4 ./train_imagenet_runner -n 1 -b 32 --network resnet-v1b --disp-batches 50 -e 1 --no-val -s 12800
INFO:root:start with arguments Namespace(batch_size=32, batchnorm_eps=2e-05, batchnorm_layout='NHWC', batchnorm_mom=0.9, benchmark=0, bn_gamma_init0=False, brightness=0.4, contrast=0.4, conv_algo=-1, conv_layout='NHWC', custom_bn_off=0, dali_nvjpeg_memory_padding=16, dali_prefetch_queue=3, dali_threads=3, data_nthreads=40, data_train='/data/imagenet/train-480-val-256-recordio/train.rec', data_train_idx='/data/imagenet/train-480-val-256-recordio/train.idx', data_val=None, data_val_idx='', disp_batches=50, dtype='float16', epoch_size=0, fill_value=127, force_tensor_core=0, fuse_bn_add_relu=1, fuse_bn_relu=1, gc_threshold=0.5, gc_type='none', gpus='0', image_shape='4,224,224', initializer='default', input_layout='NCHW', kv_store='device', load_epoch=None, log='', logging_dir='logs', loss='', lr=0.0125, lr_factor=0.1, lr_step_epochs='30,60,80', macrobatch_size=0, max_crop_size=-1, max_random_area=1.0, max_random_aspect_ratio=1.33, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1.0, max_random_shear_ratio=0.0, min_crop_size=-1, min_random_area=0.08, min_random_aspect_ratio=0.75, min_random_scale=1.0, model_prefix=None, mom=0.9, monitor=0, network='resnet-v1b-fl', num_classes=1000, num_epochs=1, num_examples=12800, num_layers=50, optimizer='sgd', pad_size=0, pca_noise=0.0, pooling_layout='NHWC', profile_server_suffix='', profile_worker_suffix='', random_crop=0, random_mirror=1, random_resized_crop=1, resize=256, rgb_mean='123.68,116.779,103.939', rgb_std='1,1,1', saturation=0.4, save_period=1, seed=None, separ_val=False, set_data_aug_level=None, set_resnet_aug=None, test_io=0, top_k=0, use_dali=True, verbose=0, warmup_epochs=5, warmup_strategy='linear', wd=0.0001)
/opt/mxnet/example/image-classification/common/dali.py:142: UserWarning: 12800 training examples will be used, although full training set contains 1281167 examples
  warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
[17:14:43] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:119: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:Epoch[0] Batch [50]   Speed: 1005.45 samples/sec lr:0.000313  accuracy=0.000613
INFO:root:Epoch[0] Batch [100]  Speed: 1020.27 samples/sec lr:0.000625  accuracy=0.000625
INFO:root:Epoch[0] Batch [150]  Speed: 1016.28 samples/sec lr:0.000937  accuracy=0.000625
INFO:root:Epoch[0] Batch [200]  Speed: 1020.46 samples/sec lr:0.001250  accuracy=0.001875
INFO:root:Epoch[0] Batch [250]  Speed: 1018.46 samples/sec lr:0.001563  accuracy=0.000625
INFO:root:Epoch[0] Batch [300]  Speed: 1020.25 samples/sec lr:0.001875  accuracy=0.000625
INFO:root:Epoch[0] Batch [350]  Speed: 1020.17 samples/sec lr:0.002188  accuracy=0.002500
INFO:root:Epoch[0] Train-accuracy=0.001276
INFO:root:Epoch[0] Time cost=14.256
  1. MXNET_UPDATE_ON_KVSTORE=0
    MXNET_OPTIMIZER_AGGREGATION_SIZE=60 (max possible)
root@dgx1v-loki-19:/opt/mxnet/example/image-classification# numactl --physcpubind=0-4 ./train_imagenet_runner -n 1 -b 32 --network resnet-v1b --disp-batches 50 -e 1 --no-val -s 12800
INFO:root:start with arguments Namespace(batch_size=32, batchnorm_eps=2e-05, batchnorm_layout='NHWC', batchnorm_mom=0.9, benchmark=0, bn_gamma_init0=False, brightness=0.4, contrast=0.4, conv_algo=-1, conv_layout='NHWC', custom_bn_off=0, dali_nvjpeg_memory_padding=16, dali_prefetch_queue=3, dali_threads=3, data_nthreads=40, data_train='/data/imagenet/train-480-val-256-recordio/train.rec', data_train_idx='/data/imagenet/train-480-val-256-recordio/train.idx', data_val=None, data_val_idx='', disp_batches=50, dtype='float16', epoch_size=0, fill_value=127, force_tensor_core=0, fuse_bn_add_relu=1, fuse_bn_relu=1, gc_threshold=0.5, gc_type='none', gpus='0', image_shape='4,224,224', initializer='default', input_layout='NCHW', kv_store='device', load_epoch=None, log='', logging_dir='logs', loss='', lr=0.0125, lr_factor=0.1, lr_step_epochs='30,60,80', macrobatch_size=0, max_crop_size=-1, max_random_area=1.0, max_random_aspect_ratio=1.33, max_random_h=0, max_random_l=0, max_random_rotate_angle=0, max_random_s=0, max_random_scale=1.0, max_random_shear_ratio=0.0, min_crop_size=-1, min_random_area=0.08, min_random_aspect_ratio=0.75, min_random_scale=1.0, model_prefix=None, mom=0.9, monitor=0, network='resnet-v1b-fl', num_classes=1000, num_epochs=1, num_examples=12800, num_layers=50, optimizer='sgd', pad_size=0, pca_noise=0.0, pooling_layout='NHWC', profile_server_suffix='', profile_worker_suffix='', random_crop=0, random_mirror=1, random_resized_crop=1, resize=256, rgb_mean='123.68,116.779,103.939', rgb_std='1,1,1', saturation=0.4, save_period=1, seed=None, separ_val=False, set_data_aug_level=None, set_resnet_aug=None, test_io=0, top_k=0, use_dali=True, verbose=0, warmup_epochs=5, warmup_strategy='linear', wd=0.0001)
/opt/mxnet/example/image-classification/common/dali.py:142: UserWarning: 12800 training examples will be used, although full training set contains 1281167 examples
  warnings.warn("{} training examples will be used, although full training set contains {} examples".format(args.num_examples, trainpipes[0].epoch_size("Reader")))
[17:15:54] src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:119: Running performance tests to find the best convolution algorithm, this can take a while... (setting env variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
INFO:root:Epoch[0] Batch [50]   Speed: 1035.09 samples/sec lr:0.000313  accuracy=0.000613
INFO:root:Epoch[0] Batch [100]  Speed: 1047.64 samples/sec lr:0.000625  accuracy=0.000625
INFO:root:Epoch[0] Batch [150]  Speed: 1042.25 samples/sec lr:0.000937  accuracy=0.000625
INFO:root:Epoch[0] Batch [200]  Speed: 1047.25 samples/sec lr:0.001250  accuracy=0.001875
INFO:root:Epoch[0] Batch [250]  Speed: 1045.58 samples/sec lr:0.001563  accuracy=0.000625
INFO:root:Epoch[0] Batch [300]  Speed: 1044.48 samples/sec lr:0.001875  accuracy=0.000625
INFO:root:Epoch[0] Batch [350]  Speed: 1045.78 samples/sec lr:0.002188  accuracy=0.002500
INFO:root:Epoch[0] Train-accuracy=0.001276
INFO:root:Epoch[0] Time cost=13.927

@ptrendx ptrendx requested a review from nswamy as a code owner November 21, 2018 19:55
@lupesko
Copy link
Contributor

lupesko commented Dec 5, 2018

Thanks for the contribution @ptrendx !
Adding @nswamy and @sandeep-krishnamurthy to help review/merge.

@@ -98,6 +99,9 @@ def dict_equ(a, b):

@with_seed()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it not tested with test_trainer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you asking why I am not changing the test_trainer as well since it should fail with the MXNET_UPDATE_ON_KVSTORE=0 option set? Since you made a PR to fix that test, I did not change it. The MXNET_UPDATE_ON_KVSTORE=0 option is not set in CI (although logic for the aggregated SGD itself is tested by the SGD test).

@Roshrini
Copy link
Member

Roshrini commented Jan 2, 2019

@ptrendx Can you please rebase this PR?

Copy link
Member

@anirudhacharya anirudhacharya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the PR description you said that the test_sgd covers the new code paths. But in _update_impl there is an if statement with aggregate

if aggregate:
    ...
else:
    ...

can you explain how the else block is covered with the current test_sgd code.

docs/faq/env_var.md Show resolved Hide resolved
for weight, grad in zip(weights, grads):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
aggregate = (aggregate and
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anirudhacharya As you can see aggregate is set to True at the beginning and changes to False when encountering non-default storage type, so testing with both dense and sparse data tests both branches of the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!


template<typename DType, typename MPDType>
struct MultiSGDKernelParam {
static const int N = 60;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anirudhacharya This is the reason of 60 - I pass this struct as kernel parameter, which has a limit of 4 kB.

@ptrendx
Copy link
Member Author

ptrendx commented Jan 4, 2019

Is there anything else needed for this PR?

Copy link
Member

@anirudhacharya anirudhacharya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the code LGTM. Some minor doc fixes are needed I think

@@ -105,6 +110,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
self._index_update_count = {}
self.clip_gradient = clip_gradient
self.multi_precision = multi_precision
self.aggregate_num = 0
Copy link
Member

@anirudhacharya anirudhacharya Jan 4, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add this in the parameter list in the class doc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not really a parameter though - it is up to the optimizer (not the user) to override this value if they support aggregated execution.

@@ -502,6 +545,7 @@ def __init__(self, momentum=0.0, lazy_update=True, **kwargs):
super(SGD, self).__init__(**kwargs)
self.momentum = momentum
self.lazy_update = lazy_update
self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in line 510 can you add a section on aggregate updates and in line 524 can also point to these two methods - multi_sgd_mom_update and multi_mp_sgd_update as optimizer update rules.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote the section on aggregate updates, but I'm not sure about pointing to the new methods in line 524 - they use the same algorithm as the sgd_update and sgd_mom_update functions so pointing to those functions for details of the algorithm seems sufficient.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think the 'multi' update methods should show up in the SGD doc description. But I am okay with the code owner/merger making a call on this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's necessary to point to these two methods since the algorithm is the same one

@KellenSunderland
Copy link
Contributor

LGTM

@eric-haibin-lin Open question brought by @anirudh2290. In your opinion should 'multi' update methods should show up in the SGD doc description?

Copy link
Member

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good pending some suggestions for documentation. Awesome work

src/operator/optimizer_op.cc Outdated Show resolved Hide resolved
src/operator/optimizer_op.cc Outdated Show resolved Hide resolved
python/mxnet/gluon/trainer.py Show resolved Hide resolved
@eric-haibin-lin eric-haibin-lin merged commit 0a45e1a into apache:master Jan 24, 2019
jessr92 pushed a commit to jessr92/incubator-mxnet that referenced this pull request Jan 27, 2019
* Aggregate SGD

* Make OpWrapperGenerator understand Tuple<float>

* Trigger

* Add NNVM Tuple to cpp-package op.h

* Trigger

* Fix pylint aggregate SGD

* Update info about new ENV vars and modifying 2 tests that require
update_on_kvstore to be true

* Fix

* Aggregate SGD support for Gluon trainer

* Added text to doc about aggregate update in SGD optimizer

* Docs changes from review
apeforest added a commit to apeforest/incubator-mxnet that referenced this pull request Feb 5, 2019
apeforest added a commit to apeforest/incubator-mxnet that referenced this pull request Feb 13, 2019
apeforest added a commit to apeforest/incubator-mxnet that referenced this pull request Feb 13, 2019
apeforest added a commit to apeforest/incubator-mxnet that referenced this pull request Feb 14, 2019
apeforest added a commit to apeforest/incubator-mxnet that referenced this pull request Feb 14, 2019
stephenrawls pushed a commit to stephenrawls/incubator-mxnet that referenced this pull request Feb 16, 2019
* Aggregate SGD

* Make OpWrapperGenerator understand Tuple<float>

* Trigger

* Add NNVM Tuple to cpp-package op.h

* Trigger

* Fix pylint aggregate SGD

* Update info about new ENV vars and modifying 2 tests that require
update_on_kvstore to be true

* Fix

* Aggregate SGD support for Gluon trainer

* Added text to doc about aggregate update in SGD optimizer

* Docs changes from review
@kice
Copy link
Contributor

kice commented Mar 9, 2019

Missing type information for some parameteres

E.g. From here https://mxnet.incubator.apache.org/api/python/symbol/symbol.html#mxnet.symbol.multi_mp_sgd_mom_update

lrs (tuple of , required) – Learning rates.

This should be tuple of <float>, instead of tuple of .

And OpWrapperGenerator.py also complains this

argument "lrs" of operator "multi_sgd_update" has unknown type ", required"
argument "wds" of operator "multi_sgd_update" has unknown type ", required"
argument "lrs" of operator "multi_sgd_mom_update" has unknown type ", required"
argument "wds" of operator "multi_sgd_mom_update" has unknown type ", required"
argument "lrs" of operator "multi_mp_sgd_update" has unknown type ", required"
argument "wds" of operator "multi_mp_sgd_update" has unknown type ", required"
argument "lrs" of operator "multi_mp_sgd_mom_update" has unknown type ", required"
argument "wds" of operator "multi_mp_sgd_mom_update" has unknown type ", required"

haohuanw pushed a commit to haohuanw/incubator-mxnet that referenced this pull request Jun 23, 2019
* Aggregate SGD

* Make OpWrapperGenerator understand Tuple<float>

* Trigger

* Add NNVM Tuple to cpp-package op.h

* Trigger

* Fix pylint aggregate SGD

* Update info about new ENV vars and modifying 2 tests that require
update_on_kvstore to be true

* Fix

* Aggregate SGD support for Gluon trainer

* Added text to doc about aggregate update in SGD optimizer

* Docs changes from review
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
pr-awaiting-review PR is waiting for code review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants