-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.compile
with reduce-overhead
: very long compile time + GPU memory continuously to grow
#128424
Comments
This is probably the same as #119640 reduce-overhead is not magic fairy dust. It works by doing CUDA graphs. CUDA graphs do not work with dynamic shapes. So we CUDA graph each individual dynamic shape individually. This can end up using a lot of memory. To reduce memory usage, you will need to do some padding at multiples. Or you can rearchitect your prefill/decode so that it is CUDA graph friendly, as was done in gpt-fast. |
@ezyang Besides my support for more love for padding multiples for nestedtensor constructors (e.g. #65156) and more inplacing/out-versions (e.g. for torch.cat), but also could be cool to have somehow more introspection into the Inductor compiler cache / CUDA graph cache. E.g. if there was a way to list from Python all cached shape/dtype specializations, it would be easier to diagnose/confirm this sort of problems (e.g. it would be growing along with time) + maybe some higher-level metric on memory fragmentation or more examples on memory allocator stats. E.g. could one enable more coarse memory allocator segment sizes without torch recompilation? (this could go along with fully-fledged support for customized/reconfigured memory allocators) |
Thanks @ezyang! OK, guess we have to use the workarounds you mentioned like padding. I agree what @vadimkantorov mentioned about a way to investigate this cache stuff. (Probably it's already possible with |
One other thing: cuda has a driver-level issue where cudagraphs take a lot of memory on device (64kb per kernel). That is fixed on cuda 12.4 and driver 550+. |
Confirmed that keep all tensors (not just in the arguments of the top level Feel free if you think we could close this issue. |
Is there anyway we can save the cache if the input size is exactly same? Recompiling (even with cache) at first-run is very slow (40s+) for GPT-fast. |
The issue mentioned in this issue is not related to I personally tried |
馃悰 Describe the bug
For code snippet, see at the end
model.py
: very simplerun.py
: a bit more complex but just in order to measure the memory and timingIn short:
torch.compile
withreduce-overhead
takes long time and the GPU memory usage continues to grow (in the second call to each shape seen).Although the code snippet is dummy, the same situation happens when I check with
Llama
orGemma
models fromHuggingFace
.This could be seen clearly from the output sections below. Here are a few explanations to facilitate the understanding (code snippet and the outputs):
generate
(up to amax_len
steps)forward
)generate
) sees all possible input shape: time taken and memory usage don't vary muchmax_len=1024
max_len=2048
max_len=4096
outputs (with information from intermediate steps)
below)max_len=1024
- timing: 115.606751
- Used GPU memory increased: 150.0 MB.
max_len=2048
- timing: 232.851245
- Used GPU memory increased: 302.0 MB.
max_len=4096
- timing: 565.084438
- Used GPU memory increased: 606.0 MB.
max_len=1024
max_len=2048
max_len=4096
Question: The slowness and memory accumulation (especially with super small model here) in the 2nd iteration makes
torch.compile
impractical (in such use-case, which seems to be a common use-case)outputs (brief)
1024 steps (per iteration)
2048 steps (per iteration)
outputs (with information from intermediate steps)
2048 steps
model.py
modeling code
run.py
A script to run.
Versions
Collecting environment information...
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: Could not collect
CMake version: version 3.20.3
Libc version: glibc-2.31
Python version: 3.8.10 (default, Nov 22 2023, 10:22:35) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.10.0-30-cloud-amd64-x86_64-with-glibc2.29
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 550.54.15
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
Address sizes: 46 bits physical, 48 bits virtual
CPU(s): 8
On-line CPU(s) list: 0-7
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 63
Model name: Intel(R) Xeon(R) CPU @ 2.30GHz
Stepping: 0
CPU MHz: 2299.998
BogoMIPS: 4599.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 128 KiB
L1i cache: 128 KiB
L2 cache: 1 MiB
L3 cache: 45 MiB
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Mitigation; PTE Inversion
Vulnerability Mds: Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Meltdown: Mitigation; PTI
Vulnerability Mmio stale data: Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; IBRS, IBPB conditional, STIBP conditional, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm invpcid_single pti ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid xsaveopt arat md_clear arch_capabilities
Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.3.0
[pip3] mypy-extensions==1.0.0
[pip3] natten==0.15.1+torch220cu121
[pip3] numpy==1.24.3
[pip3] onnx==1.16.1
[pip3] onnxconverter-common==1.13.0
[pip3] onnxruntime==1.18.0
[pip3] onnxruntime-tools==1.7.0
[pip3] tf2onnx==1.16.1
[pip3] torch==2.3.0+cu121
[pip3] torchaudio==2.3.0+cu121
[pip3] torchvision==0.18.0+cu121
[pip3] triton==2.3.0
[conda] Could not collect
cc @mcarilli @ezyang @eellison @peterbell10 @bdhirsh @anijain2305 @chauhang
The text was updated successfully, but these errors were encountered: