-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile crash - Aborted exit code 134 #125804
Comments
If I try to run the step in https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py#L258-L265 manually, the crash happens when the process exits:
This can be reproduced on devgpu with CUDA 12.1 and python 3.11. |
I am not sure if the |
When running the script, I'm getting
|
hi @xmfan
Run following python script: pytorch/builder@main/test/smoke_test/smoke_test.py
|
@anijain2305 we also need TARGET_OS="linux" I can't repro the issue on 12.1 python 3.10.14 with 2f53747. how do I run this on CI? @atalman after isolating the script to only try torch.compile: diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py
index 3e95c79..8e93cb6 100644
--- a/test/smoke_test/smoke_test.py
+++ b/test/smoke_test/smoke_test.py
@@ -167,6 +167,7 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
(target_os == "linux" and torch.cuda.is_available()) or
target_os == "macos-arm64"):
smoke_test_compile()
+ return
if torch.cuda.is_available():
if torch.version.cuda != gpu_arch_ver:
@@ -314,14 +315,14 @@ def main() -> None:
print(f"torch: {torch.__version__}")
check_version(options.package)
- smoke_test_conv2d()
- test_linalg()
- test_numpy()
- if is_cuda_system:
- test_linalg("cuda")
-
- if options.package == "all":
- smoke_test_modules()
+ # smoke_test_conv2d()
+ # test_linalg()
+ # test_numpy()
+ # if is_cuda_system:
+ # test_linalg("cuda")
+
+ # if options.package == "all":
+ # smoke_test_modules()
smoke_test_cuda(options.package, options.runtime_error_check) I do not see the IMA
|
I can repro some segfault, but it comes from test_cuda_runtime_errors_captured
and there is no core dump if I comment that out |
I can't see the run workflow button on the builder repo (permissions?), I wonder if this has to do with A100 vs other hardware |
stack trace: https://gist.github.com/xmfan/d2dcddda2f042df35832992753e3df34 cc @eellison @BoyuanFeng could you guys help take a look |
mini repro import torch
from torch import nn
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28, 1)
def forward(self, x):
output = self.fc1(x)
return output
x = torch.rand(28, 28, device="cuda")
model = Net().to(device="cuda")
x_pt2 = torch.compile(model, mode="max-autotune")(x)
try:
torch._assert_async(torch.tensor(0, device="cuda"))
except:
print("ignoring exception")
# check for `Aborted (core dumped)` on process exit |
@xmfan this repros as just
assert async is checking for non 0 tensor |
That's intentional for the smoke test, the issue isn't the exception, but that it has updating the repro to ignore the exception |
Note: I had to use nightly to repro the issue consistently: 2.4.0.dev20240521+cu121
|
confirmed that reverting that pr fixes the problem. something to do with new destructor logic, haven't looked into it closely. @eee4017 can you look into it? |
I am willing to investigate this issue; however, my recent commitments have limited my availability. I will be able to dedicate time to this matter over the weekend, which may delay the process. If the issue is urgent, perhaps you could begin looking into it beforehand? |
This weekend is fine. Release cut is M3.1: Release branch cut (6/10/24). I am taking off later part of this week won't be able look before theen. |
Hello @eellison, The bug appears to be caused by the CUDA stream status check located in the destructor of the Graph here. The stream status check was added by me. The error that triggered by the Stack trace:
|
@eee4017 yea, let's change back to the previous warning for the destructor |
Fixes #125804 Pull Request resolved: #127382 Approved by: https://github.com/eqy, https://github.com/eellison (cherry picked from commit d3e8b8b)
Remove cuda check in the CUDAGraph destructor (#127382) Fixes #125804 Pull Request resolved: #127382 Approved by: https://github.com/eqy, https://github.com/eellison (cherry picked from commit d3e8b8b) Co-authored-by: Frank Lin <[email protected]>
Confirmed fix on 2.4 final rc: https://github.com/pytorch/builder/actions/runs/9841936763/job/27169815737#step:12:5409 |
🐛 Describe the bug
Minirepro:
stack trace: https://gist.github.com/xmfan/d2dcddda2f042df35832992753e3df34
Original description:
Repro:
Install torch with cuda 11.8 or 12.1 for python 3.8-3.11
Set global vars:
Run following python script: https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py
Failure:
Failure workflow:
https://github.com/pytorch/builder/actions/runs/9007921255/job/24748864568#step:11:4337
If I comment out this line, no failure is observed:
x_pt2 = torch.compile(model, mode="max-autotune")(x)
https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py#L265C5-L265C57
Started happening on:
2.4.0.dev20240327
This nightly commit:
384cbf2
Versions
nightly
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mcarilli @eellison @peterbell10 @bdhirsh @anijain2305 @chauhang @jansel
The text was updated successfully, but these errors were encountered: