Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Move the flattened tensors back to GPU to prevent CPU OOM #124008

Closed
wants to merge 1 commit into from

Conversation

exhyy
Copy link

@exhyy exhyy commented Apr 13, 2024

I encountered a CPU OOM issue when resuming from a checkpoint with FSDP.optim_state_dict_to_load. See https://github.com/huggingface/accelerate/blob/5ca095a34fede7c988af8c193eb0c0d199750845/src/accelerate/utils/fsdp_utils.py#L208 for details.

I notice that _flatten_tensor_optim_state and _flatten_zero_dim_tensor_optim_state will create tensors on CPU, and results in the OOM issue. I have tried creating tensors on GPU directly, but it caused OOM on GPU.

My solution is to move the flattened tensors back to the device where FSDP is running on. This works well for me with PyTorch 2.1.1. The current main branch doesn't seem to have fixed this issue yet.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @rohan-varma

Copy link

pytorch-bot bot commented Apr 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124008

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 60ba71a with merge base da7db5d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Copy link

linux-foundation-easycla bot commented Apr 13, 2024

CLA Signed


The committers listed above are authorized under a signed CLA.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Apr 13, 2024
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 15, 2024
@zou3519 zou3519 requested a review from awgu April 15, 2024 15:22
@awgu
Copy link
Contributor

awgu commented Apr 15, 2024

cc: @fegin since you are more familiar with this code

I am not fully convinced by this solution though 🤔 I may need to understand more.

@exhyy
Copy link
Author

exhyy commented Apr 20, 2024

Sorry for the late reply.
Well, I am not convinced about this solution as well. My main concern is that I am unsure if this will lead to a peak increase in GPU memory.

Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 19, 2024
@github-actions github-actions bot closed this Jul 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants