Skip to content

Commit

Permalink
Lion Optimizer (#1062)
Browse files Browse the repository at this point in the history
* initial commit

* test set, fixed readme and docstring

* Refactor Lion implementation

---------

Co-authored-by: kamathis4 <[email protected]>
  • Loading branch information
andylolu2 and adi-kmt committed Oct 20, 2023
1 parent e001a04 commit b02d989
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 4 deletions.
4 changes: 2 additions & 2 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,11 @@ Optimizer Arguments



- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
- **optimizer_type**: typing.Literal['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion']

Default = adam

Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd', 'lion']
NOTE: sgd will use MuSGD from Mup. Mup must be enabled for this optimizer.


Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ class NeoXArgsOptimizer(NeoXArgsTemplate):
"""

optimizer_type: Literal[
"adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd"
"adam", "onebitadam", "cpu_adam", "cpu_torch_adam", "sm3", "madgrad_wd", "sgd", "lion"
] = "adam"
"""
Type of optimizer to use. Choose from ['adam', 'onebitadam', 'cpu_adam', 'cpu_torch_adam', 'sm3', 'madgrad_wd', 'sgd']
Expand Down
84 changes: 83 additions & 1 deletion megatron/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def _max_reduce_except_dim(tensor, dim):
# closure is checked if callable or not since some code passes loss directly, rather than in closure param

import math
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional
from typing import Collection, TYPE_CHECKING, Any, Callable, Optional, Tuple

import torch
import torch.optim
Expand Down Expand Up @@ -413,3 +413,85 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]

self.state["k"] += 1
return loss


class Lion(Optimizer):
"""
Implements the Lion Algorithm
.. / _Lion: https://arxiv.org/abs/2302.06675
Compared to AdamW and various adaptive optimizers that need to save both first and second moments,
Lion only needs the momentum, halving the additional memory footprint. This is beneficial when training large models
and / or with a large batch size.
Arguments:
params (iterable):
Iterable of parameters to optimize or dicts defining parameter groups.
lr (float):
Learning rate (default: 1e-2).
beta (float):
coefficients used for computing running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float):
Weight decay, i.e. a L2 penalty (default: 0).
"""

def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
if lr <= 0:
raise ValueError(f"Learning rate {lr} must be positive")
if weight_decay < 0:
raise ValueError(f"Weight decay {weight_decay} must be non-negative")
if not (0 <= betas[0] <= 1 and 0 <= betas[1] <= 1):
raise ValueError(f"Betas {betas} must be in range [0, 1)")

defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)

def update(self, p, grad, exp_avg, lr, wd, beta1, beta2):
"""https://arxiv.org/pdf/2302.06675.pdf#appendix.A"""

# update model parameters
p.mul_(1 - lr * wd)
sign = exp_avg.clone().mul_(beta1).add(grad, alpha=1 - beta1).sign_()
p.add_(sign, alpha=-lr)

# update EMA
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

@torch.no_grad()
def step(self, closure: Optional[Callable] = None):

loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

state = self.state[p]

# init state - exponential moving average of gradient values
if len(state) == 0:
state["exp_avg"] = torch.zeros_like(p.data).detach()

self.update(
p,
p.grad,
state["exp_avg"],
group["lr"],
group["weight_decay"],
group["betas"][0],
group["betas"][1],
)

return loss
8 changes: 8 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,14 @@ def get_optimizer(model, neox_args):
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"],
)
elif neox_args.optimizer_type.lower() == "lion":
from .optimizers import Lion

optimizer = Lion(
param_groups,
weight_decay=neox_args.weight_decay,
**neox_args.optimizer["params"]
)
elif neox_args.optimizer_type.lower() == "adam":
# Use Adam
if neox_args.use_mup:
Expand Down
1 change: 1 addition & 0 deletions tests/model/test_model_instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def wrapper():
{"type": "cpu_adam", "params": {"lr": 0.0006}},
{"type": "cpu_torch_adam", "params": {"lr": 0.0006}},
{"type": "sm3", "params": {"lr": 0.0006}},
{"type": "lion", "params": {"lr": 0.0006}},
{"type": "madgrad_wd", "params": {"lr": 0.0006}},
]
}
Expand Down
1 change: 1 addition & 0 deletions tests/model/test_model_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def wrapper():
{"type": "cpu_adam", "params": {"lr": 0.0006}},
{"type": "cpu_torch_adam", "params": {"lr": 0.0006}},
{"type": "sm3", "params": {"lr": 0.0006}},
{"type": "lion", "params": {"lr": 0.0006}},
{"type": "madgrad_wd", "params": {"lr": 0.0006}},
]
}
Expand Down

0 comments on commit b02d989

Please sign in to comment.