-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
HSDP + set_optimizer_state_dict
errors with monolithic checkpointing
#128444
Labels
module: distributed_checkpoint
oncall: distributed checkpointing
Oncall label should be attached to any issues related to distributed checkpointing.
oncall: distributed
Add this issue/PR to distributed oncall triage queue
release notes: distributed (checkpoint)
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Comments
@pytorchbot label "oncall: distributed" |
I think this fixes it... #128446 but not 100% sure |
PaliC
pushed a commit
that referenced
this issue
Jun 17, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast Pull Request resolved: #128446 Approved by: https://github.com/fegin
fegin
pushed a commit
that referenced
this issue
Jun 21, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast Pull Request resolved: #128446 Approved by: https://github.com/fegin (cherry picked from commit 153362f)
atalman
pushed a commit
that referenced
this issue
Jun 26, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast Pull Request resolved: #128446 Approved by: https://github.com/fegin (cherry picked from commit 153362f) Co-authored-by: Mihir Patel <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
module: distributed_checkpoint
oncall: distributed checkpointing
Oncall label should be attached to any issues related to distributed checkpointing.
oncall: distributed
Add this issue/PR to distributed oncall triage queue
release notes: distributed (checkpoint)
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
馃悰 Describe the bug
PyTorch 2.3.1 (and nightly)
Versions
TLDR: HSDP + DCP monolith checkpointing break.
When using DCP +
set_optimizer_state_dict
, the code eventually goes to_optim_state_dict_to_load_impl
here, which calls_flatten_optim_state_dict
and in turn_broadcast_processed_state
defined here.This broadcast occurs across the specified group, which is
None
-> global processs group (asset_optimizer_state_dict
has no way of specifying the group). However, when using HSDP,fsdp_state.rank
is the local rank within the shard group and not the global rank. Thus, local rank 0 of each shard group will error out as it returnsoptim_state
, which is set toNone
in_optim_state_dict_to_load_impl
here.I think the correct solution is to check for rank in the same group as the broadcast but not 100% sure
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC @MeetVadakkanchery
The text was updated successfully, but these errors were encountered: