Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sdp::SDPBackend::flash_attention support PrivateUse1 #126392

Closed

Conversation

1274085042
Copy link
Contributor

@1274085042 1274085042 commented May 16, 2024

Copy link

pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126392

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9b7f54b with merge base a0dac3d (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@albanD albanD requested review from drisspg and removed request for albanD May 16, 2024 13:25
@1274085042 1274085042 requested a review from drisspg May 21, 2024 02:55
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 21, 2024
@drisspg drisspg requested a review from jainapurva May 21, 2024 23:59
@drisspg
Copy link
Contributor

drisspg commented May 22, 2024

The current structure of this op looks like;

|-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
|    |
|    |-- if PrivateUse1:
|    |      |-- handle_private_use(...)
|    |-- else:
|          |-- _fused_sdp_choice_stub(...)
|
|-- switch (backend)
     |
     |-- case cudnn_attention:
     |      |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
     |
     |-- case flash_attention:
     |      |-- if CUDA:
     |      |      |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
     |      |-- else (CPU):
     |            |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
     |
     |-- case efficient_attention:
     |      |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
     |
     |-- case math:
     |      |-- return at::_scaled_dot_product_attention_math(...)
     |
     |-- default:
            |-- TORCH_CHECK(false, "No viable backend found.")
            |-- return Tensor()

I spoke with Alban offline about this, and we came to the conclusion that we want this structure:

|-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
|    | If stub_registered(){
|    | 		|--_fused_sdp_choice_stub(...)
|	 | Else
|.   | Use math as choice
|
|-- switch (backend)
     |
     |-- case cudnn_attention:
     |      |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
     |
     |-- case flash_attention:
     |      |-- if CUDA:
     |      |      |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
     |      |-- else (CPU):
     |            |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
     |
     |-- case efficient_attention:
     |      |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
     |
	 |-- case overridable:
	 		|-- return at::_scaled_dot_product_attention_overridable(...)
	 }
     |-- case math:
     |      |-- return at::_scaled_dot_product_attention_math(...)
     |
	 |
     |-- default:
            |-- TORCH_CHECK(false, "No viable backend found.")
            |-- return Tensor()

So what does that mean for this PR, the structure looks pretty good. I made some changes here that should enable this, so once this lands we can make land your updates: #126832

The dispatching logic for the kernels will be
default_choice is math, (if a device doesnt register a stub then they will get routed to math)

  • if a choice is registered devices have the option to go to an overridable op that this pr provides. That op should have no preprocessing but will be run through 'validate_sdpa' and convert attn_mask from bool to float

pytorchmergebot pushed a commit that referenced this pull request May 25, 2024
# Summary

Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler:
#126392
Pull Request resolved: #126832
Approved by: https://github.com/ezyang, https://github.com/albanD
titaiwangms pushed a commit to titaiwangms/pytorch that referenced this pull request May 28, 2024
# Summary

Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler:
pytorch#126392
Pull Request resolved: pytorch#126832
Approved by: https://github.com/ezyang, https://github.com/albanD
@1274085042
Copy link
Contributor Author

@drisspg
could this update be landed?

@drisspg
Copy link
Contributor

drisspg commented May 29, 2024

The PR I referenced above has landed can you rebase?

@1274085042
Copy link
Contributor Author

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased flash_attention_overrideable onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout flash_attention_overrideable && git pull --rebase)

@1274085042
Copy link
Contributor Author

@drisspg Rebased and fixed some CI issues

@@ -680,10 +684,15 @@ Tensor scaled_dot_product_attention(
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
}
} else if (query_.device().type() == DeviceType::PrivateUse1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesnt look right to me

It should now just be 1 more case switch entry

You will need to add the overridable backend

   case sdp::SDPBackend::overridable:
      return std::get<0>(at::_scaled_dot_product_attention_overridable(
         ...));```
         
         
         
     Private use authors would thsu register a dispatch to the stub and have it return the overrridable backend
     
     by default they would be routed to the math backend

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Could you please help review the PR again? Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have any further questions, feel free to bring them up. @drisspg

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left a comment

Copy link

linux-foundation-easycla bot commented Jun 3, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

@1274085042 1274085042 force-pushed the flash_attention_overrideable branch from d28f942 to 93e6eb4 Compare June 3, 2024 13:15
@1274085042 1274085042 requested a review from drisspg June 3, 2024 13:16
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more small comment but otherwise this is looking really good

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@1274085042
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #126392, but it was already up to date.

@1274085042
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@1274085042
Copy link
Contributor Author

@pytorchbot rebase -b main

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased flash_attention_overrideable onto refs/remotes/origin/main, please pull locally before adding more changes (for example, via git checkout flash_attention_overrideable && git pull --rebase)

@1274085042
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add logic about PrivateUse1 in sdp::SDPBackend::flash_attention
7 participants