-
Notifications
You must be signed in to change notification settings - Fork 970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to preserve Pythia's sampling order but for different batch size. #984
Comments
@haileyschoelkopf @uSaiPrashanth maybe you both have any idea to this issue? |
From what I have observed, as long as you keep the number of epochs and sequence length the same, your batch size (or) number of train iters should not matter (ref: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/data/gpt2_dataset.py#L187)
Could you check if this changes the number of epochs you're training on? |
Here's a hack that should get around this: Keep |
Describe the bug
I'd like to observe whether there is any substantial effects of using different batch sizes. It makes sense to use the exact same sampling order as was done on Pythia. To do this, the idea is to set the same number of tokens for each batch size variable and increasing or decreasing the train-iters accordingly.
Double checking sampling order with
utils/batch_viewer.py
from Pythia, it seems that changingtrain_micro_batch_size_per_gpu
while keepingtrain-iters
the same doesn't affect sampling order. Modifyingtrain-iters
based ontrain_micro_batch_size_per_gpu
to keep total number of tokens the same for each run results in different ordering.These configuration results in the same ordering.
This will be an issue if we want to train the same as the original Pythia (300B) tokens because changing
train-iters
changes the ordering and keeping it while changingtrain_micro_batch_size_per_gpu
will not result in the same amount of tokens.To Reproduce
utils/batch_viewer.py
withutils/dummy_config.yml
adjusted. I only observed the first 2 steps for bs512 and first step for bs1024.Expected behavior
A clear and concise description of what you expected to happen.
Proposed solution
Not yet sure what the direct solution is, this might be an issue in how the dataset is loaded based on batch size.
Screenshots
If applicable, add screenshots to help explain your problem.
Environment (please complete the following information):
Additional context
Add any other context about the problem here.
The text was updated successfully, but these errors were encountered: