Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed AnnealingLR Class and Cosine Decay Schedule #1008

Merged
merged 2 commits into from
Aug 7, 2023
Merged

Conversation

kshitijkg
Copy link
Contributor

No description provided.

@CLAassistant
Copy link

CLAassistant commented Aug 5, 2023

CLA assistant check
All committers have signed the CLA.

@Quentin-Anthony
Copy link
Member

Important context from @kshitijkg on Discord:

So I looked into the LR schedule. Here is what I found. The current Cosine LR schedule in gpt neox is designed to produce the LR schedule we get, this is what is used for training pythia and all other models. Not a bug introduced by us, but it looks like a bug in GPT-NeoX

-The cosine schedule itself does not keep the min_lr in mind, the only place min_lr is used currently is when returning the LR: return max(lr, self.min_lr). So it does not go below the min_lr
-Secondly, there is this weird statement: num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter)
The above statement produces the weird artifact at the end of training, because when you are in the last warmup_iter iterations, num_iters will always be equal to self.end_iter - self.warmup_iter, which is fixed

Current formula used for the cosine decay:
cur_iter = min(cur_iter, total_iter - wamrup_iters) 
if cur_iter < wamrup_iters: do warmup
cur_iter = cur_iter - self.warmup_iter
lr =  0.5*(max_lr)*(1+cos(pi*cur_iter/total_iter))
lr = max(lr, min_lr)

What I think we should use:
if cur_iter < wamrup_iters: do warmup
cur_iter = cur_iter - self.warmup_iter
lr = min_lr + 0.5*(max_lr-min_lr)*(1+cos(pi*cur_iter/total_iter))

The flat part we observe is because of both the max function and the min part below:

4 plots:
A) Current 
B) Current without Max
C) Current without Max and Min
C) Proposed

To keep legend small: 
When I say max, I mean: lr = max(lr, min_lr)
When I say min, I mean: cur_iter = min(cur_iter, total_iter - wamrup_iters) 

image

@Quentin-Anthony
Copy link
Member

Overall, this is not a bug, it's intended behavior. The schedule you propose:

lr = min_lr + 0.5*(max_lr-min_lr)*(1+cos(pi*cur_iter/total_iter))

Is simply not cosine learning rate decay because the rate has been reduced by min_lr, which will lead to significantly higher LRs near the end of training like in your figure.

If you have evidence this schedule performs better and would like it to be introduced as an alternative to cosine decay, that's fine, but it shouldn't replace cosine learning decay.

As for the num_iters_ = min(self.num_iters, self.end_iter - self.warmup_iter), I agree it's strange. I would expect it to be: num_iters_ = self.num_iters - self.warmup_iter like in https://github.com/NVIDIA/Megatron-LM/blob/0609f27fe8376f17ab65c001d3d8f35cd8175950/megatron/optimizer_param_scheduler.py#L101C27-L101C37

@kshitijkg
Copy link
Contributor Author

kshitijkg commented Aug 7, 2023

Hi Quentin! Thank you for the information. I was curious if thats what is used generally and it might be worth doing an ablation to see what works well with typical datasets like LLaMA and Pile?

Upon more investigation I found that other repositories like Megatron LM (https://github.com/NVIDIA/Megatron-LM/blob/0609f27fe8376f17ab65c001d3d8f35cd8175950/megatron/optimizer_param_scheduler.py#L77C9-L77C9), MPT (https://github.com/mosaicml/composer/blob/cc35953ef374b9aad17938d4fdc08cfc2d09fc42/composer/optim/scheduler.py#L384), and PyTorch (https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR) use the one I proposed above. I think this is the same one used to train Chinchilla, and Gopher as well. Its the one mentioned in this paper: https://arxiv.org/pdf/1608.03983v5.pdf

@Quentin-Anthony
Copy link
Member

Hi Quentin! Thank you for the information. I was curious if thats what is used generally and it might be worth doing an ablation to see what works well with typical datasets like LLaMA and Pile?

Upon more investigation I found that other repositories like Megatron LM (https://github.com/NVIDIA/Megatron-LM/blob/0609f27fe8376f17ab65c001d3d8f35cd8175950/megatron/optimizer_param_scheduler.py#L77C9-L77C9), MPT (https://github.com/mosaicml/composer/blob/cc35953ef374b9aad17938d4fdc08cfc2d09fc42/composer/optim/scheduler.py#L384), and PyTorch (https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR) use the one I proposed above. I think this is the same one used to train Chinchilla, and Gopher as well. Its the one mentioned in this paper: https://arxiv.org/pdf/1608.03983v5.pdf

Ah you're correct and I was mistaken. I'm happy with this change.

@kshitijkg kshitijkg marked this pull request as ready for review August 7, 2023 16:06
@kshitijkg kshitijkg requested a review from a team as a code owner August 7, 2023 16:06
@Quentin-Anthony
Copy link
Member

@StellaAthena and @haileyschoelkopf -- FYI

@Quentin-Anthony Quentin-Anthony merged commit 009018e into main Aug 7, 2023
2 checks passed
@Quentin-Anthony Quentin-Anthony deleted the fix_cosine branch August 7, 2023 16:10
kshitijkg added a commit to CERC-AAI/multimodal that referenced this pull request Aug 7, 2023
Fixed AnnealingLR Class and Cosine Decay Schedule (EleutherAI#1008)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants