Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Improve MADGRAD optimizer #24

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_optimizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from pytorch_optimizer.sam import SAM
from pytorch_optimizer.sgdp import SGDP

__VERSION__ = '0.0.6'
__VERSION__ = '0.0.7'
21 changes: 10 additions & 11 deletions pytorch_optimizer/madgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

class MADGRAD(Optimizer):
"""
Reference : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
Reference 1 : https://github.com/facebookresearch/madgrad/blob/main/madgrad/madgrad.py
Reference 2 : https://github.com/lessw2020/Best-Deep-Learning-Optimizers/blob/master/madgrad/madgrad_wd.py
Example :
from pytorch_optimizer import MADGRAD
...
Expand All @@ -35,11 +36,13 @@ def __init__(
weight_decay: float = 0.0,
eps: float = 1e-6,
):
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic
"""A Momentumized, Adaptive, Dual Averaged Gradient Method for Stochastic (slightly modified)
:param params: PARAMS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate.
:param eps: float. term added to the denominator to improve numerical stability
:param weight_decay: float. weight decay (L2 penalty)
MADGRAD optimizer requires less weight decay than other methods, often as little as zero
On sparse problems both weight_decay and momentum should be set to 0.
"""
self.lr = lr
self.momentum = momentum
Expand Down Expand Up @@ -72,11 +75,6 @@ def supports_flat_params(self) -> bool:
return True

def step(self, closure: CLOSURE = None) -> LOSS:
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss: LOSS = None
if closure is not None:
loss = closure()
Expand Down Expand Up @@ -124,7 +122,11 @@ def step(self, closure: CLOSURE = None) -> LOSS:
'weight_decay option is not compatible with sparse gradients'
)

grad.add_(p.data, alpha=decay)
# original implementation
# grad.add_(p.data, alpha=decay)

# Apply weight decay - L2 / AdamW style
p.data.mul_(1 - lr * decay)

if grad.is_sparse:
grad = grad.coalesce()
Expand Down Expand Up @@ -174,16 +176,13 @@ def step(self, closure: CLOSURE = None) -> LOSS:
grad_sum_sq.addcmul_(grad, grad, value=_lambda)
rms = grad_sum_sq.pow(1 / 3).add_(eps)

# Update s
s.data.add_(grad, alpha=_lambda)

# Step
if momentum == 0:
p.data.copy_(x0.addcdiv(s, rms, value=-1))
else:
z = x0.addcdiv(s, rms, value=-1)

# p is a moving average of z
p.data.mul_(1 - ck).add_(z, alpha=ck)

self.state['k'] += 1
Expand Down