Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][Pipelining] Support separate dW/dInput in Schedule and Stage #128974

Closed
wconstab opened this issue Jun 18, 2024 · 1 comment
Closed

[RFC][Pipelining] Support separate dW/dInput in Schedule and Stage #128974

wconstab opened this issue Jun 18, 2024 · 1 comment
Assignees
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@wconstab
Copy link
Contributor

wconstab commented Jun 18, 2024

FB vs FBW schedules

Traditionally, PP schedules only schedule the 'forward' and 'backward' step (call these FB schedules). More recently, the Zero Bubble schedule suggests that separating Backward into 'dInput' and 'dWeight' portions allows for finer-grained scheduling and reduction of pipeline bubbles (call these FBW schedules).

Custom Backwards vs Autograd

There are two ways to achieve dW / dInput separation.

  1. Write a custom backwards for your whole model. In it, execute only dInput computations during backwards. Create lambda functions for any dW computations and put them into a container. Create another function that runs every lambda in the container. call this the 'dW runner function. Pass the dW runner to the pipeline runtime and let it be called whenever the schedule has a 'W' step.
  2. allow torch autograd to partition the backwards for your model by determining the set of 'W' nodes (the parameters in your model) and delaying computation of those. See [autograd] Support GradientEdge as output for torch.autograd.grad #127766 for a prototype of the needed functionality, which would be extended by the pipelining runtime.

Requirements

  1. Stage backward computation
    a. Must be able to compute 'whole backward' for FB schedules
    b. Must be able to compute separate dW / dI for FBW schedules

  2. Stage configuration
    a. When using a 'custom backwards' and a 'dW runner' function provided by the user, the stage must accept the dW runner function
    b. If using a model without a custom backwards, the stage should use torch.grad to separate the dW from the dI

    • users should be able to pass in a list of fqn to include in dW if they do not want to include all Weight gradients in dW, to customize the automatic splitting

Open Questions

  • if a stage is configured for FBW compatibility, should it also be usable by an FB schedule?
  • if a stage is configured the simplest way, and we have support for autograd to automatically split FBW, should the stage automatically work with both FB and FBW schedules?
  • Should users pass the custom backward at when creating stage objects or when creating a schedule, and then have the schedule pass it to the stage at runtime?
  • how should we change the stage.backward_one_chunk() api such that it can either run 'whole backwards' or 'just dInput' in the right situations?
  • Is it a design goal that a Stage can be constructed in a way that the construction code can be left alone while the schedule to be used is changed (e.g. by a runtime flag)? In other words, should we require that a Stage object constructed one way can be compatible with all schedule types? and FBW-compatible stage would be a 'superset' in this case.

Proposal

  1. Stage backward computation
  • stage.backward_one_chunk() gets a new kwarg, full_backward: bool=True
    • it defaults to true, meaning that existing FB schedules do not need to change
    • when passed as false, it would have different behavior depending on how the stage is configured
      • with custom backward: the 'full backward' of the model would be run, and the 'dw runner' would be saved for later
      • without custom backward: the backward would be run in such a way as to defer dW, and create a dw runner to save for later
  • stage.backward_weight_one_chunk() is a new API. its semantics are that it can only be called after a call to stage.backward_one_chunk(full_backward=False)
  • FBW schedules must call the appropriate sequence of B, W API calls
  • If and FB schedule is used with a stage that has FBW capability, stage.run_backward() default behavior would be to run model.backward() and dw_runner() at the same time to make the FBW compatible stage behave like an FB stage.
  1. Stage configuration
  • pipeline stage constructor would accept an optional kwarg 'dw_runner' along with the model chunk. If provided, the stage would be 'FBW compatible' as described above.
  • in the future, we may extend pipeline stage to support autograd-based dW. In that case, if a dw_runner is provided, autograd-based dw_runner creation would be skipped and the provided dw_runner takes precedence. If no dw_runner is provided, one would be automatically created so the stage would be FBW compatible.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k

@wconstab wconstab added oncall: distributed Add this issue/PR to distributed oncall triage queue module: pipelining Pipeline Parallelism labels Jun 18, 2024
@wconstab wconstab self-assigned this Jun 18, 2024
@H-Huang
Copy link
Member

H-Huang commented Jun 18, 2024

Good writeup! My opinions on the open questions:

if a stage is configured for FBW compatibility, should it also be usable by an FB schedule?

To start, no, since that would be simpler. We can just error. However, once we have a native autograd to split FBW, then we can relax this because FB schedules can just be defined as FBW schedules.

if a stage is configured the simplest way, and we have support for autograd to automatically split FBW, should the stage automatically work with both FB and FBW schedules?

I think if we have autograd support, we should do a refactor to define the FB schedules as FBW schedules (basically replace all B with BW). Then all schedules would be FBW schedules and also share the same runtime? Hopefully that should be at no cost to performance.

Should users pass the custom backward at when creating stage objects or when creating a schedule, and then have the schedule pass it to the stage at runtime?

I think we should just keep custom backward at the stage level. I'm not clear why we need to pass it to schedule level. It if is too difficult for the _step_microbatches to share implementations between FB and FBW schedules, then I think we can just have a different implementation of _step_microbatches.


I have a question on why we need to add a stage.backward_weight_one_chunk()? How come we can't just update stage.backward_one_chunk() with "input" or "weight" as the argument? This seems like it would be more inline with the autograd changes once we have them.

@wz337 wz337 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 18, 2024
wconstab added a commit that referenced this issue Jun 18, 2024
Fixes #128974

ghstack-source-id: 293da8fc635b2621e15dd32fdf307bd857025f92
Pull Request resolved: #128983
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants