-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sdp::SDPBackend::flash_attention support PrivateUse1 #126392
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126392
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9b7f54b with merge base a0dac3d (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
The current structure of this op looks like; |-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
| |
| |-- if PrivateUse1:
| | |-- handle_private_use(...)
| |-- else:
| |-- _fused_sdp_choice_stub(...)
|
|-- switch (backend)
|
|-- case cudnn_attention:
| |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
|
|-- case flash_attention:
| |-- if CUDA:
| | |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
| |-- else (CPU):
| |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
|
|-- case efficient_attention:
| |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
|
|-- case math:
| |-- return at::_scaled_dot_product_attention_math(...)
|
|-- default:
|-- TORCH_CHECK(false, "No viable backend found.")
|-- return Tensor() I spoke with Alban offline about this, and we came to the conclusion that we want this structure: |-- Determine backend (CUDA, CPU, HIP, PrivateUse1)
| | If stub_registered(){
| | |--_fused_sdp_choice_stub(...)
| | Else
|. | Use math as choice
|
|-- switch (backend)
|
|-- case cudnn_attention:
| |-- out_lse_softmax = at::_scaled_dot_product_cudnn_attention(...)
|
|-- case flash_attention:
| |-- if CUDA:
| | |-- out_lse_softmax = at::_scaled_dot_product_flash_attention(...)
| |-- else (CPU):
| |-- return at::_scaled_dot_product_flash_attention_for_cpu(...)
|
|-- case efficient_attention:
| |-- out_and_lse = at::_scaled_dot_product_efficient_attention(...)
|
|-- case overridable:
|-- return at::_scaled_dot_product_attention_overridable(...)
}
|-- case math:
| |-- return at::_scaled_dot_product_attention_math(...)
|
|
|-- default:
|-- TORCH_CHECK(false, "No viable backend found.")
|-- return Tensor() So what does that mean for this PR, the structure looks pretty good. I made some changes here that should enable this, so once this lands we can make land your updates: #126832 The dispatching logic for the kernels will be
|
# Summary Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler: #126392 Pull Request resolved: #126832 Approved by: https://github.com/ezyang, https://github.com/albanD
# Summary Adds a public method to dispatchstub to check if a fn has been registered for a device. We use this new function to clean up the dispatching logic for SDPA, as well as make the private use dispatching simpler: pytorch#126392 Pull Request resolved: pytorch#126832 Approved by: https://github.com/ezyang, https://github.com/albanD
@drisspg |
The PR I referenced above has landed can you rebase? |
@pytorchmergebot rebase |
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Successfully rebased |
c276523
to
761a79d
Compare
@drisspg Rebased and fixed some CI issues |
@@ -680,10 +684,15 @@ Tensor scaled_dot_product_attention( | |||
auto out_lse_softmax = at::_scaled_dot_product_flash_attention( | |||
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked()); | |||
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size); | |||
} | |||
} else if (query_.device().type() == DeviceType::PrivateUse1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesnt look right to me
It should now just be 1 more case switch entry
You will need to add the overridable backend
case sdp::SDPBackend::overridable:
return std::get<0>(at::_scaled_dot_product_attention_overridable(
...));```
Private use authors would thsu register a dispatch to the stub and have it return the overrridable backend
by default they would be routed to the math backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@drisspg Could you please help review the PR again? Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have any further questions, feel free to bring them up. @drisspg
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left a comment
d28f942
to
93e6eb4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One more small comment but otherwise this is looking really good
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Tried to rebase and push PR #126392, but it was already up to date. |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 3 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
@pytorchbot rebase -b main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
e28410c
to
9b7f54b
Compare
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #124271
cc @cpuhrsch @drisspg @albanD @soulitzer