-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Negative document indices caused by 64 bit integer stored in a 32 bit integer array. #493
Comments
This seems like a generally good idea, though I’m very intrigued about some of your config choices. You’re finetuning a 13B parameter model with a sequence length of 8192? And doing more than 10 epochs on the Pile? |
Great, I'll get to working on a PR! Realistically I won't be able to train on 10 epochs. I was just changing around parameters on the config and running little tests, and I stumbled across this bug by accident. I am curious whether the model can learn to process long-form text though, hence the 8192 sequence length. Unrelated, I noticed that model checkpoints are stored in a way that is specific to the 3D parallelism config. Is it possible to take a checkpoint that used a model parallelism of 4 and update it to a model parallelism of 8? I was thinking it should be possible to write a conversion script that copies all the weights over into the right locations, but wasn't sure if something like that already existed. |
That is a functionality we are currently exploring. It is unfortunately non-trivial :/ |
@pwstegman did you ever solve this issue? |
@haileyschoelkopf @ShivanshuPurohit @Quentin-Anthony this was the issue that you independently discovered and then patched right? |
Yeah this should be fixed by #835 |
Describe the bug
While training on The Pile, I was getting errors from sparse attention, claiming that the sequence length wasn't divisible by the block size, despite using a sequence length of 8192 and a block size of 16. This was caused by negative document indices in the dataset, which caused weird sample lengths (screenshot included in screenshot section). The negative document indices were caused by a wraparound at the 32 bit signed integer limit. More info in the Proposed Solution section.
To Reproduce
I'm using the docker image
leogao2/gpt-neox:sha-6dc7645
. My training script is:Configs are included in the Environment section.
Expected behavior
Each sample should be exactly 8193 tokens.
Proposed solution
In short, I traced the issue to this function:
gpt-neox/megatron/data/helpers.cpp
Line 100 in 98683ae
It keeps looping until the target number of samples is reached:
gpt-neox/megatron/data/helpers.cpp
Line 145 in 98683ae
There are only ~200m documents, but since the loop covers multiple epochs, and there may be multiple documents per sample, the document index quickly went over 2.1 billion. The document index variable itself is a 64 bit signed integer:
gpt-neox/megatron/data/helpers.cpp
Line 137 in 98683ae
However, it's being stored in a 32 bit signed integer array, and this is where the wraparound to negative values happens:
gpt-neox/megatron/data/helpers.cpp
Line 122 in 98683ae
gpt-neox/megatron/data/helpers.cpp
Line 168 in 98683ae
To solve this, I propose:
I can submit a PR to take care of both of these if this sounds reasonable.
Screenshots
Here's one screenshot I took which highlights the core of the issue:
Environment (please complete the following information):
13B.yml
local_setup.yml
sparse.yml
Additional context
None
The text was updated successfully, but these errors were encountered: