Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixes for trainer with update_on_kvstore=False #13721

Merged
merged 13 commits into from
Dec 29, 2018
Prev Previous commit
Next Next commit
fix a bug
  • Loading branch information
Ubuntu committed Dec 23, 2018
commit 9be416251925a9ec0b41fee91ff4267951544f95
8 changes: 5 additions & 3 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,11 @@ def _init_kvstore(self):
self._distributed = 'dist' in kvstore.type if kvstore else False
update_on_kvstore = self._distributed
# raise err if user provides unsupported configs
if config['update_on_kvstore'] is False and self._distributed:
raise RuntimeError("Cannot set update_on_kvstore=False on dist kvstore "
"when sparse gradients are present.")
if config['update_on_kvstore'] is not None:
if config['update_on_kvstore'] is False and self._distributed:
raise ValueError("Cannot set update_on_kvstore=False on dist kvstore "
"when sparse gradients are present.")
update_on_kvstore = config['update_on_kvstore']

else:
# Training with dense weight and dense gradients.
Expand Down
34 changes: 20 additions & 14 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,34 +217,40 @@ def check_trainer_reset_kv(kv):

@with_seed()
def test_trainer_sparse_kv():
def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected_update_on_kv):
def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected):
params = gluon.ParameterDict()
x = params.get('x', shape=(10,1), lr_mult=1.0, stype=stype, grad_stype=grad_stype)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1},
kvstore=kv, update_on_kvstore=update_on_kv)
all_rows = mx.nd.arange(0, 10, ctx=mx.cpu(0))
ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows)
with mx.autograd.record():
for w in ws:
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore.type == kv
assert trainer._kv_initialized
assert trainer._update_on_kvstore is expected_update_on_kv
# the updated parameter should be based on the loaded checkpoint
mx.nd.waitall()
updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows)
assert (updated_w == -0.2).asnumpy().all()
try:
ws = x.list_data() if stype == 'default' else x.list_row_sparse_data(all_rows)
with mx.autograd.record():
for w in ws:
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore.type == kv
assert trainer._kv_initialized
assert trainer._update_on_kvstore is expected
# the updated parameter should be based on the loaded checkpoint
mx.nd.waitall()
updated_w = x.data(mx.cpu(0)) if stype == 'default' else x.row_sparse_data(all_rows)
assert (updated_w == -0.2).asnumpy().all()
except Exception as err:
assert isinstance(err, expected)

kvs = ['local', 'device']
for kv in kvs:
check_trainer_sparse_kv(kv, 'default', 'default', True, True)
check_trainer_sparse_kv(kv, 'default', 'default', False, False)
check_trainer_sparse_kv(kv, 'default', 'default', None, True)
check_trainer_sparse_kv(kv, 'default', 'row_sparse', None, False)
check_trainer_sparse_kv(kv, 'default', 'row_sparse', True, True)
check_trainer_sparse_kv(kv, 'default', 'row_sparse', False, False)
check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', None, True)
check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', False, ValueError)

@with_seed()
def test_trainer_lr_scheduler():
Expand Down