-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Always use high precision for SDPA math backend #128922
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128922
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit decfde5 with merge base 54d4f6b (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D58710805 |
This pull request was exported from Phabricator. Differential Revision: D58710805 |
ae61dc2
to
f3d5029
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
f3d5029
to
2c1c544
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
2c1c544
to
fbd5586
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
fbd5586
to
86096ad
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
86096ad
to
9330113
Compare
9330113
to
9cc4514
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D58710805 |
9cc4514
to
aecd432
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
aecd432
to
412ac0d
Compare
This pull request was exported from Phabricator. Differential Revision: D58710805 |
412ac0d
to
80caa02
Compare
0475bfb
to
3056faf
Compare
Summary: Pull Request resolved: #128922 feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts. Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16. Reviewed By: feikou, xw285cornell Differential Revision: D58710805
test/test_transformers.py
Outdated
@@ -3030,7 +3030,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le | |||
*zip(grads_ref, grads_ref_lp, grads), | |||
fudge_factors={ | |||
'out': 1.5, | |||
'grad_query': 26.0, | |||
'grad_query': 68.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its a little strange to me that if the new math implementation is supposed to more closely match the flash attention kernel that it is so much more precise for grad_q on to fp64 math.
I actually recently updated this test in prep for your PR expecting that we would not need to tolerance bump for the flash impl especially by so much
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check comment in https://www.internalfb.com/diff/D58710805?dst_version_fbid=437617645906478&transaction_fbid=1233726084662163 . It's a bit frustrating to make a simple change for SDPA with almost 30 times rebase :)
grads_ref_lp
returned from Low Precision Math Reference has been significantly improved with this PR. Before this PR, let's say it's grad_q_ref_rtol=3142 , and after this PR, it's 0.08.
With this comparison logic, originally before this PR we have the boundary of 26.0 * 3142 (grad_q_ref_rtol) . Now after this PR the boundary is 26.0 * 0.08 (grad_q_ref_rtol). Even if I increase the ratio from 26.0 to 100.0 (100.0 * 0.08 vs. 26.0 * 3142), the absolute error bound will be much less than the original error bound before this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah its a very popular piece of code ;)
I agree with the original premise that the 'rtol' was too lenient thats why I switched the testing to instead only measure against max abs difference. But the updated tests now:
1.) Uses fp64 as the gold standard and computes using the math impl.
2.) Gets a threshold by calculating the max abs difference between low precision
math and fp64 math.
3.) Asserts the fused impls outs and grads' max abs difference between it and the gold impl is less than the tolerance from 2.
In the current impl ( before this PR). There are intermediate casts to the low precision
dtype for the low precision
math impl on fp16 and bf16 inputs. If the fused impls were uniformly more precise than the current math impl (since they dont do these roundings) then I would expect that the fudge factors would all be 1.0
. That however was not the case. So when you bump the tolerances now you are only affecting the absolute error not the relative.
There is a great note connecting the errors introduced in FAv2 to the iterative softmax algorithm: https://fb.workplace.com/notes/200492239686121
TBH, these unit tests are not designed to 'prove' the accuracy of the fused kernels, that is proved out in training and in inference in E2E use cases, since so far we havent been taking a ULP like framing of these ops(maybe we should).
The main goal is to catch any regressions that might get added and to be sure to cover a large swash of the fused kernels configuration space. If these tolerance changes still ensure that, then I am fine with the change. If you are saying that we made the math's grad query 30x closer to the fp64 version and thus we need to bump flash's grad query to be 30x the new tolerance then I think this change makes sense
7757213
to
36f7577
Compare
743203d
to
da8e381
Compare
@jianyuh has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
da8e381
to
db613ee
Compare
db613ee
to
decfde5
Compare
Summary:
feikou observed the big numerical gaps when using math backend on AMD and NV GPUs. It's mainly because we are not using higher precision FP32 for the intermediate accumulated/materialized parts.
Since math backend is expected to be slower anyways, and we expect math backend to generate the correct reference result, I think it should be worth to upcast FP16/BF16 input to FP32, and do FP32/TF32 computations, and then downcast FP32 output back to FP16/BF16.
Differential Revision: D58710805