Skip to content

Commit

Permalink
Merge pull request #887 from fdlm/master
Browse files Browse the repository at this point in the history
Implemented amsgrad updates
  • Loading branch information
f0k committed Feb 21, 2018
2 parents ffc8b8a + 5d968ba commit 8978b1d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/modules/updates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Update functions
.. autofunction:: adadelta
.. autofunction:: adam
.. autofunction:: adamax
.. autofunction:: amsgrad


Update modification functions
Expand Down
8 changes: 8 additions & 0 deletions lasagne/tests/test_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class TestUpdateFunctions(object):
'adamax': [0.90211749000754,
0.90211748762402,
0.90211748682951],
'amsgrad': [0.90034979581833,
0.90034979581833,
0.90034979581833],
}

def f(self, X):
Expand All @@ -49,6 +52,7 @@ def f(self, X):
['adadelta', {}],
['adam', {'learning_rate': 0.01}],
['adamax', {'learning_rate': 0.01}],
['amsgrad', {'learning_rate': 0.01}],
])
def test_updates(self, method, kwargs):
A = theano.shared(lasagne.utils.floatX([1, 1, 1]))
Expand Down Expand Up @@ -87,6 +91,10 @@ def test_updates(self, method, kwargs):
'beta1': 0.9,
'beta2': 0.999,
'epsilon': 1e-8}],
['amsgrad', {'learning_rate': 0.01,
'beta1': 0.9,
'beta2': 0.999,
'epsilon': 1e-8}],
])
def test_update_returntype(self, method, kwargs):
'''Checks whether lasagne.updates handles float32 inputs correctly'''
Expand Down
65 changes: 65 additions & 0 deletions lasagne/updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
adadelta
adam
adamax
amsgrad
Two functions can be used to further modify the updates to include momentum:
Expand Down Expand Up @@ -99,6 +100,7 @@
"adadelta",
"adam",
"adamax",
"amsgrad",
"norm_constraint",
"total_norm_constraint"
]
Expand Down Expand Up @@ -691,6 +693,69 @@ def adamax(loss_or_grads, params, learning_rate=0.002, beta1=0.9,
return updates


def amsgrad(loss_or_grads, params, learning_rate=0.001, beta1=0.9,
beta2=0.999, epsilon=1e-8):
"""AMSGrad updates
AMSGrad updates implemented as in [1]_.
Parameters
----------
loss_or_grads : symbolic expression or list of expressions
A scalar loss expression, or a list of gradient expressions
params : list of shared variables
The variables to generate update expressions for
learning_rate : float or symbolic scalar
Learning rate
beta1 : float or symbolic scalar
Exponential decay rate for the first moment estimates.
beta2 : float or symbolic scalar
Exponential decay rate for the second moment estimates.
epsilon : float or symbolic scalar
Constant for numerical stability.
Returns
-------
OrderedDict
A dictionary mapping each parameter to its update expression
References
----------
.. [1] https://openreview.net/forum?id=ryQu7f-RZ
"""
all_grads = get_or_compute_grads(loss_or_grads, params)
t_prev = theano.shared(utils.floatX(0.))
updates = OrderedDict()

# Using theano constant to prevent upcasting of float32
one = T.constant(1)

t = t_prev + 1
a_t = learning_rate*T.sqrt(one-beta2**t)/(one-beta1**t)

for param, g_t in zip(params, all_grads):
value = param.get_value(borrow=True)
m_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)
v_hat_prev = theano.shared(np.zeros(value.shape, dtype=value.dtype),
broadcastable=param.broadcastable)

m_t = beta1*m_prev + (one-beta1)*g_t
v_t = beta2*v_prev + (one-beta2)*g_t**2
v_hat_t = T.maximum(v_hat_prev, v_t)
step = a_t*m_t/(T.sqrt(v_hat_t) + epsilon)

updates[m_prev] = m_t
updates[v_prev] = v_t
updates[v_hat_prev] = v_hat_t
updates[param] = param - step

updates[t_prev] = t
return updates


def norm_constraint(tensor_var, max_norm, norm_axes=None, epsilon=1e-7):
"""Max weight norm constraints and gradient clipping
Expand Down

0 comments on commit 8978b1d

Please sign in to comment.