Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement LARS optimizer #50

Merged
merged 13 commits into from
Feb 1, 2022
Prev Previous commit
update: LARS
  • Loading branch information
kozistr committed Feb 1, 2022
commit 9fadf571acfe757a708bdd65c40299516fe69ca1
7 changes: 6 additions & 1 deletion pytorch_optimizer/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ def __init__(
weight_decay: float = 0.0,
momentum: float = 0.9,
trust_coefficient: float = 0.001,
eps: float = 1e-6,
):
"""LARS optimizer, no rate scaling or weight decay for parameters <= 1D
:param params: PARAMETERS. iterable of parameters to optimize or dicts defining parameter groups
:param lr: float. learning rate
:param weight_decay: float. weight decay (L2 penalty)
:param momentum: float. momentum
:param trust_coefficient: float. trust_coefficient
:param eps: float. epsilon
"""
self.lr = lr
self.weight_decay = weight_decay
self.momentum = momentum
self.trust_coefficient = trust_coefficient
self.eps = eps

self.check_valid_parameters()

Expand All @@ -59,6 +62,8 @@ def check_valid_parameters(self):
raise ValueError(f'Invalid momentum : {self.momentum}')
if self.trust_coefficient < 0.0:
raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}')
if self.eps < 0.0:
raise ValueError(f'Invalid eps : {self.eps}')

@torch.no_grad()
def step(self, closure: CLOSURE = None) -> LOSS:
Expand All @@ -84,7 +89,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:

q = torch.where(
param_norm > 0.0,
torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one),
torch.where(update_norm > 0.0, (g['trust_coefficient'] * param_norm / update_norm), one),
one,
)
dp = dp.mul(q)
Expand Down