Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP2] Used multi-grad hook when no inputs require grad #129259

Open
wants to merge 14 commits into
base: gh/awgu/606/base
Choose a base branch
from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Jun 21, 2024

Stack from ghstack (oldest at bottom):

This PR uses register_multi_post_accumuluate_grad_hook to run the post-backward logic when the module inputs do not require gradient.

For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook.

Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient.

Common Case

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding AccumulateGrad node.

New:
Screenshot 2024-06-21 at 3 17 06 PM

Old:
Screenshot 2024-06-21 at 3 16 45 PM

cc @XilunWu @H-Huang @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @penguinwu @tianyu-l @yf225 @chauhang

Differential Revision: D59012616

Copy link

pytorch-bot bot commented Jun 21, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129259

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit a05c807 with merge base 2820e1d (image):

UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 21, 2024
cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Jun 21, 2024
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
@@ -556,3 +584,35 @@ def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
def backward(ctx, *grads: torch.Tensor):
ctx.param_group.post_backward()
return (None,) + grads

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For context, the existing multi-grad hook does not work so well with the existing post_backward because the last parameter's .grad is not assigned yet when the mulit-grad hook runs (expected behavior).

We actually prefer this mulit-post-accumulate-grad hook, where the last parameter's .grad has already been assigned.

]
self._multi_grad_hook_handle = _register_multi_post_acc_grad_hook(
tensors, self._multi_grad_post_backward
)
return args, kwargs # no tensors that require gradients
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key idea is that before, we just return when there are no module inputs that require grad, falling back to the final callback that runs at the end of backward to run post-backward for this module.

This multi-post-acc-grad hook allows running the module's post-backward earlier (but still correctly).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is neat!

<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: 0efb902e087dc525b4474da374a6f3bf6cc9c23e
Pull Request resolved: #129259
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jun 25, 2024
ghstack-source-id: c91c7b9863dbbde9839ec526007f5d73495d0f67
Pull Request resolved: #129259
@awgu
Copy link
Contributor Author

awgu commented Jun 25, 2024

@awgu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@sanketpurandare sanketpurandare self-requested a review June 26, 2024 00:21
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 26, 2024
ghstack-source-id: 1d6c8c588f3d5d15b6fe9d018ec976ca249d800f
Pull Request resolved: #129259
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 1f70eb05e82c2fc4afaff50f68ab9248b813e3ac
Pull Request resolved: #129259
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 1ca85b983638ea6fe3d8b0c2a15926e5f44b11be
Pull Request resolved: #129259
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: c19e7ec40459791d907d234fa6aa6fdb3a9182de
Pull Request resolved: #129259
<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 016f50f0b8aeac260b5ff808bb223d890e36409e
Pull Request resolved: #129259
@awgu awgu marked this pull request as ready for review July 29, 2024 16:45

This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient.

For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook.

Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient.

<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 44001c46cc57072f946599e17c4d6af1a642fce8
Pull Request resolved: #129259
@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 29, 2024

This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient.

For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook.

Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient.

<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 376c8b9628baef2dde1c194d7f7183514f6ea35f
Pull Request resolved: #129259

This PR uses `register_multi_post_accumuluate_grad_hook` to run the post-backward logic when the module inputs do not require gradient.

For context, FSDP2 currently relies on a hook on the module inputs that require gradient to run the post-backward logic (like a module full backward hook except with pytree support). This means that if none of the module inputs require gradient, then the post-backward logic is deferred to the final callback that runs at the end of backward, which may not be timely in some hybrid parallelism cases (e.g. a sparse arch before the FSDP dense arch). To address this case, we use this new multi-post-accumulate-grad hook.

Since whether the module inputs requiring grad could be dynamic from iteration to iteration, we guard this on a flag to make sure that we use the existing logic if the module inputs do require gradient.

<details>
<summary> Common Case </summary>

In the common case, the only module whose forward inputs do not require gradient is the root FSDP module, where the root FSDP module is the overall model's root. In this case, the multi-grad hook simply moves the root's post-backward from the final callback to a preceding `AccumulateGrad` node.

New:
![Screenshot 2024-06-21 at 3 17 06 PM](https://github.com/pytorch/pytorch/assets/31054793/e06cc6e4-2bba-488b-b15d-1a55c881e40f)

Old:
![Screenshot 2024-06-21 at 3 16 45 PM](https://github.com/pytorch/pytorch/assets/31054793/22eb1bcc-f128-459f-961c-f4f6ded00aab)

</details>

cc XilunWu H-Huang kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o mrshenli pritamdamania87 zhaojuanmao satgera gqchen aazzolini osalpekar jiayisuse penguinwu tianyu-l yf225 chauhang

Differential Revision: [D59012616](https://our.internmc.facebook.com/intern/diff/D59012616)

[ghstack-poisoned]
awgu added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 5fb80831bbfb9624fbe184af4b741bcedd102180
Pull Request resolved: #129259
Copy link
Contributor

@sanketpurandare sanketpurandare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very clean! Kudos!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants