Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added infinite lr schedules #1194

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
81 changes: 77 additions & 4 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ LR Scheduler Arguments



- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential']
- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential', 'infinite_cosine', 'infinite_inv_sqrt']

Default = linear

Expand All @@ -31,6 +31,14 @@ LR Scheduler Arguments



- **constant_lr**: float

Default = 0.0

Constant learning rate when using infinite cosine or infinite inv sqrt decay styles.



- **warmup**: float

Default = 0.01
Expand All @@ -39,6 +47,30 @@ LR Scheduler Arguments



- **cooldown_iters_perc**: float

Default = 0.0

Percentage of total iterations to cooldown for.



- **constant_iters_perc**: float

Default = 0.0

Percentage of total iterations to keep the learning rate constant for.



- **timescale**: float

Default = 1.0

Timescale for the steepness of the inverse square root cooldown.



- **override_lr_scheduler**: bool

Default = False
Expand Down Expand Up @@ -111,7 +143,7 @@ Logging Arguments

- **git_hash**: str

Default = 6fb840e
Default = 924d2e5

current git hash of repository

Expand Down Expand Up @@ -1201,7 +1233,7 @@ Text Generation arguments



- **num_experts**: int
- **moe_num_experts**: int

Default = 1

Expand Down Expand Up @@ -1243,7 +1275,7 @@ Text Generation arguments

- **moe_token_dropping**: bool

Default = True
Default = False

Whether to drop tokens when exceeding capacity

Expand Down Expand Up @@ -1273,6 +1305,47 @@ Text Generation arguments



- **moe_type**: str

Default = megablocks

Either `deepspeed` or `megablocks`



- **moe_glu**: bool

Default = False

Use gated linear units in MoE



- **moe_lbl_in_fp32**: bool

Default = False

Whether to compute the load balancing loss in fp32.



- **moe_jitter_eps**: float

Default = None

Coefficient for MoE routing jitter. Jitter is
not used if set to None



- **enable_expert_tensor_parallelism**: bool

Default = False

Enable expert tensor parallelism



## NeoXArgsTokenizer

Tokenizer Arguments
Expand Down
36 changes: 36 additions & 0 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def __init__(
decay_style,
last_iter,
min_lr=0.0,
constant_lr=0.0,
constant_iters=None,
cooldown_iters=None,
timescale=None,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False,
use_mup=False,
Expand All @@ -43,9 +47,13 @@ def __init__(
self.optimizer = optimizer
self.start_lr = start_lr
self.min_lr = min_lr
self.constant_lr = constant_lr
self.warmup_iter = warmup_iter
self.num_iters = last_iter
self.end_iter = total_iters
self.constant_iters = constant_iters
self.cooldown_iters = cooldown_iters
self.timescale = timescale
assert self.end_iter > 0
self.decay_style = decay_style
self.override_lr_scheduler = override_lr_scheduler
Expand Down Expand Up @@ -84,6 +92,34 @@ def get_lr(self):
# exp(-0.693) = 1/2
end_iter = self.end_iter - self.warmup_iter
lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter)
elif self.decay_style == "infinite_cosine" or self.decay_style == "infinite_inv_sqrt":
if num_iters_ <= self.cooldown_iters:
if self.decay_style == "infinite_cosine":
lr = self.constant_lr + (
(self.start_lr-self.constant_lr)
/ 2.0
* (math.cos(math.pi * num_iters_ / self.cooldown_iters) + 1)
)
else:
def inv_f(t):
return (1/math.sqrt(1+(self.timescale*t))) - 1
lr = self.start_lr + (
(self.constant_lr - self.start_lr)
/ inv_f(1)
* (inv_f(num_iters_ / self.cooldown_iters))
)
return lr
else:
num_iters_ = num_iters_ - self.cooldown_iters
if num_iters_ <= self.constant_iters:
# Stay constant for constant_iters
lr = self.constant_lr
else:
# Go from constant iters to min LR using exponential decay in remaining iters
end_iter_ = self.end_iter - self.warmup_iter - self.cooldown_iters - self.constant_iters
num_iters_ = num_iters_ - self.constant_iters
exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_
lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_)
else:
lr = self.start_lr
return max(lr, self.min_lr)
Expand Down
22 changes: 21 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
LR Scheduler Arguments
"""

lr_decay_style: Literal["constant", "linear", "cosine", "exponential"] = "linear"
lr_decay_style: Literal["constant", "linear", "cosine", "exponential", "infinite_cosine", "infinite_inv_sqrt"] = "linear"
"""
Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'.
"""
Expand All @@ -547,11 +547,31 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate):
Minimum value for learning rate. The scheduler clips values below this threshold.
"""

constant_lr: float = 0.0
"""
Constant learning rate when using infinite cosine or infinite inv sqrt decay styles.
"""

warmup: float = 0.01
"""
Percentage of total iterations to warmup on (.01 = 1 percent of all training iters).
"""

cooldown_iters_perc: float = 0.0
"""
Percentage of total iterations to cooldown for.
"""

constant_iters_perc: float = 0.0
"""
Percentage of total iterations to keep the learning rate constant for.
"""

timescale: float = 1.0
"""
Timescale for the steepness of the inverse square root cooldown.
"""

override_lr_scheduler: bool = False
"""
Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.
Expand Down
6 changes: 6 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,8 @@ def get_learning_rate_scheduler(optimizer, neox_args):
num_iters = max(1, num_iters)
init_step = 0
warmup_iter = neox_args.warmup * num_iters
constant_iters = neox_args.constant_iters_perc * num_iters
cooldown_iters = neox_args.cooldown_iters_perc * num_iters
lr_scheduler = AnnealingLR(
optimizer,
start_lr=neox_args.lr,
Expand All @@ -721,6 +723,10 @@ def get_learning_rate_scheduler(optimizer, neox_args):
decay_style=neox_args.lr_decay_style,
last_iter=init_step,
min_lr=neox_args.min_lr,
constant_lr=neox_args.constant_lr,
constant_iters=constant_iters,
cooldown_iters=cooldown_iters,
timescale=neox_args.timescale,
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler,
override_lr_scheduler=neox_args.override_lr_scheduler,
use_mup=neox_args.use_mup,
Expand Down