-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix add decomposition for complex numbers #129044
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129044
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6a94092 with merge base 920ebcc (): UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torch/_inductor/decomposition.py
Outdated
@@ -340,7 +340,32 @@ def add(x, y, *, alpha=None): | |||
if alpha is not None: | |||
z = alpha * y | |||
complex_type = torch.promote_types(x.dtype, y.dtype) | |||
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type) | |||
if x_is_complex_tensor and y_is_complex_tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see we bail out in the case when any of the tensors is real, so we can probably delete the branches here. Otherwise I think that just adding a unsqueeze(-1)
to the real tensor would make everything work, in case you want to implement that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1, delete branch or support the case (and add a test for it)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I added support for adding a complex tensor and a real tensor.
torch/_inductor/decomposition.py
Outdated
tensor_unsqueezed = tensor.unsqueeze(-1) | ||
zeros = torch.zeros_like(tensor_unsqueezed) | ||
return torch.cat((tensor_unsqueezed, zeros), dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right. We have to do this. This is not going to fuse well, and we might even regress perf.
I can think of three ways of doing this:
- Doing
mask = (torch.arange(last_dim) % 2 == 0).view(last_dim // 2, 2)
and broadcasting the real tensor and choosing the sum or this tensor depending on this mask. - A better way to do this at the moment would be via lowering with masked loads to avoid unnecessary compute
- Long term, we would represent complex numbers as two tensors (real and imaginary part) and all this would be just trivial to do
All these are a bit involved. Perhaps it's better to simply fallback to eager in this case and wait until someone tackles the general complex support problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. To keep things simple, I'm now letting any add that involves real tensor fallback to eager.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm ok, looks like if we let adding complex and real tensor fallback to eager, then dynamic shapes codegen breaks. For example, the test below fails (my newly added test case).
python test/inductor/test_torchinductor_codegen_dynamic_shapes.py -k DynamicShapesCodegenCpuTests.test_add_complex9_dynamic_shapes_cpu
If I add back the decomposition, it would pass.
Is this expected? I think we still need a decomposition for adding complex number and real number?
cc @zou3519
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do these tests fail before your PR? (My guess is yes). If so, then I'd just skip them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. They do. I'll remove these tests then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you file some issues for them please?
2c84515
to
d8618d1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the fix!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #125745
Bug source: When addition requires broadcasting, adding complex numbers is not implemented correctly in
torch/_inductor/decomposition.py
becausex.view(x.real.dtype)
would multiply the last dimension by 2, and then broadcasting wouldn't work.Fix: re-shape the complex tensors after view and before broadcasting.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang