Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pipelining] enable inputs for all model stages #128115

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

alexlan137
Copy link

@alexlan137 alexlan137 commented Jun 6, 2024

Overview

torch.distributed.pipelining / PiPPy currently does not support inputs for non-zero model stages (see #127136). I don’t think this is the best implementation, because it fails for a variety of models. For example, models that take in 'labels' as target values for computing the loss cannot be pipelined with the current implementation. Specifically, these models are split by PiPPy such that the placeholder for 'labels' exists in the last submodule (on the last device), and no mechanisms currently exists to input values to non-zero model stages. Here, I propose a quick-and-easy change to enable this functionality. To the best of my knowledge, it should not interfere with any existing workflows.

Test Script

I revised the pippy_gpt2.py test case in my fork of the pytorch/PiPPy package to test the pipeline when accepting inputs at non-zero model stages: alexlan137/PiPPy/examples/huggingface/pippy_gpt2.py
Only 2 modifications were made from the original pippy_gpt2 script:

  1. I set include_loss_args to True in generate_inputs_for_model to generate the 'labels' input alongside the 'input_ids'.
  2. I modified the schedule.step() condition to add the 'labels' input in the last (args.chunks - 1) stage.

Modifications

Change 1

The first bug encountered when running the modified test script is AssertionError: Expected _RecvInfo but got <class 'torch.distributed.pipelining.PipelineStage._RootArgPlaceholder'>. torch.distributed.pipelining currently does not support inputs apart from in the first stage. Naturally, the _retrieve_recv_activations method, which is not called when self.is_first, assumes all inputs to its layers are stored in recv_infos and are not RootArgPlaceholders. This is not true when you can add inputs to intermediate model stages.

To filter out RootArgPlaceholders in _retrieve_recv_activations:

recv_infos = tuple([info for info in recv_infos if isinstance(info, _RecvInfo)])

Change 2

The next bug encountered is TypeError: forward() missing 1 required positional argument: 'labels'. Because the ‘labels’ input is not passed through recv_infos from a previous layer but is a new placeholder, we need to modify the forward_one_chunk method to append the new args to the args received from the previous model stage. This ordering agrees with the pipe.split_gm graph module—the labels placeholder is always the last arg to the submod (as far as I know).

This is accomplished with:

composite_args = self._retrieve_recv_activations() + args

Change 3

The final error occurs in the final stage when it returns self._merge_outputs(self._stage.output_chunks). This merge function used to combine outputs across mini-batches requires all tensors to have dim >= 1. Because the model outputs loss as a tensor with dim() = 0, the output cannot be merged as originally designed. To fix this before merging:

if self._output_merge_spec is None:
    for chunk in self._stage.output_chunks:
        for tensor in chunk:
            if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
                tensor.unsqueeze_(0)

This solution is slightly hacky but works and will not affect any existing workflows, because any output without a merge_spec that has tensors with dimension 0 (scalars) will fail in merge_chunks (cannot be concatenated).

Both the unmodified original and the new test script pass with these three changes.

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (pipeline) release notes category labels Jun 6, 2024
Copy link

pytorch-bot bot commented Jun 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128115

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures

As of commit 692bf52 with merge base 68eb771 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@janeyx99 janeyx99 requested a review from kwen2501 June 6, 2024 17:15
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 6, 2024
@alexlan137
Copy link
Author

@kwen2501 Hey! Just wanted to check if you were able to take a look at this?

Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Aug 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (pipeline) release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants