Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lion Optimizer #1062

Merged
merged 3 commits into from
Oct 20, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Refactor Lion implementation
  • Loading branch information
andylolu2 committed Oct 20, 2023
commit f3a2dd6ec908dd9bddd63d5c1fccd2aa27a3a20d
115 changes: 52 additions & 63 deletions megatron/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,12 +271,12 @@ class madgrad_wd(torch.optim.Optimizer):
"""

def __init__(
self,
params: _params_t,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
self,
params: _params_t,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
):
if momentum < 0 or momentum >= 1:
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
Expand Down Expand Up @@ -417,7 +417,7 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

class Lion(Optimizer):
"""
Implementes the Lion Algorithm
Implements the Lion Algorithm

.. / _Lion: https://arxiv.org/abs/2302.06675

Expand All @@ -436,73 +436,62 @@ class Lion(Optimizer):
Weight decay, i.e. a L2 penalty (default: 0).

"""
def exists(val):
return val is not None

def update_fn(self, p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay

p.data.mul_(1 - lr * wd)

# weight update

update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
p.add_(update, alpha=-lr)

# decay the momentum running average coefficient

exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay
)
if lr <= 0:
raise ValueError(f"Learning rate {lr} must be positive")
if weight_decay < 0:
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
if not (0 <= betas[0] <= 1 and 0 <= betas[1] <= 1):
raise ValueError(f"Betas {betas} must be in range [0, 1)")

defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):
def update(self, p, grad, exp_avg, lr, wd, beta1, beta2):
"""https://arxiv.org/pdf/2302.06675.pdf#appendix.A"""

loss = None
if self.exists(closure):
with torch.enable_grad():
loss = closure()
# update model parameters
p.mul_(1 - lr * wd)
sign = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
p.add_(sign, alpha=-lr)

for group in self.param_groups:
for p in filter(lambda p: self.exists(p.grad), group['params']):
# update EMA
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
self.state[p]
@torch.no_grad()
def step(self, closure: Optional[Callable] = None):

# init state - exponential moving average of gradient values
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

exp_avg = state['exp_avg']
state = self.state[p]

self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)
# init state - exponential moving average of gradient values
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p.data).detach()

self.update(
p,
p.grad,
state["exp_avg"],
group["lr"],
group["weight_decay"],
group["betas"][0],
group["betas"][1],
)

return loss
return loss