-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP2] Used multi-grad hook when no inputs require grad #129259
base: gh/awgu/606/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129259
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit a05c807 with merge base 2820e1d (): UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
@@ -556,3 +584,35 @@ def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): | |||
def backward(ctx, *grads: torch.Tensor): | |||
ctx.param_group.post_backward() | |||
return (None,) + grads | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For context, the existing multi-grad hook does not work so well with the existing post_backward
because the last parameter's .grad
is not assigned yet when the mulit-grad hook runs (expected behavior).
We actually prefer this mulit-post-accumulate-grad hook, where the last parameter's .grad
has already been assigned.
] | ||
self._multi_grad_hook_handle = _register_multi_post_acc_grad_hook( | ||
tensors, self._multi_grad_post_backward | ||
) | ||
return args, kwargs # no tensors that require gradients |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key idea is that before, we just return when there are no module inputs that require grad, falling back to the final callback that runs at the end of backward to run post-backward for this module.
This multi-post-acc-grad hook allows running the module's post-backward earlier (but still correctly).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is neat!
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: 0efb902e087dc525b4474da374a6f3bf6cc9c23e Pull Request resolved: #129259
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k [ghstack-poisoned]
ghstack-source-id: c91c7b9863dbbde9839ec526007f5d73495d0f67 Pull Request resolved: #129259
@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 1d6c8c588f3d5d15b6fe9d018ec976ca249d800f Pull Request resolved: #129259
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 1f70eb05e82c2fc4afaff50f68ab9248b813e3ac Pull Request resolved: #129259
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 1ca85b983638ea6fe3d8b0c2a15926e5f44b11be Pull Request resolved: #129259
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: c19e7ec40459791d907d234fa6aa6fdb3a9182de Pull Request resolved: #129259
<details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 016f50f0b8aeac260b5ff808bb223d890e36409e Pull Request resolved: #129259
This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient. For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook. Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient. <details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 44001c46cc57072f946599e17c4d6af1a642fce8 Pull Request resolved: #129259
This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient. For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook. Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient. <details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 376c8b9628baef2dde1c194d7f7183514f6ea35f Pull Request resolved: #129259
This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient. For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook. Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient. <details> <summary> Common Case </summary> In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node. New: ![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f) Old: ![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab) </details> cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616) [ghstack-poisoned]
ghstack-source-id: 5fb80831bbfb9624fbe184af4b741bcedd102180 Pull Request resolved: #129259
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is very clean! Kudos!
Stack from ghstack (oldest at bottom):
register_multi_post_accumulate_grad_hook
#131949This PR uses
register_multi_post_accumuluate_grad_hook
to run the post-backward logic when the module inputs do not require gradient.For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook.
Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient.
Common Case
In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding
AccumulateGrad
node.New:
Old:
cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang
Differential Revision: D59012616