Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lion Optimizer #1012

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ Optimizer Arguments



- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion']

Default = adam

Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion']
NOTE: sgd will use MuSGD from Mup. Mup must be enabled for this optimizer.


Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class NeoXArgsOptimizer(NeoXArgsTemplate):
"""

optimizer_type: Literal[
"adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd"
"adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd", "lion"
] = "adam"
"""
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
Expand Down
107 changes: 100 additions & 7 deletions megatron/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _max_reduce_except_dim(tensor, dim):
# closure is checked if callable or not since some code passes loss directly, rather than in closure param

import math
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional, Tuple

import torch
import torch.optim
Expand Down Expand Up @@ -271,12 +271,12 @@ class madgrad_wd(torch.optim.Optimizer):
"""

def __init__(
self,
params: _params_t,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
self,
params: _params_t,
lr: float = 1e-2,
momentum: float = 0.9,
weight_decay: float = 0,
eps: float = 1e-6,
):
if momentum < 0 or momentum >= 1:
raise ValueError(f"Momentum {momentum} must be in the range [0,1]")
Expand Down Expand Up @@ -413,3 +413,96 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

self.state["k"] += 1
return loss


class Lion(Optimizer):
"""
Implementes the Lion Algorithm

.. / _Lion: https://arxiv.org/abs/2302.06675

Compared to AdamW and various adaptive optimizers that need to save both first and second moments,
Lion only needs the momentum, halving the additional memory footprint. This is beneficial when training large models
and / or with a large batch size.

Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate (default: 1e-2).
beta (float):
coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).

"""
def exists(val):
return val is not None

def update_fn(self, p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay

p.data.mul_(1 - lr * wd)

# weight update

update = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
p.add_(update, alpha=-lr)

# decay the momentum running average coefficient

exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])

defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay
)

super().__init__(params, defaults)

@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):

loss = None
if self.exists(closure):
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in filter(lambda p: self.exists(p.grad), group['params']):

grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
self.state[p]

# init state - exponential moving average of gradient values

if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)

exp_avg = state['exp_avg']

self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)

return loss
8 changes: 8 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,14 @@ def get_optimizer(model, neox_args):
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "lion":
from .optimizers import Lion

optimizer = Lion(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"]
)
elif neox_args.optimizer_type.lower() == "adam":
# Use Adam
if neox_args.use_mup:
Expand Down
1 change: 1 addition & 0 deletions tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def wrapper():
{"type": "cpu_adam", "params": {"lr": 0.0006}},
{"type": "cpu_torch_adam", "params": {"lr": 0.0006}},
{"type": "sm3", "params": {"lr": 0.0006}},
{"type": "lion", "params": {"lr": 0.0006}},
{"type": "madgrad_wd", "params": {"lr": 0.0006}},
]
}
Expand Down
1 change: 1 addition & 0 deletions tests/model/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def wrapper():
{"type": "cpu_adam", "params": {"lr": 0.0006}},
{"type": "cpu_torch_adam", "params": {"lr": 0.0006}},
{"type": "sm3", "params": {"lr": 0.0006}},
{"type": "lion", "params": {"lr": 0.0006}},
{"type": "madgrad_wd", "params": {"lr": 0.0006}},
]
}
Expand Down
Loading