Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: invalid dtype for bias when use compile + autocast #124901

Open
yiliu30 opened this issue Apr 25, 2024 · 1 comment
Open

RuntimeError: invalid dtype for bias when use compile + autocast #124901

yiliu30 opened this issue Apr 25, 2024 · 1 comment
Labels
module: amp (automated mixed precision) autocast oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yiliu30
Copy link
Contributor

yiliu30 commented Apr 25, 2024

🐛 Describe the bug

When I tried using torch.compile along with autocast to infereance a llama's decoder block, I encountered RuntimeError: invalid dtype for bias - should match query's dtype.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.amp import autocast

model_name = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = "cuda"
amp_dtype = torch.float16
block = model.model.layers[0]
block = block.to(device).to(amp_dtype)

input_ids = torch.randn(8, 2048, 4096).to(amp_dtype).to(device)
input_others = {
    "attention_mask": torch.randn(8, 1, 4096, 4096).to(amp_dtype).to(device),
    "position_ids": torch.arange(2048).unsqueeze(0).to(torch.int64).to(device),
    "cache_position": torch.arange(2048).to(torch.int64).to(device),
}

opt_block = torch.compile(block)
out_without_amp = opt_block.forward(input_ids, **input_others)
print(f"out_without_amp[0].shape: {out_without_amp[0].shape}")

with autocast(device_type=device, dtype=amp_dtype):
    out = opt_block.forward(input_ids, **input_others)
    print(f"out[0].shape: {out[0].shape}")

Versions

$ python collect_env.py 
Collecting environment information...
PyTorch version: 2.4.0.dev20240318+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 10.5.0-1ubuntu1~20.04) 10.5.0
Clang version: Could not collect
CMake version: version 3.29.0
Libc version: glibc-2.31

Python version: 3.9.18 (main, Sep 11 2023, 13:41:44)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-169-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 12.3.52
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: 
GPU 0: NVIDIA A100-SXM4-80GB
GPU 1: NVIDIA A100-SXM4-80GB
GPU 2: NVIDIA A100-SXM4-80GB
GPU 3: NVIDIA A100-SXM4-80GB
GPU 4: NVIDIA A100-SXM4-80GB
GPU 5: NVIDIA A100-SXM4-80GB
GPU 6: NVIDIA A100-SXM4-80GB
GPU 7: NVIDIA A100-SXM4-80GB

...
Versions of relevant libraries:
[pip3] functorch==1.14.0a0+b71aa0b
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+989adb9a29
[pip3] torch==2.4.0.dev20240318+cu121
[pip3] torchaudio==2.2.1
[pip3] torchvision==0.18.0.dev20240318+cu121
[pip3] triton==2.2.0
[conda] blas                      1.0                         mkl  
[conda] ffmpeg                    4.3                  hf484d3e_0    pytorch
[conda] libjpeg-turbo             2.0.0                h9bf148f_0    pytorch
[conda] mkl                       2023.1.0         h213fc3f_46344  
[conda] mkl-service               2.4.0            py39h5eee18b_1  
[conda] mkl_fft                   1.3.8            py39h5eee18b_0  
[conda] mkl_random                1.2.4            py39hdb19cb5_0  
[conda] numpy                     1.26.4           py39h5f9d8c6_0  
[conda] numpy-base                1.26.4           py39hb5e798b_0  
[conda] pytorch-cuda              12.1                 ha16c6d3_5    pytorch
[conda] pytorch-mutex             1.0                        cuda    pytorch
[conda] pytorch-triton            3.0.0+989adb9a29          pypi_0    pypi
[conda] torch                     2.1.2                    pypi_0    pypi
[conda] torchaudio                2.2.1                py39_cu121    pytorch
[conda] torchtriton               2.2.0                      py39    pytorch
[conda] torchvision               0.18.0.dev20240318+cu121          pypi_0    pypi

cc @mcarilli @ptrblck @leslie-fang-intel @jgong5 @ezyang @msaroufim @bdhirsh @anijain2305 @chauhang

@jbschlosser jbschlosser added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: amp (automated mixed precision) autocast labels Apr 26, 2024
@marcociccone
Copy link

marcociccone commented Jun 23, 2024

is there any fix for this issue? I have the same exact error when combining autocast and compile. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants