Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grouped Query Attention #128898

Closed
wants to merge 39 commits into from
Closed

Grouped Query Attention #128898

wants to merge 39 commits into from

Conversation

jainapurva
Copy link
Contributor

@jainapurva jainapurva commented Jun 17, 2024

Approach: Using the current function declaration

Constraint: Q_Heads % KV_Heads == 0

Major change:

  • Added a new argument enable_gqa: bool to sdpa function call
  • It adds a meaning to the last third dimension.

Sample use cases this would enable:
LLama3

# LLama3 8b call to SDPA
query = torch.rand(batch, 32, seq_len_q, D)
key = torch.rand(batch, 8, seq_len_kv, D)
value = torch.rand(batch, 8, seq_len_kv, D)

output = scaled_dot_product_attention(query, key, value, is_causal=True, enable_gqa=True)

# Output Shape
(batch, 32, seq_len_q, D)

Design Choice:

  • Check if Query.size(-3) == Key.size(-3) == Value.size(-3) or, Query.size(-3) % Key.size(-3) == 0
  • The function adjusts the key and value tensors to match the query tensor's head dimension by using repeat_interleave if their number of heads are not equal, facilitating correct and efficient computation in attention mechanisms.
  • By default the enable_gqa flag is set to False, which ensures that regular sdpa functionality remains unchanged.

Benchmarks:

  • sdpa.py: Gqa benchmark #130634
    For different batch sizes enable_gqa=True shows a substansial improvement in the run_time of sdpa
batch_size q_num_heads kv_num_heads q_seq_len kv_seq_len embed_dim forward_time when enable_gqa=True forward_time when enable_gqa=False
1 32 8 2048 2048 2048 100.71 119.70
8 32 8 2048 2048 2048 539.78 628.83
16 32 8 2048 2048 2048 1056.81 1225.48
32 32 8 2048 2048 2048 2099.54 2440.45

Screenshot 2024-07-25 at 9 07 40 PM

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link

pytorch-bot bot commented Jun 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128898

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 95e724e with merge base e6cddc9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@drisspg drisspg self-requested a review June 18, 2024 23:27
@jainapurva jainapurva marked this pull request as ready for review June 20, 2024 16:20
@jainapurva
Copy link
Contributor Author

@pytorch-bot rebase

@jainapurva jainapurva marked this pull request as draft June 22, 2024 18:53
test/test_transformers.py Outdated Show resolved Hide resolved
test/test_transformers.py Outdated Show resolved Hide resolved
@jainapurva jainapurva marked this pull request as ready for review July 1, 2024 00:53
@jainapurva jainapurva requested a review from mruberry as a code owner July 1, 2024 18:04
test/test_transformers.py Outdated Show resolved Hide resolved
test/test_transformers.py Outdated Show resolved Hide resolved
@jainapurva
Copy link
Contributor Author

@pytorchbot rebase

@jainapurva jainapurva marked this pull request as draft July 3, 2024 21:35
@jainapurva jainapurva marked this pull request as ready for review July 8, 2024 17:42
@albanD albanD removed their request for review July 8, 2024 23:02
test/test_transformers.py Outdated Show resolved Hide resolved
torch/nn/functional.py Outdated Show resolved Hide resolved
Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good, left a few comments

I think it would be helpful to run this script: and get some micro benchmark data: https://github.com/pytorch/pytorch/blob/main/benchmarks/transformer/sdpa.py

test/test_nestedtensor.py Outdated Show resolved Hide resolved
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 31, 2024
@jainapurva
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

This PR updates submodules third_party/fmt

If those updates are intentional, please add "submodule" keyword to PR title/description.

@pytorch pytorch deleted a comment from pytorchmergebot Jul 31, 2024
@pytorch pytorch deleted a comment from pytorchmergebot Jul 31, 2024
@jainapurva
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: nn release notes category Reverted
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants