Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fixes for trainer with update_on_kvstore=False #13721

Merged
merged 13 commits into from
Dec 29, 2018
Prev Previous commit
Next Next commit
revert optimizer list
  • Loading branch information
Ubuntu committed Dec 23, 2018
commit 634bca2698d9db6a1079ab93800c75cf40493cc4
78 changes: 32 additions & 46 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class Trainer(object):
`update_on_kvstore=False` is not supported in the following cases:
- dist kvstore with sparse weights or sparse gradients
- dist async kvstore
- optimizer.lr_scheduler is not None

Properties
----------
Expand Down Expand Up @@ -94,13 +95,7 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
self._compression_params = compression_params
optimizer_params = optimizer_params if optimizer_params else {}
self._scale = float(optimizer_params.get('rescale_grad', 1.0))
# one optimizer / updater per context
# If self._update_on_kvstore is set to `True` in `_init_kvstore()`, then:
# - updaters[:] are never used.
# - optimizer[0] is registered with kvstore. optimizer[1:] are never used.
self._contexts = self._check_contexts()
self._optimizers = []
self._updaters = []
self._init_optimizer(optimizer, optimizer_params)
self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore}
self._kv_initialized = False
Expand All @@ -127,19 +122,15 @@ def _init_optimizer(self, optimizer, optimizer_params):
assert not optimizer_params, \
"optimizer_params must be None if optimizer is an instance of " \
"Optimizer instead of str"
else:
optimizer = opt.create(optimizer, **optimizer_params)
optimizer.param_dict = param_dict
self._optimizers = [optimizer]
self._updaters = [opt.get_updater(optimizer)]
# create a deep copy of the optimizer per context
for _ in range(len(self._contexts) - 1):
optim = copy.deepcopy(optimizer)
self._optimizer = optimizer
# param_dict must not be deep copied, so that if user mutate the lr_mult
# or wd_mult of some parameters, it takes effect.
optim.param_dict = param_dict
self._optimizers.append(optim)
self._updaters.append(opt.get_updater(optim))
self._optimizer.param_dict = param_dict
else:
self._optimizer = opt.create(optimizer, param_dict=param_dict,
**optimizer_params)
self._updaters = [opt.get_updater(self._optimizer) \
for _ in self._contexts]

def _init_params(self):
"""Initialize parameters in the KVStore.
Expand Down Expand Up @@ -245,9 +236,13 @@ def _init_kvstore(self):
kvstore.set_gradient_compression(self._compression_params)
if update_on_kvstore:
# optimizer preferably needs to be set before init for multiprecision
kvstore.set_optimizer(self._optimizers[0])
kvstore.set_optimizer(self._optimizer)
self._kvstore = kvstore
self._update_on_kvstore = update_on_kvstore
if self._optimizer.lr_scheduler is not None:
assert self._update_on_kvstore, "update_on_kvstore=True does not support " \
"optimizer with LRScheduler. Please " \
"consider setting learning rate manually."
else:
self._kvstore = None
self._update_on_kvstore = None
Expand All @@ -256,16 +251,11 @@ def _init_kvstore(self):

@property
def learning_rate(self):
if not isinstance(self._optimizers[0], opt.Optimizer):
if not isinstance(self._optimizer, opt.Optimizer):
raise UserWarning("Optimizer has to be defined before its learning "
"rate can be accessed.")
else:
lr = self._optimizers[0].learning_rate
for i in range(self._contexts):
if self._optimizers[i].learning_rate != lr:
raise UserWarning("The optimizer on %s has a different learning rate"
" from that on %s. Cannot return learning rate")
return lr
return self._optimizer.learning_rate

def set_learning_rate(self, lr):
"""Sets a new learning rate of the optimizer.
Expand All @@ -275,14 +265,11 @@ def set_learning_rate(self, lr):
lr : float
The new learning rate of the optimizer.
"""
if not self._optimizers:
if not isinstance(self._optimizer, opt.Optimizer):
raise UserWarning("Optimizer has to be defined before its learning "
"rate is mutated.")
for optim in self._optimizers:
if not isinstance(optim, opt.Optimizer):
raise UserWarning("Optimizer has to be defined before its learning "
"rate is mutated.")
optim.set_learning_rate(lr)
else:
self._optimizer.set_learning_rate(lr)

def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
"""Internal method to invoke pull operations on KVStore. If `full_idx` is set to True,
Expand All @@ -301,15 +288,14 @@ def _row_sparse_pull(self, parameter, out, row_id, full_idx=False):
self._kvstore.row_sparse_pull(idx, out=out, row_ids=row_id, priority=-idx)

def _check_and_rescale_grad(self, scale):
for optim in self._optimizers:
if self._update_on_kvstore and self._distributed and self._kv_initialized:
if optim.rescale_grad != scale:
raise UserWarning('Possible change in the `batch_size` from previous '
'`step` detected. Optimizer gradient normalizing '
'factor will not change w.r.t new batch_size when '
'update_on_kvstore=True and when distributed kvstore '
'is used.')
optim.rescale_grad = scale
if self._update_on_kvstore and self._distributed and self._kv_initialized:
if self._optimizer.rescale_grad != scale:
raise UserWarning('Possible change in the `batch_size` from previous '
'`step` detected. Optimizer gradient normalizing '
'factor will not change w.r.t new batch_size when '
'update_on_kvstore=True and when distributed kvstore '
'is used.')
self._optimizer.rescale_grad = scale

def step(self, batch_size, ignore_stale_grad=False):
"""Makes one step of parameter update. Should be called after
Expand Down Expand Up @@ -448,7 +434,7 @@ def save_states(self, fname):
`optimizer.param_dict`, which contains Parameter information (such as
`lr_mult` and `wd_mult`) will not be saved.
"""
assert self._optimizers and self._optimizers[0] is not None
assert self._optimizer is not None

if not self._kv_initialized:
self._init_kvstore()
Expand Down Expand Up @@ -483,14 +469,14 @@ def load_states(self, fname):
self._init_params()

if self._update_on_kvstore:
self._kvstore.load_optimizer_states(fname)
optimizer = self._kvstore._updater.optimizer
self._init_optimizer(optimizer, None)
self._optimizer = self._kvstore._updater.optimizer
param_dict = {i: param for i, param in enumerate(self._params)}
else:
with open(fname, 'rb') as f:
states = f.read()
param_dict = {i: param for i, param in enumerate(self._params)}
for updater in self._updaters:
updater.set_states(states)
updater.optimizer.param_dict = param_dict
self._optimizers = [updater.optimizer for updater in self._updaters]
updater.optimizer = self._updaters[0].optimizer
self._optimizer = self._updaters[0].optimizer
self._optimizer.param_dict = param_dict
29 changes: 19 additions & 10 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,7 @@ def dict_equ(a, b):
y.backward()
trainer.step(1)

assert len(trainer._optimizers) == 2
assert len(trainer._updaters) == 2
assert trainer._optimizers[0].param_dict == trainer._optimizers[1].param_dict
assert trainer._optimizer.param_dict == trainer._optimizer.param_dict
assert (x.data(mx.cpu(1)).asnumpy() == -2).all()

x.lr_mult = 0.5
Expand All @@ -74,18 +72,14 @@ def dict_equ(a, b):
trainer.load_states('test_trainer.states')
if trainer._update_on_kvstore:
dict_equ(trainer._kvstore._updater.states, states)
assert trainer._optimizers[0] == trainer._kvstore._updater.optimizer
assert len(trainer._optimizers) == 2
assert len(trainer._updaters) == 2
assert trainer._optimizer == trainer._kvstore._updater.optimizer
# invalid usage of update and allreduce_grads if update_on_kvstore
assert_raises(AssertionError, trainer.update, 1)
assert_raises(AssertionError, trainer.allreduce_grads)
else:
for updater in trainer._updaters:
dict_equ(updater.states, states)
assert trainer._optimizers[0] == trainer._updaters[0].optimizer
assert len(trainer._optimizers) == 2
assert len(trainer._updaters) == 2
assert trainer._optimizer == trainer._updaters[0].optimizer

x = gluon.Parameter('x', shape=(10,))
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
Expand Down Expand Up @@ -254,4 +248,19 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv, expected):

@with_seed()
def test_trainer_lr_scheduler():
pass
x = gluon.Parameter('x', shape=(10,))
x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
freq = 2
factor = 0.1
lr = 1
lr_sched = mx.lr_scheduler.FactorScheduler(freq, factor=factor, base_lr=lr)
trainer = gluon.Trainer([x], 'sgd', {'learning_rate': lr, 'lr_scheduler': lr_sched})
for i in range(10):
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
if i % freq == 0:
assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i)
lr *= factor