Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

HSDP + set_optimizer_state_dict errors with monolithic checkpointing #128444

Closed
mvpatel2000 opened this issue Jun 11, 2024 · 4 comments
Closed

HSDP + set_optimizer_state_dict errors with monolithic checkpointing #128444

mvpatel2000 opened this issue Jun 11, 2024 · 4 comments
Assignees
Labels
module: distributed_checkpoint oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Jun 11, 2024

馃悰 Describe the bug

PyTorch 2.3.1 (and nightly)

Versions

TLDR: HSDP + DCP monolith checkpointing break.

When using DCP + set_optimizer_state_dict, the code eventually goes to _optim_state_dict_to_load_impl here, which calls _flatten_optim_state_dict and in turn _broadcast_processed_state defined here.

def _broadcast_processed_state(
    fsdp_state: _FSDPState,
    optim_state: Dict[str, Any],
    group: Optional[dist.ProcessGroup],
) -> Dict[str, Any]:
    objects: List[Any] = [None]
    if fsdp_state.rank == 0:
        objects[0] = tree_map_only(
            torch.Tensor,
            lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype),  # type: ignore[union-attr]
            optim_state,
        )
    dist.broadcast_object_list(objects, src=0, group=group)
    if fsdp_state.rank == 0:
        return optim_state
    else:
        return objects[0]

This broadcast occurs across the specified group, which is None -> global processs group (as set_optimizer_state_dict has no way of specifying the group). However, when using HSDP, fsdp_state.rank is the local rank within the shard group and not the global rank. Thus, local rank 0 of each shard group will error out as it returns optim_state, which is set to None in _optim_state_dict_to_load_impl here.

        if rank0_only and dist.get_rank(group) > 0:
            optim_state_dict = {}

I think the correct solution is to check for rank in the same group as the broadcast but not 100% sure

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k @LucasLLC @MeetVadakkanchery

@mvpatel2000
Copy link
Contributor Author

@pytorchbot label "oncall: distributed"

@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jun 11, 2024
@mvpatel2000
Copy link
Contributor Author

I think this fixes it... #128446 but not 100% sure

@mvpatel2000
Copy link
Contributor Author

would need help writing a unit test... not sure what to do / where to add

CC: @fegin @LucasLLC @wz337 maybe?

@weifengpy
Copy link
Contributor

I am oncall this week. To @fegin : this is HSDP + DSD, both component are very relevant to you. WYT? I see you are tagged in solution PR as well #128446

@weifengpy weifengpy added release notes: distributed (checkpoint) oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jun 13, 2024
@fegin fegin self-assigned this Jun 13, 2024
PaliC pushed a commit that referenced this issue Jun 17, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast

Pull Request resolved: #128446
Approved by: https://github.com/fegin
fegin pushed a commit that referenced this issue Jun 21, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast

Pull Request resolved: #128446
Approved by: https://github.com/fegin

(cherry picked from commit 153362f)
atalman pushed a commit that referenced this issue Jun 26, 2024
Fixes #128444. Rank 0 check should be in the same group as the broadcast

Pull Request resolved: #128446
Approved by: https://github.com/fegin

(cherry picked from commit 153362f)

Co-authored-by: Mihir Patel <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: distributed_checkpoint oncall: distributed checkpointing Oncall label should be attached to any issues related to distributed checkpointing. oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (checkpoint) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants