Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile crash - Aborted exit code 134 #125804

Closed
Tracked by #130151
atalman opened this issue May 8, 2024 · 21 comments
Closed
Tracked by #130151

torch.compile crash - Aborted exit code 134 #125804

atalman opened this issue May 8, 2024 · 21 comments
Assignees
Labels
high priority module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@atalman
Copy link
Contributor

atalman commented May 8, 2024

🐛 Describe the bug

Minirepro:

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28, 1)

    def forward(self, x):
        output = self.fc1(x)
        return output


x = torch.rand(28, 28, device="cuda")
model = Net().to(device="cuda")
x_pt2 = torch.compile(model, mode="max-autotune")(x)
try:
    torch._assert_async(torch.tensor(0, device="cuda"))
except:
    print("ignoring exception")
    
# check for `Aborted (core dumped)` on process exit
(segfault) [16:07:08] ~/builder/test/smoke_test (main) > python minirepro.py
/home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:133: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
../aten/src/ATen/native/cuda/TensorCompare.cu:106: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
Aborted (core dumped)

stack trace: https://gist.github.com/xmfan/d2dcddda2f042df35832992753e3df34

#0  0x00007ffff7c8b94c in __pthread_kill_implementation () from /lib64/libc.so.6
#1  0x00007ffff7c3e646 in raise () from /lib64/libc.so.6
#2  0x00007ffff7c287f3 in abort () from /lib64/libc.so.6
#3  0x00007ffff66b135a in __cxxabiv1::__terminate (handler=<optimized out>) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_terminate.cc:48
#4  0x00007ffff66b03b9 in __cxa_call_terminate (ue_header=0x14225d90) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_call.cc:54
#5  0x00007ffff66b0ae7 in __cxxabiv1::__gxx_personality_v0 (version=<optimized out>, actions=6, exception_class=5138137972254386944, ue_header=0x14225d90, context=<optimized out>) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libstdc++-v3/libsupc++/eh_personality.cc:685
#6  0x00007ffff74f51e4 in _Unwind_RaiseException_Phase2 (exc=0x14225d90, context=0x7fffffffbdb0, frames_p=0x7fffffffbcb8) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libgcc/unwind.inc:64
#7  0x00007ffff74f5c1e in _Unwind_Resume (exc=0x14225d90) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libgcc/unwind.inc:241
#8  0x00007fffa90015fb in at::CUDAGeneratorState::unregister_graph(at::cuda::CUDAGraph*) [clone .cold] () from /home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so
#9  0x00007fffa91e1e3c in at::cuda::CUDAGraph::~CUDAGraph() () from /home/xmfan/.conda/envs/segfault/lib/python3.10/site-packages/torch/lib/libtorch_cuda.so

Original description:

Repro:
Install torch with cuda 11.8 or 12.1 for python 3.8-3.11
Set global vars:

MATRIX_GPU_ARCH_VERSION=12.1
MATRIX_GPU_ARCH_TYPE=cuda

Run following python script: https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py

python smoke_test.py --package torchonly

Failure:

True
Testing smoke_test_compile with mode 'max-autotune' for torch.float32
torch cuda: 12.1
torch cudnn: 8902
cuDNN enabled? True
torch nccl version: (2, 20, 5)
Testing test_cuda_runtime_errors_captured
../aten/src/ATen/native/cuda/TensorCompare.cu:106: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
Caught CUDA exception with success: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Aborted (core dumped) 

Failure workflow:
https://github.com/pytorch/builder/actions/runs/9007921255/job/24748864568#step:11:4337

If I comment out this line, no failure is observed:
x_pt2 = torch.compile(model, mode="max-autotune")(x)
https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py#L265C5-L265C57

Started happening on:
2.4.0.dev20240327

This nightly commit:
384cbf2

Versions

nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @mcarilli @eellison @peterbell10 @bdhirsh @anijain2305 @chauhang @jansel

@atalman atalman changed the title torch.compile segfault torch.compile crash on exception May 8, 2024
@atalman atalman changed the title torch.compile crash on exception torch.compile crash - error code 134 May 8, 2024
@atalman atalman changed the title torch.compile crash - error code 134 torch.compile crash - Aborted exit code 134 May 8, 2024
@huydhn
Copy link
Contributor

huydhn commented May 8, 2024

If I try to run the step in https://github.com/pytorch/builder/blob/main/test/smoke_test/smoke_test.py#L258-L265 manually, the crash happens when the process exits:

Exception ignored in: <function ExactWeakKeyDictionary.__setitem__.<locals>.<lambda> at 0x7fc5c565bba0>
Traceback (most recent call last):
  File "/data/users/huydo/conda/py3.11/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 474, in <lambda>
    self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
                                                  ^^^^^^^^^^^^^^^^^^^^
  File "/data/users/huydo/conda/py3.11/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 752, in _remove_id
    hook()
  File "/data/users/huydo/conda/py3.11/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 736, in _call_
    del self.scope[[self.name](http:https://self.name/)]
        ~~~~~~~~~~^^^^^^^^^^^
KeyError: '__builtins_dict___0'

This can be reproduced on devgpu with CUDA 12.1 and python 3.11.

@anijain2305
Copy link
Contributor

I am not sure if the __builtins_dict__0 is the right direction. I tried this on Cuda 12.0 (by changing the env variable) and 3.10. It does not call smoke_test_compile and still fails with the IMA error.

@xmfan
Copy link
Member

xmfan commented May 21, 2024

When running the script, I'm getting RuntimeError: Wrong CUDA version. Loaded: 12.1 Expected: None. I have 12.1 installed

> nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Mon_Apr__3_17:16:06_PDT_2023
Cuda compilation tools, release 12.1, V12.1.105
Build cuda_12.1.r12.1/compiler.32688072_0

@atalman
Copy link
Contributor Author

atalman commented May 21, 2024

hi @xmfan
Set global vars:

MATRIX_GPU_ARCH_VERSION=12.1
MATRIX_GPU_ARCH_TYPE=cuda

Run following python script: pytorch/builder@main/test/smoke_test/smoke_test.py

python smoke_test.py --package torchonly

@xmfan
Copy link
Member

xmfan commented May 21, 2024

@anijain2305 we also need TARGET_OS="linux"

I can't repro the issue on 12.1 python 3.10.14 with 2f53747. how do I run this on CI? @atalman

after isolating the script to only try torch.compile:

diff --git a/test/smoke_test/smoke_test.py b/test/smoke_test/smoke_test.py
index 3e95c79..8e93cb6 100644
--- a/test/smoke_test/smoke_test.py
+++ b/test/smoke_test/smoke_test.py
@@ -167,6 +167,7 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
         (target_os == "linux" and torch.cuda.is_available()) or
         target_os == "macos-arm64"):
         smoke_test_compile()
+        return
 
     if torch.cuda.is_available():
         if torch.version.cuda != gpu_arch_ver:
@@ -314,14 +315,14 @@ def main() -> None:
     print(f"torch: {torch.__version__}")
 
     check_version(options.package)
-    smoke_test_conv2d()
-    test_linalg()
-    test_numpy()
-    if is_cuda_system:
-        test_linalg("cuda")
-
-    if options.package == "all":
-        smoke_test_modules()
+    # smoke_test_conv2d()
+    # test_linalg()
+    # test_numpy()
+    # if is_cuda_system:
+    #     test_linalg("cuda")
+
+    # if options.package == "all":
+    #     smoke_test_modules()
 
     smoke_test_cuda(options.package, options.runtime_error_check)

I do not see the IMA

(benchmarks) [13:24:06] ~/builder/test/smoke_test (main) > python smoke_test.py --package torchonly
torch: 2.4.0a0+git2f53747
Skip version check for channel None as stable version is None
Testing smoke_test_compile for torch.float16
False
Testing smoke_test_compile for torch.float32
True
Testing smoke_test_compile for torch.float64
True
Testing smoke_test_compile with mode 'max-autotune' for torch.float32
(benchmarks) [13:24:34] ~/builder/test/smoke_test (main) > 

@xmfan
Copy link
Member

xmfan commented May 21, 2024

I can repro some segfault, but it comes from test_cuda_runtime_errors_captured

(segfault) [14:24:57] ~/builder/test/smoke_test (main) > python smoke_test.py --package torchonly
torch: 2.4.0.dev20240521+cu121
Skip version check for channel None as stable version is None
Testing smoke_test_conv2d
Testing smoke_test_conv2d with cuda
Testing smoke_test_conv2d with cuda for torch.float16
Testing smoke_test_conv2d with cuda for torch.float32
Testing smoke_test_conv2d with cuda for torch.float64
Testing smoke_test_linalg on cpu
Testing smoke_test_linalg on cuda
Testing smoke_test_linalg with cuda for torch.float32
Testing smoke_test_linalg with cuda for torch.float64
Testing smoke_test_compile for torch.float16
False
Testing smoke_test_compile for torch.float32
True
Testing smoke_test_compile for torch.float64
True
Testing smoke_test_compile with mode 'max-autotune' for torch.float32
torch cuda: 12.1
torch cudnn: 8902
cuDNN enabled? True
torch nccl version: (2, 20, 5)
Testing test_cuda_runtime_errors_captured
../aten/src/ATen/native/cuda/TensorCompare.cu:106: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `input[0] != 0` failed.
Caught CUDA exception with success: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Aborted (core dumped)

and there is no core dump if I comment that out

@xmfan
Copy link
Member

xmfan commented May 21, 2024

I can't see the run workflow button on the builder repo (permissions?), I wonder if this has to do with A100 vs other hardware

@xmfan xmfan added the module: cuda graphs Ability to capture and then replay streams of CUDA kernels label May 21, 2024
@xmfan
Copy link
Member

xmfan commented May 21, 2024

stack trace: https://gist.github.com/xmfan/d2dcddda2f042df35832992753e3df34
happens during CUDAGraph destructor, can also be repro'd by changing max-autotune -> reduce-overhead

cc @eellison @BoyuanFeng could you guys help take a look

@xmfan
Copy link
Member

xmfan commented May 21, 2024

mini repro

import torch
from torch import nn


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28, 1)

    def forward(self, x):
        output = self.fc1(x)
        return output


x = torch.rand(28, 28, device="cuda")
model = Net().to(device="cuda")
x_pt2 = torch.compile(model, mode="max-autotune")(x)
try:
    torch._assert_async(torch.tensor(0, device="cuda"))
except:
    print("ignoring exception")
    
# check for `Aborted (core dumped)` on process exit

@eellison
Copy link
Contributor

@xmfan this repros as just

import torch
torch._assert_async(torch.tensor(0, device="cuda"))

assert async is checking for non 0 tensor

@xmfan
Copy link
Member

xmfan commented May 22, 2024

That's intentional for the smoke test, the issue isn't the exception, but that it has Aborted (core dumped) on process exit, which is not present if we run the script with other torch.compile modes

updating the repro to ignore the exception

@eellison
Copy link
Contributor

Root cause is likely #114068. @eee4017 are you able to submit fix ? (will take a look if i dont hear back)

@xmfan
Copy link
Member

xmfan commented May 22, 2024

Note: I had to use nightly to repro the issue consistently: 2.4.0.dev20240521+cu121

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 --force-reinstall

@eee4017
Copy link
Contributor

eee4017 commented May 22, 2024

Root cause is likely #114068. @eee4017 are you able to submit fix ? (will take a look if i dont hear back)

Is this issue related to supporting the capture of generator states exchange within the torch.compile ? Or did the problem stem from corrupted structure within CUDAGraph?

@eellison
Copy link
Contributor

confirmed that reverting that pr fixes the problem. something to do with new destructor logic, haven't looked into it closely. @eee4017 can you look into it?

@eee4017
Copy link
Contributor

eee4017 commented May 22, 2024

I am willing to investigate this issue; however, my recent commitments have limited my availability. I will be able to dedicate time to this matter over the weekend, which may delay the process. If the issue is urgent, perhaps you could begin looking into it beforehand?

@mlazos mlazos added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 22, 2024
@eellison
Copy link
Contributor

This weekend is fine. Release cut is M3.1: Release branch cut (6/10/24). I am taking off later part of this week won't be able look before theen.

@eee4017
Copy link
Contributor

eee4017 commented May 25, 2024

Hello @eellison,

The bug appears to be caused by the CUDA stream status check located in the destructor of the Graph here. The stream status check was added by me. The error that triggered by the _assert_async_cuda_kernel is captured by the C10_CUDA_CHECK in currentStreamCaptureStatusMayInitCtx. In previous implementation of the destructor, we use only C10_CUDA_CHECK_WARN, so it does not trigger abort. Should the destructor avoid triggering a CUDA abort? Could we just remove this particular check from the destructor to prevent potential issues?

Stack trace:

terminate called after throwing an instance of 'c10::Error'
what():  CUDA error: device-side assert triggered  

...
#8  0x00007ffff78aec1e in _Unwind_Resume (exc=0x7e2fe0) at /opt/conda/conda-bld/gcc-compiler_1654084175708/work/gcc/libgcc/unwind.inc:241
#9  0x00007fffce630a14 in at::CUDAGeneratorState::unregister_graph (this=0x59ddd30, graph=0x8122c70) at pytorch/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp:156
#10 0x00007fffce65074b in at::cuda::CUDAGraph::~CUDAGraph (this=0x8122c70, __in_chrg=<optimized out>) at pytorch/aten/src/ATen/cuda/CUDAGraph.cpp:326
#11 0x00007fffee2e9d46 in std::_Sp_counted_ptr<at::cuda::CUDAGraph*, (__gnu_cxx::_Lock_policy)2>::_M_dispose (this=0x82eb8f0) at /usr/include/c++/11/bits/shared_ptr_base.h:348
#12 0x00007fffed249eac in std::_Sp_counted_base<(__gnu_cxx::_Lock_policy)2>::_M_release (this=0x82eb8f0) at /usr/include/c++/11/bits/shared_ptr_base.h:168
#13 0x00007fffed2436a1 in std::__shared_count<(__gnu_cxx::_Lock_policy)2>::~__shared_count (this=0x7fff727d30d0, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:705
#14 0x00007fffee2e5dc4 in std::__shared_ptr<at::cuda::CUDAGraph, (__gnu_cxx::_Lock_policy)2>::~__shared_ptr (this=0x7fff727d30c8, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr_base.h:1154
#15 0x00007fffee2e5de4 in std::shared_ptr<at::cuda::CUDAGraph>::~shared_ptr (this=0x7fff727d30c8, __in_chrg=<optimized out>) at /usr/include/c++/11/bits/shared_ptr.h:122
#16 0x00007fffee2e5e39 in pybind11::class_<at::cuda::CUDAGraph, std::shared_ptr<at::cuda::CUDAGraph> >::dealloc (v_h=...) at pytorch/third_party/pybind11/include/pybind11/pybind11.h:1939
#17 0x00007fffed23d36f in pybind11::detail::clear_instance (self=0x7fff727d30b0) at pytorch/third_party/pybind11/include/pybind11/detail/class.h:421
#18 0x00007fffed23d4a9 in pybind11::detail::pybind11_object_dealloc (self=0x7fff727d30b0) at pytorch/third_party/pybind11/include/pybind11/detail/class.h:454

@eellison
Copy link
Contributor

@eee4017 yea, let's change back to the previous warning for the destructor

@atalman
Copy link
Contributor Author

atalman commented Jul 9, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cuda graphs Ability to capture and then replay streams of CUDA kernels oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

7 participants