Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'on-the-fly' sample packing #1109

Merged
merged 8 commits into from
Jun 27, 2024

Conversation

joecummings
Copy link
Contributor

@joecummings joecummings commented Jun 21, 2024

Context

As investigated in #1097, it was shown that the offline approach to constructing the mask consumed waaaaaay too much memory. Therefore, this approach constructs tokens, labels, and input_pos offline and then constructs the mask during access (training). For a max_seq_len of 4096 (default for many models), we can expect the memory of a single pack to look like the following offline:

tokens: 88 (fixed size of Python object) + 8 (size of torch.int64) * 4096 = 32,896
labels: 88 + 8 * 4096 = 32,896
input_pos: 88 + 8 * 4096 = 32,896
seq_lens: 226 <-- varies based on num of samples that we do, but this is avg based on experiments
----------------------------------------
98,984 bytes ~= 0.1 MB

To provide a real-world example, let's use the Web Instruct Dataset from Tiger Labs. It comes in at 3.51 GB of size with 2.3 million samples. The average sample length (with instruct template applied) is about 100 tokens. This means that 40 samples fit in each pack if we don't split across packs. Therefore we can expect there to be about 57,500 packs. This number times 0.1MB is 5.75GB additional memory bringing the total on-disk memory needed to load (before training) this dataset is 9.26GB, well within reasonable bounds.

Why do we need seq_lens?: Technically we could calculate this using the input_pos, but this would save us negligible memory and increase processing time during training, which is undesirable.

Why are you using this dataset? It's a large dataset downloaded 33,026 times in the last month. Good a baseline as any.

Why did you update the signature to take in a padding_idx and hardcode in CROSS_ENTROPY_IGNORE_IDX? Excellent question. So before, the packed dataset made the assumption that padding_idx = 0 and to use the CROSS_ENTROPY_IGNORE_IDX. The former is NOT an assumption we can make therefore it should be actually configurable and the latter IS a reasonable assumption so we should just hardcore it instead of defaulting the param (which won't get used).

Changelog

  • Add PACK_TYPE so I don't have to keep typing it
  • Add seq_lens to offline variables to hold information on each seq len in the pack (useful for calculating mask on the fly)
  • Added method specifically to convert to tensor
  • Cleaned up logic of how to handle splitting packs
  • Updated getitem method to construct masks on the fly
  • Updated test to include testing on a larger max_packs
  • Modified chat, instruct, and text completion datasets to take pass in the tokenizer pad id to the packed dataset

Test plan

  1. Unit tests

All are passing

(joe-torchtune) [[email protected] ~/projects/joe-torchtune (pack-mask-on-the-fly)]$ pytest tests/torchtune/datasets/test_packed_dataset.py
================================================================ test session starts =================================================================
platform linux -- Python 3.11.9, pytest-8.1.1, pluggy-1.5.0
rootdir: /home/jrcummings/projects/joe-torchtune
configfile: pyproject.toml
plugins: integration-0.2.3, mock-3.14.0, cov-5.0.0, hypothesis-6.103.1
collected 11 items

tests/torchtune/datasets/test_packed_dataset.py ...........                                                                                    [100%]

================================================================= 11 passed in 5.03s =================================================================
  1. Direct memory / speed comparisons with old version

Using this gist: https://gist.github.com/joecummings/05586af0a08eef0714c7da3c56ee7365

Only packing 1% of the dataset which is 23k samples. Using our calculation from above we expect memory usage with the new implementation to take an additional 0.58GB.

Memory is an estimate based on psutil monitoring of virtual memory used. There are more things that affect this than just the dataset, but I think it gives us a good feel for memory usage. Also, I didn't want to go through and figure out how to zero out the psutil memory management in the same script so in between runs I just commented out the code I didn't care about in order to get memory estimates.

impl additional memory used for packing total additional memory used
old (all offline) 27 GB 27 GB
new (mask on the fly) 0.62 GB 1.14 GB

Our calculation looks pretty spot on for how much memory the new implementation should take. And it makes sense that there would be a little more memory used when the mask is constructed during dataloading.

impl time for packing
old (all offline) 134 s
new (mask on the fly) 22 s

Not surprising that the old packing takes much longer than the new masking.

  1. Compare iterations e2e during training

Why do we need to do this? Well, the above "loading" is not a true measurement of how packing a dataset will affect the training time. For instance, we are now passing in a constructed mask for attention instead of relying on SPDA to construct one for us.

CMD:

tune run lora_finetune_single_device \
    --config llama3/8B_lora_single_device \
    dataset._component_=torchtune.datasets.instruct_dataset \
    dataset.source=TIGER-LAB/WebInstructSub \
    template=torchtune.data.AlpacaInstructTemplate \
    column_map={"instruction":"question","output":"answer"} \
    max_seq_len=4096 \
    packed=True \
    split=train[:1%]
impl total time
old (all offline) 39 mins
new (mask on the fly) 48 mins
no packing 1 hr 58 mins

YAY, it's just (kinda) as fast as the old implementation in waaaaaay less memory.

Copy link

pytorch-bot bot commented Jun 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1109

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cdf5cdf with merge base abe798d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 21, 2024
@ScottHoang
Copy link

Hi Joe,
Thank you for your work! I have been trying to do something similar.
I'm just reading through your changes. Is the purpose of 'on-the-fly' packing to reduce overall memory overhead by generating the attn-mask on the fly instead of during _add_pack?

@codecov-commenter
Copy link

codecov-commenter commented Jun 24, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 66.72%. Comparing base (abe798d) to head (60d19ab).
Report is 5 commits behind head on main.

Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1109       +/-   ##
===========================================
+ Coverage   26.67%   66.72%   +40.04%     
===========================================
  Files         183      184        +1     
  Lines        8337     8586      +249     
===========================================
+ Hits         2224     5729     +3505     
+ Misses       6113     2857     -3256     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@joecummings joecummings changed the title [WIP] Add 'on-the-fly' packing Add 'on-the-fly' sample packing Jun 24, 2024
@joecummings joecummings marked this pull request as ready for review June 24, 2024 21:27
@joecummings joecummings linked an issue Jun 24, 2024 that may be closed by this pull request
@joecummings
Copy link
Contributor Author

Hi Joe, Thank you for your work! I have been trying to do something similar. I'm just reading through your changes. Is the purpose of 'on-the-fly' packing to reduce overall memory overhead by generating the attn-mask on the fly instead of during _add_pack?

Yep, constructing the mask during training reduces memory by about 99% and only slightly slows down processing.

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constructing the mask during training reduces memory by about 99%

Incredible work, after seeing this I don't know what else to say on this awesome PR. Only have one blocking comment.

Also, how difficult would it be to also move the input_pos creation to getitem as well? Probably not as significant of a memory saver as the mask, but might still be worthwhile

@@ -71,7 +73,7 @@ class on packed samples with a ``Sampler`` as part of the dataloader. Currently,
ds (Dataset): dataset to sample pack. This should return a dictionary with field
"tokens" and "labels" containing the tokenized and label samples.
max_seq_len (int): Maximum number of tokens to pack
max_packs (Optional[int]): maximum number of packs. Default is None, which will create as many
max_packs (Optional[int]): Maximum number of packs. Default is None, which will create as many
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grammar police 🚓 🚓 🚓

# Handle the last pack if there's leftover and we haven't filled up the max packs
if (
len(current_pack["tokens"]) > 0
and self.max_packs is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.max_packs is None and we have a leftover pack, then this condition won't be true when we would want it to be, hence the self.max_packs is None or

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good catch - let me also make sure this is caught in a test.

"tokens": torch.tensor(pack["tokens"]),
"labels": torch.tensor(pack["labels"]),
"input_pos": torch.tensor(pack["input_pos"]),
"seq_lens": pack["seq_lens"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we just standardize on torch.tensor here, or would that make the enumerate in getitem more messy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ehh, it just takes up more memory and isn't necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am inclined to agree with @RdoubleA here. Is the memory increase just tensor metadata? If so should be pretty small, right? Advantage is that we can then scrap this whole PACK_TYPE business. Either way not a blocker

@joecummings
Copy link
Contributor Author

Also, how difficult would it be to also move the input_pos creation to getitem as well? Probably not as significant of a memory saver as the mask, but might still be worthwhile

Great question! I think I actually could do this w just the seq_len information, but it would entail generating and concat-ing multiple tensor arrays during the getitem, which would slow down processing and only save a little bit of memory. I could do some tests to confirm the tradeoff though.

@@ -91,8 +111,13 @@ def test_packed_dataset(
max_packs=max_packs,
split_across_pack=split_across_pack,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to add in a padding idx here just to test it?

Comment on lines +269 to +282
block_attn_masks.append(
torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
)

# If we're at the last sample and the total seq len is less than the max seq len,
# we need to pad with identity matrix for the remainder
if i == num_samples_in_pack - 1 and total_seq_len < self.max_seq_len:
block_attn_masks.append(
torch.eye(
self.max_seq_len - total_seq_len,
self.max_seq_len - total_seq_len,
dtype=torch.bool,
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a major blocker for this PR, but just fyi it is probably faster to do this in a list comprehension instead of append. See e.g. here. If number of samples per pack is generally small prob not a huge deal, but we will be doing this for every sample so it may add up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #1130 for a follow-up

padded_labels = F.pad(
pack["labels"],
(0, self.max_seq_len - len(pack["labels"])),
value=CROSS_ENTROPY_IGNORE_IDX,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird that we pass padding_idx but still have CROSS_ENTROPY_IGNORE_IDX as a global constant. (I'm just complaining, don't worry about it here)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I dislike this as well b/c it means that we don't actually support losses that do NOT use this ignore index OOTB. BUT, we make this assumption throughout the repo so there's no point in deviating here in this PR.

@joecummings joecummings merged commit 3d1507c into pytorch:main Jun 27, 2024
29 checks passed
@joecummings joecummings deleted the pack-mask-on-the-fly branch June 27, 2024 18:18
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Investigate possible memory leak with sample packing
6 participants