From 5f2a3f2f2b5f5274024dc3dce96b9a2af584a2ad Mon Sep 17 00:00:00 2001 From: Kshitij Date: Mon, 25 Mar 2024 20:44:21 +0530 Subject: [PATCH 1/8] Added infinite lr schedules --- megatron/learning_rates.py | 36 ++++++++++++++++++++++++++++ megatron/neox_arguments/neox_args.py | 22 ++++++++++++++++- megatron/training.py | 6 +++++ 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 9db951aa0..092001d6a 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -34,6 +34,10 @@ def __init__( decay_style, last_iter, min_lr=0.0, + constant_lr=0.0, + constant_iters=None, + cooldown_iters=None, + timescale=None, use_checkpoint_lr_scheduler=True, override_lr_scheduler=False, use_mup=False, @@ -43,9 +47,13 @@ def __init__( self.optimizer = optimizer self.start_lr = start_lr self.min_lr = min_lr + self.constant_lr = constant_lr self.warmup_iter = warmup_iter self.num_iters = last_iter self.end_iter = total_iters + self.constant_iters = constant_iters + self.cooldown_iters = cooldown_iters + self.timescale = timescale assert self.end_iter > 0 self.decay_style = decay_style self.override_lr_scheduler = override_lr_scheduler @@ -84,6 +92,34 @@ def get_lr(self): # exp(-0.693) = 1/2 end_iter = self.end_iter - self.warmup_iter lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter) + elif self.decay_style == "infinite_cosine" or self.decay_style == "infinite_inv_sqrt": + if num_iters_ <= self.cooldown_iter: + if self.decay_style == "infinite_cosine": + lr = self.constant_lr + ( + (self.start_lr-self.constant_lr) + / 2.0 + * (math.cos(math.pi * num_iters_ / self.cooldown_iter) + 1) + ) + else: + def inv_f(t): + return (1/math.sqrt(1+(self.timescale*t))) - 1 + lr = self.start_lr + ( + (self.constant_lr - self.start_lr) + / inv_f(1) + * (inv_f(num_iters_ / self.cooldown_iter)) + ) + return lr + else: + num_iters_ = num_iters_ - self.cooldown_iter + if num_iters_ <= self.constant_iter: + # Stay constant for constant_iters + lr = self.constant_lr + else: + # Go from constant iters to min LR using exponential decay in remaining iters + end_iter_ = self.end_iter - self.warmup_iter - self.constant_iter + num_iters_ = num_iters_ - self.constant_iter + exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_ + lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_) else: lr = self.start_lr return max(lr, self.min_lr) diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 16d6456b4..f793cfa36 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -531,7 +531,7 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): LR Scheduler Arguments """ - lr_decay_style: Literal["constant", "linear", "cosine", "exponential"] = "linear" + lr_decay_style: Literal["constant", "linear", "cosine", "exponential", "infinite_cosine", "infinite_inv_sqrt"] = "linear" """ Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'. """ @@ -546,11 +546,31 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): Minimum value for learning rate. The scheduler clips values below this threshold. """ + constant_lr: float = 0.0 + """ + Constant learning rate when using infinite cosine or infinite inv sqrt decay styles. + """ + warmup: float = 0.01 """ Percentage of total iterations to warmup on (.01 = 1 percent of all training iters). """ + cooldown_iters_perc: float = 0.0 + """ + Percentage of total iterations to cooldown for. + """ + + constant_iters_perc: float = 0.0 + """ + Percentage of total iterations to keep the learning rate constant for. + """ + + timescale: float = 1.0 + """ + Timescale for the steepness of the inverse square root cooldown. + """ + override_lr_scheduler: bool = False """ Reset the values of the scheduler (learning rate,warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset. diff --git a/megatron/training.py b/megatron/training.py index a0e408694..88fda089a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -656,6 +656,8 @@ def get_learning_rate_scheduler(optimizer, neox_args): num_iters = max(1, num_iters) init_step = 0 warmup_iter = neox_args.warmup * num_iters + constant_iters = neox_args.constant_iters_perc * num_iters + cooldown_iters = neox_args.cooldown_iters_perc * num_iters lr_scheduler = AnnealingLR( optimizer, start_lr=neox_args.lr, @@ -664,6 +666,10 @@ def get_learning_rate_scheduler(optimizer, neox_args): decay_style=neox_args.lr_decay_style, last_iter=init_step, min_lr=neox_args.min_lr, + constant_lr=neox_args.constant_lr, + constant_iters=constant_iters, + cooldown_iters=cooldown_iters, + timescale=neox_args.timescale, use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler, override_lr_scheduler=neox_args.override_lr_scheduler, use_mup=neox_args.use_mup, From 499f65fc32cc418b4e07afa3d4dd9ffdc848c44c Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 25 Mar 2024 15:15:03 +0000 Subject: [PATCH 2/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 8b301cf8b..408331f6f 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -7,7 +7,7 @@ LR Scheduler Arguments -- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential'] +- **lr_decay_style**: typing.Literal['constant', 'linear', 'cosine', 'exponential', 'infinite_cosine', 'infinite_inv_sqrt'] Default = linear @@ -31,6 +31,14 @@ LR Scheduler Arguments +- **constant_lr**: float + + Default = 0.0 + + Constant learning rate when using infinite cosine or infinite inv sqrt decay styles. + + + - **warmup**: float Default = 0.01 @@ -39,6 +47,30 @@ LR Scheduler Arguments +- **cooldown_iters_perc**: float + + Default = 0.0 + + Percentage of total iterations to cooldown for. + + + +- **constant_iters_perc**: float + + Default = 0.0 + + Percentage of total iterations to keep the learning rate constant for. + + + +- **timescale**: float + + Default = 1.0 + + Timescale for the steepness of the inverse square root cooldown. + + + - **override_lr_scheduler**: bool Default = False @@ -111,7 +143,7 @@ Logging Arguments - **git_hash**: str - Default = f70c54d + Default = 5f2a3f2 current git hash of repository From b87692e8921f11e8a74f54c2f99a5e6860553350 Mon Sep 17 00:00:00 2001 From: github-actions Date: Sun, 21 Apr 2024 22:15:49 +0000 Subject: [PATCH 3/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index cfd272250..c104f5ff4 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -143,7 +143,7 @@ Logging Arguments - **git_hash**: str - Default = 11a5537 + Default = c7ed9ae current git hash of repository From cc9cab2e27649022d49168c366ce0fc2900005c6 Mon Sep 17 00:00:00 2001 From: Kshitij Date: Wed, 24 Apr 2024 17:01:13 -0700 Subject: [PATCH 4/8] fixed bug --- megatron/learning_rates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 092001d6a..421697ede 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -116,7 +116,7 @@ def inv_f(t): lr = self.constant_lr else: # Go from constant iters to min LR using exponential decay in remaining iters - end_iter_ = self.end_iter - self.warmup_iter - self.constant_iter + end_iter_ = self.end_iter - self.warmup_iter - self.cooldown_iter - self.constant_iter num_iters_ = num_iters_ - self.constant_iter exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_ lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_) From b6eefceb4d3f6be7faae6563b3bf8766b06d3db4 Mon Sep 17 00:00:00 2001 From: github-actions Date: Thu, 25 Apr 2024 00:01:36 +0000 Subject: [PATCH 5/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index c104f5ff4..169d921dd 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -143,7 +143,7 @@ Logging Arguments - **git_hash**: str - Default = c7ed9ae + Default = cc9cab2 current git hash of repository From 43ccefe6a3e083b6a15e8a3884fa7f17fab0aca9 Mon Sep 17 00:00:00 2001 From: github-actions Date: Mon, 6 May 2024 17:15:59 +0000 Subject: [PATCH 6/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 47 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index 2a3713a85..ad8119d5f 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -143,7 +143,7 @@ Logging Arguments - **git_hash**: str - Default = 6fb840e + Default = 2e769fb current git hash of repository @@ -1233,7 +1233,7 @@ Text Generation arguments -- **num_experts**: int +- **moe_num_experts**: int Default = 1 @@ -1275,7 +1275,7 @@ Text Generation arguments - **moe_token_dropping**: bool - Default = True + Default = False Whether to drop tokens when exceeding capacity @@ -1305,6 +1305,47 @@ Text Generation arguments +- **moe_type**: str + + Default = megablocks + + Either `deepspeed` or `megablocks` + + + +- **moe_glu**: bool + + Default = False + + Use gated linear units in MoE + + + +- **moe_lbl_in_fp32**: bool + + Default = False + + Whether to compute the load balancing loss in fp32. + + + +- **moe_jitter_eps**: float + + Default = None + + Coefficient for MoE routing jitter. Jitter is + not used if set to None + + + +- **enable_expert_tensor_parallelism**: bool + + Default = False + + Enable expert tensor parallelism + + + ## NeoXArgsTokenizer Tokenizer Arguments From aa7c8b4218d700f1949db645fc2078f9f672a84d Mon Sep 17 00:00:00 2001 From: "hatef.4" Date: Tue, 14 May 2024 13:45:31 -0400 Subject: [PATCH 7/8] fixed typos --- megatron/learning_rates.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/megatron/learning_rates.py b/megatron/learning_rates.py index 421697ede..0b9d0c128 100644 --- a/megatron/learning_rates.py +++ b/megatron/learning_rates.py @@ -93,12 +93,12 @@ def get_lr(self): end_iter = self.end_iter - self.warmup_iter lr = self.start_lr * math.exp(-0.693 * num_iters_ / end_iter) elif self.decay_style == "infinite_cosine" or self.decay_style == "infinite_inv_sqrt": - if num_iters_ <= self.cooldown_iter: + if num_iters_ <= self.cooldown_iters: if self.decay_style == "infinite_cosine": lr = self.constant_lr + ( (self.start_lr-self.constant_lr) / 2.0 - * (math.cos(math.pi * num_iters_ / self.cooldown_iter) + 1) + * (math.cos(math.pi * num_iters_ / self.cooldown_iters) + 1) ) else: def inv_f(t): @@ -106,18 +106,18 @@ def inv_f(t): lr = self.start_lr + ( (self.constant_lr - self.start_lr) / inv_f(1) - * (inv_f(num_iters_ / self.cooldown_iter)) + * (inv_f(num_iters_ / self.cooldown_iters)) ) return lr else: - num_iters_ = num_iters_ - self.cooldown_iter - if num_iters_ <= self.constant_iter: + num_iters_ = num_iters_ - self.cooldown_iters + if num_iters_ <= self.constant_iters: # Stay constant for constant_iters lr = self.constant_lr else: # Go from constant iters to min LR using exponential decay in remaining iters - end_iter_ = self.end_iter - self.warmup_iter - self.cooldown_iter - self.constant_iter - num_iters_ = num_iters_ - self.constant_iter + end_iter_ = self.end_iter - self.warmup_iter - self.cooldown_iters - self.constant_iters + num_iters_ = num_iters_ - self.constant_iters exp_factor = -math.log(self.min_lr/self.constant_lr) / end_iter_ lr = self.constant_lr * math.exp(-1* exp_factor * num_iters_) else: From e217ed7eaeb60b1d7b5da1d782cf0f8eddf5e3d9 Mon Sep 17 00:00:00 2001 From: github-actions Date: Tue, 14 May 2024 17:53:28 +0000 Subject: [PATCH 8/8] Update NeoXArgs docs automatically --- configs/neox_arguments.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index ad8119d5f..0848c480e 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -143,7 +143,7 @@ Logging Arguments - **git_hash**: str - Default = 2e769fb + Default = 924d2e5 current git hash of repository