[pipelining] enable inputs for all model stages #128115
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Overview
torch.distributed.pipelining / PiPPy currently does not support inputs for non-zero model stages (see #127136). I don’t think this is the best implementation, because it fails for a variety of models. For example, models that take in 'labels' as target values for computing the loss cannot be pipelined with the current implementation. Specifically, these models are split by PiPPy such that the placeholder for 'labels' exists in the last submodule (on the last device), and no mechanisms currently exists to input values to non-zero model stages. Here, I propose a quick-and-easy change to enable this functionality. To the best of my knowledge, it should not interfere with any existing workflows.
Test Script
I revised the
pippy_gpt2.py
test case in my fork of the pytorch/PiPPy package to test the pipeline when accepting inputs at non-zero model stages: alexlan137/PiPPy/examples/huggingface/pippy_gpt2.pyOnly 2 modifications were made from the original pippy_gpt2 script:
include_loss_args
toTrue
in generate_inputs_for_model to generate the 'labels' input alongside the 'input_ids'.schedule.step()
condition to add the 'labels' input in the last (args.chunks - 1) stage.Modifications
Change 1
The first bug encountered when running the modified test script is
AssertionError: Expected _RecvInfo but got <class 'torch.distributed.pipelining.PipelineStage._RootArgPlaceholder'>
. torch.distributed.pipelining currently does not support inputs apart from in the first stage. Naturally, the _retrieve_recv_activations method, which is not called when self.is_first, assumes all inputs to its layers are stored in recv_infos and are not RootArgPlaceholders. This is not true when you can add inputs to intermediate model stages.To filter out RootArgPlaceholders in _retrieve_recv_activations:
Change 2
The next bug encountered is
TypeError: forward() missing 1 required positional argument: 'labels'
. Because the ‘labels’ input is not passed through recv_infos from a previous layer but is a new placeholder, we need to modify the forward_one_chunk method to append the new args to the args received from the previous model stage. This ordering agrees with the pipe.split_gm graph module—the labels placeholder is always the last arg to the submod (as far as I know).This is accomplished with:
Change 3
The final error occurs in the final stage when it returns self._merge_outputs(self._stage.output_chunks). This merge function used to combine outputs across mini-batches requires all tensors to have dim >= 1. Because the model outputs loss as a tensor with dim() = 0, the output cannot be merged as originally designed. To fix this before merging:
This solution is slightly hacky but works and will not affect any existing workflows, because any output without a merge_spec that has tensors with dimension 0 (scalars) will fail in merge_chunks (cannot be concatenated).
Both the unmodified original and the new test script pass with these three changes.
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang