Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nondeterminism in torch.compile + custom op #127995

Closed
Tracked by #130151
zou3519 opened this issue Jun 5, 2024 · 15 comments
Closed
Tracked by #130151

nondeterminism in torch.compile + custom op #127995

zou3519 opened this issue Jun 5, 2024 · 15 comments
Labels
high priority module: inductor module: reinplacing inductor reinplacing, re-inplacing, auto-functionalization, auto functionalization, custom op oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@zou3519
Copy link
Contributor

zou3519 commented Jun 5, 2024

Doesn't look good. The following sometimes succeeds and sometimes fails for me:

import torch
import numpy as np # E: module level import not at top of file # E: at least two spaces before inlin

lib = torch.library.Library("mylib", "FRAGMENT") 
lib.define("numpy_sin(Tensor input, Tensor(a!) output) -> ()")

def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:
    assert input.device == output.device
    assert input.device.type == "cpu"
    input_np = input.numpy()
    output_np = output.numpy()
    np.sin(input_np, out=output_np)

lib.impl("numpy_sin", numpy_sin, "CPU")

numpy_sin = torch.ops.mylib.numpy_sin

@torch.compile(fullgraph=True)
def f(x):
    out = torch.empty(3)
    numpy_sin(x, out)
    return out

x = torch.randn(3)
y = f(x)
print(torch.__version__)
assert torch.allclose(y, x.sin())

cc @ezyang @gchanan @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 5, 2024

Looks like a regression from 2.3.0. Also appears to only repro with backend="inductor" for me (though it is non-deterministic)

@zou3519 zou3519 added module: inductor oncall: cpu inductor CPU Inductor issues for Intel team to triage labels Jun 5, 2024
@zou3519
Copy link
Contributor Author

zou3519 commented Jun 5, 2024

This is possibly an inductor mutation issue (cc @Chillee). If I could reproduce this more reliably then I would bisect

@Chillee
Copy link
Contributor

Chillee commented Jun 5, 2024

@zou3519 will take a look

@bdhirsh
Copy link
Contributor

bdhirsh commented Jun 5, 2024

a few datapoints:

(1) it fails on cpu, but not cuda

(2) the generated code looks like this:

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
    # Source Nodes: [], Original ATen: []
    buf1 = torch.ops.mylib.numpy_sin.default(arg0_1, buf0)
    del arg0_1
    cpp_fused_empty_0(buf0)
    return (buf0, )

Where the "bug" appears to be that we are running a cpp_fused_empty_0 that fills the tensor with zeros, after the custom op has run.

(3) I diff'd the GraphLowering when running the repro on cpu vs cuda, and one difference I see:

cpu:

[StorageBox(
  ComputedBuffer(name='buf0', layout=FixedLayout('cpu', torch.float32, size=[3], stride=[1]), data=Pointwise(
    'cpu',
    torch.float32,
    def inner_fn(index):
        i0 = index
        tmp0 = ops.constant(0.0, torch.float32)
        return tmp0
    ,
    ranges=[3],
    origin_node=full_default,
    origins={full_default}
  ))
)]

cuda:

[StorageBox(
  ComputedBuffer(name='buf0', layout=FixedLayout('cuda', torch.float32, size=[3], stride=[1]), data=Pointwise(
    'cuda',
    torch.float32,
    def inner_fn(index):
        i0 = index
        tmp0 = ops.constant(0, torch.float32)
        return tmp0
    ,
    ranges=[0],
    origin_node=empty,
    origins={empty}
  ))
)]

One difference is that the empty_strided lowering desugars into an aten.full(0) for cpu but not cuda. I'm also not sure why the full() lowering gets scheduled to run after the custom op.

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 5, 2024

I can reliably repro this on my mac. So far it looksl ike something between 4/7 to 5/1 caused it. So it's likely not @Chillee's recent Inductor mutation changes

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 5, 2024

Bisection says a commit between 04/08 and 04/10 caused the regression.

#122945 looks suspicious so far

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Jun 6, 2024

Testing with commit: 22712ba and d3ad84c
I didn't reproduce this issue on local system. Probably as @zou3519 mentions in summary

it's nondeterminism as sometimes succeeds and sometimes fails.

The generated code is without the last cpp_fused_empty_0(buf0) on my local system. I guess I need to find a way which can reproduce this issue steadily.

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
    # Source Nodes: [], Original ATen: []
    buf1 = torch.ops.mylib.numpy_sin.default(arg0_1, buf0)
    del arg0_1
    return (buf0, )

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Jun 6, 2024

I can reproduce this random failure on CPU now.

Totally 3 SchedulerNodes,

  • When this case passes on CPU
---- in realize self.data.name is: buf0
ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.float32, size=[3], stride=[1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      i0 = index
      tmp0 = ops.constant(0, torch.float32)
      return tmp0
  ,
  ranges=[3],
  origin_node=None,
  origins={empty}
))
---- codegen node node is: NopKernelSchedulerNode(name='buf0')-----
---- codegen node node is: ExternKernelSchedulerNode(name='buf1')-----
---- codegen node node is: NopKernelSchedulerNode(name='buf2')-----
2.4.0.dev20240603+cpu

Buf0's origins is empty and generate NopKernelSchedulerNode

  • When this case failed on CPU
---- in realize self.data.name is: buf0
ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.float32, size=[3], stride=[1]), data=Pointwise(
  'cpu',
  torch.float32,
  def inner_fn(index):
      i0 = index
      tmp0 = ops.constant(0.0, torch.float32)
      return tmp0
  ,
  ranges=[3],
  origin_node=None,
  origins={full_default}
))
---- codegen node node is: SchedulerNode(name='buf0')-----
---- codegen node node is: ExternKernelSchedulerNode(name='buf1')-----
---- codegen node node is: NopKernelSchedulerNode(name='buf2')-----
2.4.0.dev20240603+cpu
Traceback (most recent call last):
  File "/ws1/zailiw/cases/pth_git_issues/127995.py", line 27, in <module>
    assert torch.allclose(y, x.sin())
AssertionError

Buf0's origins is full_default and generate SchedulerNode

======================
Update

I guess I found why the behavior is non-deterministic:

  • and bool((t == t.flatten()[0]).all())
    checks if the value from torch.ops.aten.empty.memory_format is exactly the same. The check is nondeterminism and will pass if torch.ops.aten.empty.memory_format happens to return a tensor with all zero value.
  • If the check passes, it seems we will replace torch.ops.aten.empty.memory_format to full in UniformValueConstantFolder pass, causing this test case to fail.

I guess we may add torch.ops.aten.empty.memory_format into the is_impure op list of constant folding.

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 6, 2024

This bisected to #122347 (cc @angelayi) but since it is non-deterministic that might not be right; I don't think that PR is related

@zou3519
Copy link
Contributor Author

zou3519 commented Jun 7, 2024

@leslie-fang-intel any updates from your side? This seems like a regression, so we should really fix it for PyTorch 2.4. But I'm not very familiar with Inductor internals

@zou3519 zou3519 added the module: reinplacing inductor reinplacing, re-inplacing, auto-functionalization, auto functionalization, custom op label Jun 7, 2024
@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Jun 7, 2024

@leslie-fang-intel any updates from your side? This seems like a regression, so we should really fix it for PyTorch 2.4. But I'm not very familiar with Inductor internals

Sure, I will further investigate the behavior with PyTorch 2.3.

@Chillee
Copy link
Contributor

Chillee commented Jun 8, 2024

cc: @eellison

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Jun 8, 2024

Some finding from this issue

  • Issue 1: the torch.ops.aten.empty.memory_format has been replaced to full (which is nondeterminism) as described in nondeterminism in torch.compile + custom op #127995 (comment)

  • Issue 2: with PyTorch 2.3, even we replace torch.ops.aten.empty.memory_format to full, it can still pass the testing. Here is the generated code in PyTorch 2.3 and PyTorch 2.4

Generated code in PyTorch 2.3

  def call(args):
      arg0_1, = args
      args.clear()
      assert_size_stride(arg0_1, (3, ), (1, ))
      buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
      cpp_fused_empty_0(buf0)
      # Source Nodes: [], Original ATen: []
      buf1 = torch.ops.mylib.numpy_sin.default(arg0_1, buf0)
      del arg0_1
      return (buf0, )

Generated code in PyTorch 2.4

def call(args):
    arg0_1, = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    buf0 = empty_strided_cpu((3, ), (1, ), torch.float32)
    # Source Nodes: [], Original ATen: []
    buf1 = torch.ops.mylib.numpy_sin.default(arg0_1, buf0)
    del arg0_1
    cpp_fused_empty_0(buf0)
    return (buf0, )

The cause seems due to the buf0 cpp_fused_empty_0(buf0) will not flush before we codegen high-order op
in PyTorch 2.4

if not isinstance(node, NopKernelSchedulerNode) and (
device := node.get_device()
):
if (
device != self.current_device
or node.is_extern()
or node.is_template()
):
self.flush()

  • get_device of the high-order op external scheduler node return none in PyTorch 2.4 but cpu in PyTorch 2.3

    • @zou3519 I think I need some help here, I am not sure why the device in layout of the high-order op will change. PyTorch 2.3 returns layout of MultiOutputLayout(device=device(type='cpu')) and PyTorch 2.4 returns layout of <torch._inductor.ir.NoneLayout object at 0x7ff7fd6bdae0>.
  • After [effects] Add inductor support for tokens #122347, we will not flush if the device is None.

leslie-fang-intel added a commit that referenced this issue Jun 8, 2024
Fix the issue: #127995

- In current implementation, the device of the `NoneLayout` will be None when `example_output` returns from `cls.process_kernel` is None. The test reported in the issue is the case when `ExternalKernel` returns None. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Jun 11, 2024
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
…sue"


Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
…sue"


Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
…sue"


Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
…sue"


Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
…sue"


Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
leslie-fang-intel added a commit that referenced this issue Jun 14, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jun 15, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

Pull Request resolved: #128275
Approved by: https://github.com/eellison
@leslie-fang-intel
Copy link
Collaborator

zou3519 pushed a commit that referenced this issue Jun 17, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

Pull Request resolved: #128275
Approved by: https://github.com/eellison
atalman pushed a commit that referenced this issue Jun 19, 2024
Fix the issue: #127995

- In current implementation of creating `FallbackKernel`, the `device` of the `NoneLayout` is set to `None` when `example_output` returns from `cls.process_kernel` is `None`. https://github.com/pytorch/pytorch/blob/921aa194c77f5279b15415eaa213813ddcdb3b29/torch/_inductor/ir.py#L5632-L5649
- If a `ExternalKernel schedulerNode` has None device, the previous buffer will not flush before codegen this `ExternalKernel schedulerNode`  which causes the wrong generated code.
https://github.com/pytorch/pytorch/blob/ef2b5ed500cba0b8b2bf04e6006a0d64c910f440/torch/_inductor/scheduler.py#L2701-L2709

**Test Plan**
```
python -u -m pytest -s -v test/higher_order_ops/test_with_effects.py -k test_compile_inductor_external_op_return_none
```

Pull Request resolved: #128275
Approved by: https://github.com/eellison

Co-authored-by: leslie-fang-intel <[email protected]>
@atalman
Copy link
Contributor

atalman commented Jul 9, 2024

confirmed works on 2.4.0 final rc

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: inductor module: reinplacing inductor reinplacing, re-inplacing, auto-functionalization, auto functionalization, custom op oncall: cpu inductor CPU Inductor issues for Intel team to triage oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants