Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix add decomposition for complex numbers #129044

Closed
wants to merge 9 commits into from
Closed

Conversation

yushangdi
Copy link
Contributor

@yushangdi yushangdi commented Jun 19, 2024

Fixes #125745

Bug source: When addition requires broadcasting, adding complex numbers is not implemented correctly in torch/_inductor/decomposition.py because x.view(x.real.dtype) would multiply the last dimension by 2, and then broadcasting wouldn't work.

Fix: re-shape the complex tensors after view and before broadcasting.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link

pytorch-bot bot commented Jun 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129044

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6a94092 with merge base 920ebcc (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@yushangdi yushangdi self-assigned this Jun 20, 2024
@yushangdi yushangdi marked this pull request as ready for review June 20, 2024 18:46
@yushangdi yushangdi marked this pull request as draft June 20, 2024 19:57
@yushangdi yushangdi requested review from ezyang, lezcano, zou3519 and bdhirsh and removed request for ezyang, lezcano, zou3519 and bdhirsh June 20, 2024 19:57
@yushangdi yushangdi marked this pull request as ready for review June 20, 2024 20:33
@@ -340,7 +340,32 @@ def add(x, y, *, alpha=None):
if alpha is not None:
z = alpha * y
complex_type = torch.promote_types(x.dtype, y.dtype)
return (x.view(x.real.dtype) + z.view(y.real.dtype)).view(complex_type)
if x_is_complex_tensor and y_is_complex_tensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see we bail out in the case when any of the tensors is real, so we can probably delete the branches here. Otherwise I think that just adding a unsqueeze(-1) to the real tensor would make everything work, in case you want to implement that case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, delete branch or support the case (and add a test for it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I added support for adding a complex tensor and a real tensor.

Comment on lines 365 to 367
tensor_unsqueezed = tensor.unsqueeze(-1)
zeros = torch.zeros_like(tensor_unsqueezed)
return torch.cat((tensor_unsqueezed, zeros), dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right. We have to do this. This is not going to fuse well, and we might even regress perf.
I can think of three ways of doing this:

  • Doing mask = (torch.arange(last_dim) % 2 == 0).view(last_dim // 2, 2) and broadcasting the real tensor and choosing the sum or this tensor depending on this mask.
  • A better way to do this at the moment would be via lowering with masked loads to avoid unnecessary compute
  • Long term, we would represent complex numbers as two tensors (real and imaginary part) and all this would be just trivial to do

All these are a bit involved. Perhaps it's better to simply fallback to eager in this case and wait until someone tackles the general complex support problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. To keep things simple, I'm now letting any add that involves real tensor fallback to eager.

Copy link
Contributor Author

@yushangdi yushangdi Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm ok, looks like if we let adding complex and real tensor fallback to eager, then dynamic shapes codegen breaks. For example, the test below fails (my newly added test case).

python test/inductor/test_torchinductor_codegen_dynamic_shapes.py -k DynamicShapesCodegenCpuTests.test_add_complex9_dynamic_shapes_cpu

If I add back the decomposition, it would pass.

Is this expected? I think we still need a decomposition for adding complex number and real number?

cc @zou3519

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests fail before your PR? (My guess is yes). If so, then I'd just skip them

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. They do. I'll remove these tests then.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you file some issues for them please?

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix!

@lezcano
Copy link
Collaborator

lezcano commented Jun 25, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 25, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@yushangdi yushangdi deleted the inductor_complex_add branch June 26, 2024 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torch.compile error: Attempting to broadcast a dimension of length 2 at -1
4 participants