-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nondeterminism in torch.compile + custom op #127995
Comments
Looks like a regression from 2.3.0. Also appears to only repro with backend="inductor" for me (though it is non-deterministic) |
This is possibly an inductor mutation issue (cc @Chillee). If I could reproduce this more reliably then I would bisect |
@zou3519 will take a look |
a few datapoints: (1) it fails on cpu, but not cuda (2) the generated code looks like this:
Where the "bug" appears to be that we are running a (3) I diff'd the GraphLowering when running the repro on cpu vs cuda, and one difference I see: cpu:
cuda:
One difference is that the |
I can reliably repro this on my mac. So far it looksl ike something between 4/7 to 5/1 caused it. So it's likely not @Chillee's recent Inductor mutation changes |
Bisection says a commit between 04/08 and 04/10 caused the regression. #122945 looks suspicious so far |
Testing with commit: 22712ba and d3ad84c
The generated code is without the last
|
I can reproduce this random failure on CPU now. Totally 3 SchedulerNodes,
Buf0's origins is
Buf0's origins is ====================== I guess I found why the behavior is non-deterministic:
I guess we may add |
@leslie-fang-intel any updates from your side? This seems like a regression, so we should really fix it for PyTorch 2.4. But I'm not very familiar with Inductor internals |
Sure, I will further investigate the behavior with PyTorch 2.3. |
cc: @eellison |
Some finding from this issue
Generated code in PyTorch 2.3
Generated code in PyTorch 2.4
The cause seems due to the buf0 pytorch/torch/_inductor/scheduler.py Lines 2701 to 2709 in ef2b5ed
|
Fix the issue: #127995 - In current implementation, the device of the `NoneLayout` will be None when `example_output` returns from `cls.process_kernel` is None. The test reported in the issue is the case when `ExternalKernel` returns None. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…sue" Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…sue" Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…sue" Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…sue" Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…sue" Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` Pull Request resolved: #128275 Approved by: https://github.com/eellison
|
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` Pull Request resolved: #128275 Approved by: https://github.com/eellison
Fix the issue: #127995 - In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649 - If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode` which causes the wrong generated code. https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709 **Test Plan** ``` python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none ``` Pull Request resolved: #128275 Approved by: https://github.com/eellison Co-authored-by: leslie-fang-intel <[email protected]>
confirmed works on 2.4.0 final rc |
Doesn't look good. The following sometimes succeeds and sometimes fails for me:
cc @ezyang @gchanan @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire
The text was updated successfully, but these errors were encountered: