-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Inductor] Fix the High Order Op layout issue (#128275) #128834
Merged
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` Pull Request resolved: #128275 Approved by: https://github.com/eellison
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128834
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 17 New FailuresAs of commit e41397d with merge base b66e3f0 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
atalman
approved these changes
Jun 19, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Fix the issue: #127995
FallbackKernel
, thedevice
of theNoneLayout
is set toNone
whenexample_output
returns fromcls.process_kernel
isNone
.pytorch/torch/_inductor/ir.py
Lines 5632 to 5649 in 921aa19
ExternalKernel schedulerNode
has None device, the previous buffer will not flush before codegen thisExternalKernel schedulerNode
which causes the wrong generated code.pytorch/torch/_inductor/scheduler.py
Lines 2701 to 2709 in ef2b5ed
Test Plan
Pull Request resolved: #128275
Approved by: https://github.com/eellison
Fixes #ISSUE_NUMBER
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang