Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Always use high precision for SDPA math backend #128922

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

jianyuh
Copy link
Member

@jianyuh jianyuh commented Jun 18, 2024

Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Differential Revision: D58710805

Copy link

pytorch-bot bot commented Jun 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128922

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit decfde5 with merge base 54d4f6b (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

1 similar comment
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D58710805

pytorch-bot bot pushed a commit that referenced this pull request Jul 28, 2024
Summary:
Pull Request resolved: #128922

feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.

Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.

Reviewed By: feikou, xw285cornell

Differential Revision: D58710805
@@ -3030,7 +3030,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
*zip(grads_ref, grads_ref_lp, grads),
fudge_factors={
'out': 1.5,
'grad_query': 26.0,
'grad_query': 68.0,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a little strange to me that if the new math implementation is supposed to more closely match the flash attention kernel that it is so much more precise for grad_q on to fp64 math.

I actually recently updated this test in prep for your PR expecting that we would not need to tolerance bump for the flash impl especially by so much

Copy link
Member Author

@jianyuh jianyuh Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check comment in https://www.internalfb.com/diff/D58710805?dst_version_fbid=437617645906478&transaction_fbid=1233726084662163 . It's a bit frustrating to make a simple change for SDPA with almost 30 times rebase :)

grads_ref_lp returned from Low Precision Math Reference has been significantly improved with this PR. Before this PR, let's say it's grad_q_ref_rtol=3142 , and after this PR, it's 0.08.

With this comparison logic, originally before this PR we have the boundary of 26.0 * 3142 (grad_q_ref_rtol) . Now after this PR the boundary is 26.0 * 0.08 (grad_q_ref_rtol). Even if I increase the ratio from 26.0 to 100.0 (100.0 * 0.08 vs. 26.0 * 3142), the absolute error bound will be much less than the original error bound before this PR.

Copy link
Contributor

@drisspg drisspg Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah its a very popular piece of code ;)

I agree with the original premise that the 'rtol' was too lenient thats why I switched the testing to instead only measure against max abs difference. But the updated tests now:

1.) Uses fp64 as the gold standard and computes using the math impl.
2.) Gets a threshold by calculating the max abs difference between low precision math and fp64 math.
3.) Asserts the fused impls outs and grads' max abs difference between it and the gold impl is less than the tolerance from 2.

In the current impl ( before this PR). There are intermediate casts to the low precision dtype for the low precision math impl on fp16 and bf16 inputs. If the fused impls were uniformly more precise than the current math impl (since they dont do these roundings) then I would expect that the fudge factors would all be 1.0. That however was not the case. So when you bump the tolerances now you are only affecting the absolute error not the relative.

There is a great note connecting the errors introduced in FAv2 to the iterative softmax algorithm: https://fb.workplace.com/notes/200492239686121

TBH, these unit tests are not designed to 'prove' the accuracy of the fused kernels, that is proved out in training and in inference in E2E use cases, since so far we havent been taking a ULP like framing of these ops(maybe we should).

The main goal is to catch any regressions that might get added and to be sure to cover a large swash of the fused kernels configuration space. If these tolerance changes still ensure that, then I am fine with the change. If you are saying that we made the math's grad query 30x closer to the fp64 version and thus we need to bump flash's grad query to be 30x the new tolerance then I think this change makes sense

@jianyuh jianyuh force-pushed the export-D58710805 branch 7 times, most recently from 7757213 to 36f7577 Compare July 30, 2024 06:40
@jianyuh jianyuh force-pushed the export-D58710805 branch 6 times, most recently from 743203d to da8e381 Compare July 31, 2024 23:20
@facebook-github-bot
Copy link
Contributor

@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jianyuh
Copy link
Member Author

jianyuh commented Jul 31, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR!

Details for Dev Infra team Raised by workflow job

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants