Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

AdamW operator (Fixing Weight Decay Regularization in Adam) #13728

Merged
merged 3 commits into from
Dec 28, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply)
from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update,
mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update,
signsgd_update, signum_update)
signsgd_update, signum_update, adamw_update)
from ..ndarray import sparse
from ..random import normal

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LBSGD',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register', 'AdamW'
]


Expand Down Expand Up @@ -1018,6 +1018,70 @@ class ccSGD(SGD):
def __init__(self, *args, **kwargs):
super(ccSGD, self).__init__(*args, **kwargs)

@register
class AdamW(Optimizer):
"""The Adam optimizer with fixed weight decay regularization.

This class implements the optimizer described in *Fixing Weight Decay
Regularization in Adam*, available at https://arxiv.org/abs/1711.05101.

Note that this is different from the original Adam optimizer which adds L2
regularization on the weights to the loss: it regularizes weights with large
gradients more than L2 regularization would, which was shown to yield better
training loss and generalization error in the paper above.

Updates are applied by::

rescaled_grad = clip(grad * rescale_grad, clip_gradient)
m = beta1 * m + (1 - beta1) * rescaled_grad
v = beta2 * v + (1 - beta2) * (rescaled_grad**2)
w = w - learning_rate * (m / (sqrt(v) + epsilon) + wd * w)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the paper, it has two learning rates. An alpha before m / (sqrt(v) + epsilon).

Copy link
Member Author

@eric-haibin-lin eric-haibin-lin Dec 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. The issue is that the learning rate and schedule multiplier is not decoupled in MXNet. Here learning_rate is effectively eta_t * alpha in the paper and wd actually needs to be set as w / alpha. In another word wd can be rescaled properly so that it does exactly the same thing in the paper. Would this be acceptable? Is so maybe I can move this to contrib for the moment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's acceptable as long as the wd is set correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On second thought I think it's better to keep it consistent with the paper


This optimizer accepts the following parameters in addition to those accepted
by :class:`.Optimizer`.

For details of the update algorithm, see :class:`~mxnet.ndarray.adamw_update`.

Parameters
----------
beta1 : float, optional
Exponential decay rate for the first moment estimates.
beta2 : float, optional
Exponential decay rate for the second moment estimates.
epsilon : float, optional
Small value to avoid division by 0.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
**kwargs):
super(AdamW, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
return (zeros(weight.shape, weight.context, dtype=weight.dtype), #mean
zeros(weight.shape, weight.context, dtype=weight.dtype)) #variance

def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)

t = self._index_update_count[index]
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
lr *= math.sqrt(coef2)/coef1

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

mean, var = state
adamw_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set wd to something like wd / self._original_lr?


@register
class Adam(Optimizer):
"""The Adam optimizer.
Expand Down
30 changes: 22 additions & 8 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,10 @@ struct AdamParam : public dmlc::Parameter<AdamParam> {
}
};

template<typename xpu>
/*
* \brief adam and adam_w update. Set decoupled=True for adam_w.
*/
template<typename xpu, bool decoupled>
inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
Expand All @@ -855,9 +858,12 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;

if (decoupled) {
grad = scalar<DType>(param.rescale_grad) * grad;
} else {
grad = scalar<DType>(param.rescale_grad) * grad +
scalar<DType>(param.wd) * weight;
}
if (param.clip_gradient >= 0.0f) {
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) *
F<clip>(grad, DType(param.clip_gradient));
Expand All @@ -867,10 +873,18 @@ inline void AdamUpdate(const nnvm::NodeAttrs& attrs,
mean = scalar<DType>(param.beta1)*mean + scalar<DType>(1.f-param.beta1) * grad;
var = scalar<DType>(param.beta2)*var + scalar<DType>(1.f-param.beta2) * F<square>(grad);
}
Assign(out, req[0],
weight -
scalar<DType>(param.lr) * mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)));
if (decoupled) {
Assign(out, req[0],
weight -
scalar<DType>(param.lr) * (mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)) +
(scalar<DType>(param.wd) * weight)));
} else {
Assign(out, req[0],
weight -
scalar<DType>(param.lr) * mean /
(F<square_root>(var) + scalar<DType>(param.epsilon)));
}
});
}

Expand Down
45 changes: 41 additions & 4 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,15 +472,16 @@ are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
g_t = \nabla J(W_{t-1}) + wd W_{t-1}\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \alpha \frac{ m_t }{ \sqrt{ v_t } + \epsilon }

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
g = grad + wd*w
m = beta1*m + (1-beta1)*g
v = beta2*v + (1-beta2)*(g**2)
w += - learning_rate * m / (sqrt(v) + epsilon)

However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and the storage
Expand All @@ -507,14 +508,50 @@ only the row slices whose indices appear in grad.indices are updated (for w, m a
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu>)
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu, false>)
.set_attr<FComputeEx>("FComputeEx<cpu>", AdamUpdateEx<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(AdamParam::__FIELDS__());

NNVM_REGISTER_OP(adamw_update)
.describe(R"code(Update function for AdamW optimizer. AdamW is seen as a modification of
Adam by decoupling the weight decay from the optimization steps taken w.r.t. the loss function.

Adam update consists of the following steps, where g represents gradient and m, v
are 1st and 2nd order moment estimates (mean and variance).

.. math::

g_t = \nabla J(W_{t-1})\\
m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t\\
v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2\\
W_t = W_{t-1} - \alpha (\frac{ m_t }{ \sqrt{ v_t } + \epsilon } + wd W_{t-1})

It updates the weights using::

m = beta1*m + (1-beta1)*grad
v = beta2*v + (1-beta2)*(grad**2)
w += - learning_rate * (m / (sqrt(v) + epsilon) + w*wd)

)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<AdamParam>)
.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.set_attr<FCompute>("FCompute<cpu>", AdamUpdate<cpu, true>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(AdamParam::__FIELDS__());

NNVM_REGISTER_OP(rmsprop_update)
.describe(R"code(Update function for `RMSProp` optimizer.
Expand Down
5 changes: 4 additions & 1 deletion src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,12 @@ NNVM_REGISTER_OP(ftml_update)
.set_attr<FCompute>("FCompute<gpu>", FTMLUpdate<gpu>);

NNVM_REGISTER_OP(adam_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu>)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu, false>)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdamUpdateEx<gpu>);

NNVM_REGISTER_OP(adamw_update)
.set_attr<FCompute>("FCompute<gpu>", AdamUpdate<gpu, true>);

NNVM_REGISTER_OP(rmsprop_update)
.set_attr<FCompute>("FCompute<gpu>", RMSPropUpdate<gpu>);

Expand Down
89 changes: 87 additions & 2 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,12 +506,11 @@ def test_ftml():
class PyAdam(mx.optimizer.Optimizer):
"""python reference implemenation of adam"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
decay_factor=(1 - 1e-8), lazy_update=True, **kwargs):
lazy_update=True, **kwargs):
super(PyAdam, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.decay_factor = decay_factor
self.lazy_update = lazy_update

def create_state(self, index, weight):
Expand Down Expand Up @@ -614,6 +613,92 @@ def test_adam():
dtype, w_stype='default', g_stype='row_sparse',
rtol=1e-4, atol=2e-5)

# ADAMW
class PyAdamW(mx.optimizer.Optimizer):
"""python reference implemenation of AdamW"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
**kwargs):
super(PyAdamW, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon

def create_state(self, index, weight):
"""Create additional optimizer state: mean, variance

Parameters
----------
weight : NDArray
The weight data

"""
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype), # mean
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype)) # variance

def update(self, index, weight, grad, state):
"""Update the parameters.

Parameters
----------
index : int
An unique integer key used to index the parameters

weight : NDArray
weight ndarray

grad : NDArray
grad ndarray

state : NDArray or other objects returned by init_state
The auxiliary state used in optimization.
"""
lr = self._get_lr(index)
self._update_count(index)

t = self._index_update_count[index]
mean, variance = state

wd = self._get_wd(index)
coef1 = 1. - self.beta1**t
coef2 = 1. - self.beta2**t
lr *= math.sqrt(coef2)/coef1

grad *= self.rescale_grad
# clip gradients
if self.clip_gradient is not None:
mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient, out=grad)
# update mean
mean *= self.beta1
mean += grad * (1. - self.beta1)
# update variance
variance *= self.beta2
variance += (1 - self.beta2) * mx.nd.square(grad, out=grad)
# update weight
weight -= lr * (mean/(mx.nd.sqrt(variance) + self.epsilon) + wd * weight)

@with_seed()
def test_adamw():
opt1 = PyAdamW
opt2 = mx.optimizer.AdamW
shape = (3, 4, 5)
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
mp_options = [{}, {'multi_precision': False}, {'multi_precision': True}]
for dtype in [np.float16, np.float32, np.float64]:
for cg_option in cg_options:
for rg_option in rg_options:
for wd_option in wd_options:
for mp_option in mp_options:
kwarg = {}
kwarg.update(cg_option)
kwarg.update(rg_option)
kwarg.update(wd_option)
kwarg.update(mp_option)
if (dtype == np.float16 and
('multi_precision' not in kwarg or not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype)

# AdaMax
class PyAdamax(mx.optimizer.Optimizer):
Expand Down