Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Implement LARS optimizer #50

Merged
merged 13 commits into from
Feb 1, 2022
Prev Previous commit
Next Next commit
update: LARS
  • Loading branch information
kozistr committed Feb 1, 2022
commit 6c23f3d3b55fd56c6ec0775bf9836c4400e87748
17 changes: 12 additions & 5 deletions pytorch_optimizer/lars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
from torch.optim import Optimizer

from pytorch_optimizer.types import DEFAULTS, PARAMETERS
from pytorch_optimizer.types import CLOSURE, DEFAULTS, LOSS, PARAMETERS


class LARS(Optimizer):
Expand Down Expand Up @@ -61,19 +61,24 @@ def check_valid_parameters(self):
raise ValueError(f'Invalid trust_coefficient : {self.trust_coefficient}')

@torch.no_grad()
def step(self):
def step(self, closure: CLOSURE = None) -> LOSS:
loss: LOSS = None
if closure is not None:
loss = closure()

for g in self.param_groups:
for p in g['params']:
dp = p.grad

if dp is None:
if p.grad is None:
continue

dp = p.grad

if p.ndim > 1: # if not normalization gamma/beta or bias
dp = dp.add(p, alpha=g['weight_decay'])
param_norm = torch.norm(p)
update_norm = torch.norm(dp)
one = torch.ones_like(param_norm)

q = torch.where(
param_norm > 0.0,
torch.where(update_norm > 0, (g['trust_coefficient'] * param_norm / update_norm), one),
Expand All @@ -89,3 +94,5 @@ def step(self):
mu.mul_(g['momentum']).add_(dp)

p.add_(mu, alpha=-g['lr'])

return loss