Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse attention map::at triton error #472

Closed
pwstegman opened this issue Nov 28, 2021 · 4 comments
Closed

Sparse attention map::at triton error #472

pwstegman opened this issue Nov 28, 2021 · 4 comments
Labels
bug Something isn't working

Comments

@pwstegman
Copy link
Contributor

pwstegman commented Nov 28, 2021

Describe the bug
When training with sparse attention, Triton throws IndexError: map::at.

This is the full traceback
Traceback (most recent call last):
  File "train.py", line 27, in 
    pretrain(neox_args=neox_args)
  File "/home/gpt-neox/megatron/training.py", line 103, in pretrain
    iteration = train(
  File "/home/gpt-neox/megatron/training.py", line 552, in train
    loss_dict, skipped_iter = train_step(
  File "/home/gpt-neox/megatron/training.py", line 459, in train_step
    reduced_loss = train_step_pipe(
  File "/home/gpt-neox/megatron/training.py", line 508, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 305, in train_batch
    self._exec_schedule(sched)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1308, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 681, in _exec_forward_pass
    outputs = super().forward(inputs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/engine.py", line 987, in forward
    loss = self.module(*inputs, **kwargs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/pipe/module.py", line 350, in forward
    x = self.activation_checkpoint_func(
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 692, in checkpoint
    CheckpointFunction.apply(function, all_outputs, *args)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/activation_checkpointing/checkpointing.py", line 491, in forward
    outputs = run_function(*inputs_cuda)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/runtime/pipe/module.py", line 328, in exec_func
    inputs = layer(inputs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/gpt-neox/megatron/model/transformer.py", line 679, in forward
    return super().forward(hidden_states, attention_mask), attention_mask
  File "/home/gpt-neox/megatron/model/transformer.py", line 639, in forward
    attention_output, attention_bias = self.attention(
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/gpt-neox/megatron/model/transformer.py", line 499, in forward
    context_layer = self.sparse_attention(
  File "/home/gpt-neox/megatron/model/transformer.py", line 421, in sparse_attention
    return self.sparse_attn(
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/ops/sparse_attention/sparse_self_attention.py", line 162, in forward
    attn_output_weights = sparse_dot_sdd_nt(query, key)
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 720, in __call__
    c = _sparse_matmul.apply(a,
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 537, in forward
    c = _sparse_matmul.fn[mode](a,
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/deepspeed/ops/sparse_attention/matmul.py", line 204, in _sdd_matmul
    current = kernel(
  File "/root/.virtualenvs/gptneox/lib/python3.8/site-packages/triton/kernel.py", line 116, in __call__
    kernel = self.fn.autotune(params, grid, self.stream)
IndexError: map::at
Killing subprocess 243

To Reproduce

I cloned GPT-NeoX into my home directory.

Then I started and attached to a docker container with CUDA 10.2:

docker run --gpus all -ti -d --name gptneox -v ~/:/home pytorch/pytorch:1.8.1-cuda10.2-cudnn7-devel
docker attach gptneox

The container had Python 3.6 installed, so I installed Python 3.8. I also installed libopenmpi-dev, as it's required to install all the GPT-NeoX dependencies:

apt update
apt install python3.8 python3.8-dev and python3.8-venv libopenmpi-dev

I created and activated a new Python 3.8 virtual env and then pip installed PyTorch 1.8 and the latest version of Apex (commit aa756cec4359aff3df1d9abb68dc6e6e92920e0c):

python3.8 -m venv ~/.virtualenvs/gptneox
. ~/.virtualenvs/gptneox/bin/activate
python -m pip install --upgrade pip wheel
pip install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Then I installed the GPT-NeoX dependencies:

pip install -r requirements/requirements.txt
pip install -r requirements/requirements-sparseattention.txt

I downloaded a dataset:

python prepare_data.py --vocab-file /home/data/gpt2-vocab.json --merge-file /home/data/gpt2-merges.txt enron

Then I started training without sparse attention, which worked fine:

python ./deepy.py train.py -d configs small.yml local_setup.yml

However, once I added in the sparse attention config, it threw the error mentioned above. This is the command that caused the error:

python ./deepy.py train.py -d configs small.yml sparse.yml local_setup.yml

Expected behavior

Training should run when sparse attention is enabled.

Proposed solution

I'm not sure how to fix this. I did run into this same issue when trying Microsoft's DeepSpeed examples, so this may be an issue inherited from DeepSpeed, rather than something introduced by DeeperSpeed.

Screenshots

N/A

Environment (please complete the following information):

  • GPUs: 1x Tesla V100-SXM2-16GB
  • Configs: unmodified small.yml sparse.yml local_setup.yml
NVCC 10.2
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2019 NVIDIA Corporation
Built on Wed_Oct_23_19:24:38_PDT_2019
Cuda compilation tools, release 10.2, V10.2.89
Nvidia driver 495.44
Sun Nov 28 20:46:07 2021
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 495.44       CUDA Version: 11.5     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    37W / 300W |      0MiB / 16160MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
gcc 7.5.0
gcc (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
pip freeze
aiohttp==3.8.1
aiosignal==1.2.0
apex==0.1
appdirs==1.4.4
async-timeout==4.0.1
attrs==21.2.0
best-download==0.0.7
black==20.8b1
certifi==2021.10.8
chardet==4.0.0
charset-normalizer==2.0.8
click==8.0.3
colorama==0.4.4
configparser==5.1.0
cxxfilt==0.3.0
Cython==0.29.24
DataProperty==0.54.2
datasets==1.16.1
deepspeed @ git+git:https://github.com/EleutherAI/DeeperSpeed.git@eb7f5cff36678625d23db8a8fe78b4a93e5d2c75
dill==0.3.4
docker-pycreds==0.4.0
dyNET38==2.1
einops==0.3.0
filelock==3.4.0
frozenlist==1.2.0
fsspec==2021.11.1
ftfy==6.0.1
gitdb==4.0.9
GitPython==3.1.24
huggingface-hub==0.1.2
idna==3.3
iniconfig==1.1.1
jieba==0.42.1
joblib==1.1.0
jsonlines==2.0.0
lm-dataformat==0.0.19
lm-eval @ git+https://github.com/EleutherAI/lm-evaluation-harness.git@dc937d4b70af819c5695e09d94e59e4cdb1e40ad
mbstrdecoder==1.1.0
mock==4.0.3
mpi4py==3.0.3
msgfy==0.2.0
multidict==5.2.0
multiprocess==0.70.12.2
mypy-extensions==0.4.3
nagisa==0.2.7
ninja==1.10.2.3
numexpr==2.7.2
numpy==1.20.2
openai==0.6.4
packaging==21.3
pandas==1.3.4
pathspec==0.9.0
pathtools==0.1.2
pathvalidate==2.5.0
Pillow==8.4.0
pkg_resources==0.0.0
pluggy==0.13.1
portalocker==2.3.2
promise==2.3
protobuf==3.19.1
psutil==5.8.0
py==1.11.0
pyarrow==6.0.1
pybind11==2.6.2
pycountry==20.7.3
pyparsing==3.0.6
pytablewriter==0.58.0
pytest==6.2.3
python-dateutil==2.8.2
pytz==2021.3
PyYAML==6.0
regex==2021.11.10
rehash==1.0.0
requests==2.26.0
sacrebleu==1.5.0
sacremoses==0.0.46
scikit-learn==1.0.1
scipy==1.7.3
sentry-sdk==1.5.0
shortuuid==1.0.8
six==1.16.0
smmap==5.0.0
sqlitedict==1.6.0
subprocess32==3.5.4
tabledata==1.3.0
tcolorpy==0.1.1
tensorboardX==1.8
threadpoolctl==3.0.0
tokenizers==0.10.2
toml==0.10.2
torch==1.8.0
torchaudio==0.8.0
torchvision==0.9.0
tqdm==4.62.3
tqdm-multiprocess==0.0.11
transformers==4.5.0
triton==0.4.2
typed-ast==1.5.0
typepy==1.3.0
typing_extensions==4.0.0
ujson==4.3.0
urllib3==1.26.7
wandb==0.10.28
wcwidth==0.2.5
xxhash==2.0.2
yarl==1.7.2
zstandard==0.15.2

Additional context

None

@pwstegman pwstegman added the bug Something isn't working label Nov 28, 2021
@StellaAthena
Copy link
Member

StellaAthena commented Nov 29, 2021

Thanks for bringing this to our attention! If you’re having the same issue with NVIDIA’s repo I recommend opening an issue on DeepSpeed as well, and linking the two. Personally, I have a lot going on right now but I can look into this in a week or so.

@pwstegman
Copy link
Contributor Author

Thank you! I've opened an issue on the DeepSpeed repo: microsoft/DeepSpeed#1595. I'll also look into it to see if I can find anything.

@StellaAthena
Copy link
Member

map::at appears to be a C++ thing.

@pwstegman
Copy link
Contributor Author

Downgrading my Nvidia driver fixed the issue!

sudo apt install nvidia-driver-440

This installed Nvidia driver 460.91.03 (not sure why it's not 440) but it works! I was previously on driver version 495.44.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants