-
Notifications
You must be signed in to change notification settings - Fork 311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add 'on-the-fly' sample packing #1109
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1109
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cdf5cdf with merge base abe798d ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi Joe, |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1109 +/- ##
===========================================
+ Coverage 26.67% 66.72% +40.04%
===========================================
Files 183 184 +1
Lines 8337 8586 +249
===========================================
+ Hits 2224 5729 +3505
+ Misses 6113 2857 -3256 ☔ View full report in Codecov by Sentry. |
Yep, constructing the mask during training reduces memory by about 99% and only slightly slows down processing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
constructing the mask during training reduces memory by about 99%
Incredible work, after seeing this I don't know what else to say on this awesome PR. Only have one blocking comment.
Also, how difficult would it be to also move the input_pos
creation to getitem
as well? Probably not as significant of a memory saver as the mask, but might still be worthwhile
@@ -71,7 +73,7 @@ class on packed samples with a ``Sampler`` as part of the dataloader. Currently, | |||
ds (Dataset): dataset to sample pack. This should return a dictionary with field | |||
"tokens" and "labels" containing the tokenized and label samples. | |||
max_seq_len (int): Maximum number of tokens to pack | |||
max_packs (Optional[int]): maximum number of packs. Default is None, which will create as many | |||
max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grammar police 🚓 🚓 🚓
torchtune/datasets/_packed.py
Outdated
# Handle the last pack if there's leftover and we haven't filled up the max packs | ||
if ( | ||
len(current_pack["tokens"]) > 0 | ||
and self.max_packs is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.max_packs is None
and we have a leftover pack, then this condition won't be true when we would want it to be, hence the self.max_packs is None or
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh good catch - let me also make sure this is caught in a test.
"tokens": torch.tensor(pack["tokens"]), | ||
"labels": torch.tensor(pack["labels"]), | ||
"input_pos": torch.tensor(pack["input_pos"]), | ||
"seq_lens": pack["seq_lens"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we just standardize on torch.tensor here, or would that make the enumerate in getitem more messy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ehh, it just takes up more memory and isn't necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am inclined to agree with @RdoubleA here. Is the memory increase just tensor metadata? If so should be pretty small, right? Advantage is that we can then scrap this whole PACK_TYPE
business. Either way not a blocker
Great question! I think I actually could do this w just the |
@@ -91,8 +111,13 @@ def test_packed_dataset( | |||
max_packs=max_packs, | |||
split_across_pack=split_across_pack, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to add in a padding idx here just to test it?
block_attn_masks.append( | ||
torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)) | ||
) | ||
|
||
# If we're at the last sample and the total seq len is less than the max seq len, | ||
# we need to pad with identity matrix for the remainder | ||
if i == num_samples_in_pack - 1 and total_seq_len < self.max_seq_len: | ||
block_attn_masks.append( | ||
torch.eye( | ||
self.max_seq_len - total_seq_len, | ||
self.max_seq_len - total_seq_len, | ||
dtype=torch.bool, | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a major blocker for this PR, but just fyi it is probably faster to do this in a list comprehension instead of append. See e.g. here. If number of samples per pack is generally small prob not a huge deal, but we will be doing this for every sample so it may add up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Filed #1130 for a follow-up
padded_labels = F.pad( | ||
pack["labels"], | ||
(0, self.max_seq_len - len(pack["labels"])), | ||
value=CROSS_ENTROPY_IGNORE_IDX, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's weird that we pass padding_idx but still have CROSS_ENTROPY_IGNORE_IDX
as a global constant. (I'm just complaining, don't worry about it here)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I dislike this as well b/c it means that we don't actually support losses that do NOT use this ignore index OOTB. BUT, we make this assumption throughout the repo so there's no point in deviating here in this PR.
Context
As investigated in #1097, it was shown that the offline approach to constructing the mask consumed waaaaaay too much memory. Therefore, this approach constructs tokens, labels, and input_pos offline and then constructs the mask during access (training). For a max_seq_len of 4096 (default for many models), we can expect the memory of a single pack to look like the following offline:
To provide a real-world example, let's use the Web Instruct Dataset from Tiger Labs. It comes in at 3.51 GB of size with 2.3 million samples. The average sample length (with instruct template applied) is about 100 tokens. This means that 40 samples fit in each pack if we don't split across packs. Therefore we can expect there to be about 57,500 packs. This number times 0.1MB is 5.75GB additional memory bringing the total on-disk memory needed to load (before training) this dataset is 9.26GB, well within reasonable bounds.
Why do we need
seq_lens
?: Technically we could calculate this using theinput_pos
, but this would save us negligible memory and increase processing time during training, which is undesirable.Why are you using this dataset? It's a large dataset downloaded 33,026 times in the last month. Good a baseline as any.
Why did you update the signature to take in a padding_idx and hardcode in CROSS_ENTROPY_IGNORE_IDX? Excellent question. So before, the packed dataset made the assumption that padding_idx = 0 and to use the CROSS_ENTROPY_IGNORE_IDX. The former is NOT an assumption we can make therefore it should be actually configurable and the latter IS a reasonable assumption so we should just hardcore it instead of defaulting the param (which won't get used).
Changelog
PACK_TYPE
so I don't have to keep typing itseq_lens
to offline variables to hold information on each seq len in the pack (useful for calculating mask on the fly)Test plan
All are passing
Using this gist: https://gist.github.com/joecummings/05586af0a08eef0714c7da3c56ee7365
Only packing 1% of the dataset which is 23k samples. Using our calculation from above we expect memory usage with the new implementation to take an additional 0.58GB.
Our calculation looks pretty spot on for how much memory the new implementation should take. And it makes sense that there would be a little more memory used when the mask is constructed during dataloading.
Not surprising that the old packing takes much longer than the new masking.
Why do we need to do this? Well, the above "loading" is not a true measurement of how packing a dataset will affect the training time. For instance, we are now passing in a constructed mask for attention instead of relying on SPDA to construct one for us.
CMD:
YAY, it's just (kinda) as fast as the old implementation in waaaaaay less memory.