-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC][Pipelining] Support separate dW/dInput in Schedule and Stage #128974
Comments
Good writeup! My opinions on the open questions:
To start, no, since that would be simpler. We can just error. However, once we have a native autograd to split FBW, then we can relax this because FB schedules can just be defined as FBW schedules.
I think if we have autograd support, we should do a refactor to define the FB schedules as FBW schedules (basically replace all B with BW). Then all schedules would be FBW schedules and also share the same runtime? Hopefully that should be at no cost to performance.
I think we should just keep custom backward at the stage level. I'm not clear why we need to pass it to schedule level. It if is too difficult for the I have a question on why we need to add a |
FB vs FBW schedules
Traditionally, PP schedules only schedule the 'forward' and 'backward' step (call these FB schedules). More recently, the Zero Bubble schedule suggests that separating Backward into 'dInput' and 'dWeight' portions allows for finer-grained scheduling and reduction of pipeline bubbles (call these FBW schedules).
Custom Backwards vs Autograd
There are two ways to achieve dW / dInput separation.
Requirements
Stage backward computation
a. Must be able to compute 'whole backward' for FB schedules
b. Must be able to compute separate dW / dI for FBW schedules
Stage configuration
a. When using a 'custom backwards' and a 'dW runner' function provided by the user, the stage must accept the dW runner function
b. If using a model without a custom backwards, the stage should use torch.grad to separate the dW from the dI
Open Questions
Proposal
full_backward: bool=True
stage.backward_one_chunk(full_backward=False)
stage.run_backward()
default behavior would be to run model.backward() and dw_runner() at the same time to make the FBW compatible stage behave like an FB stage.dw_runner
is provided, autograd-based dw_runner creation would be skipped and the provided dw_runner takes precedence. If nodw_runner
is provided, one would be automatically created so the stage would beFBW compatible
.cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @yf225 @chauhang @d4l3k
The text was updated successfully, but these errors were encountered: