Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.gather can be slow on AMD with duplicated index #128631

Closed
Yuzhen11 opened this issue Jun 13, 2024 · 1 comment
Closed

torch.gather can be slow on AMD with duplicated index #128631

Yuzhen11 opened this issue Jun 13, 2024 · 1 comment
Labels
module: rocm AMD GPU support for Pytorch module: scatter & gather ops needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Yuzhen11
Copy link
Contributor

Yuzhen11 commented Jun 13, 2024

🐛 Describe the bug

We found that torch.gather can be very slow running on AMD GPU compared with NVIDIA GPU with duplicated index.

To repro:

	
import torch
import triton

device = "cuda"
input = torch.randn([100], dtype=torch.float, device=device, requires_grad=True)
indices = torch.randint(low=0, high=1, size=(74544,), dtype=torch.int64, device=device)

def gather():
    output = torch.gather(input, dim=0, index=indices)
    output.sum().backward()

triton.testing.do_bench(lambda: gather())

The slowness comes from the backward kernel: _scatter_gather_elementwise_kernel

On AMD MI300X:

I0612 223553.302 test_gather.py:41] Link to trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/testttt_20240612_223534_3731825843.json.gz&bucket=pyper_traces (Meta internal only)
I0612 223608.688 test_gather.py:42] Time (ms): 1887.4288330078125

Trace:
image

On H100:

I0612 224254.306 test_gather.py:41] Link to trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/testttt_20240612_224254_3319354165.json.gz&bucket=pyper_traces (Meta internal only)
I0612 224254.408 test_gather.py:42] Time (ms): 0.26988503336906433

Versions

Latest PyTorch trunk.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

@Yuzhen11 Yuzhen11 assigned Yuzhen11 and unassigned Yuzhen11 Jun 13, 2024
@soulitzer soulitzer added module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: scatter & gather ops labels Jun 14, 2024
@hongxiayang hongxiayang added the needs research We need to decide whether or not this merits inclusion, based on research world label Jun 14, 2024
@jerrymannil
Copy link
Contributor

The issue seems to be fixed by using atomicAdd instead of atomicAddNoRet here

I was getting a deprecated warning for atomicAddNoRet; hence tried with atomicAdd.
They are implemented using different compiler intrinsics.

atomicAddNoRet: https://github.com/ROCm/clr/blob/4ec3a977b21da0025609fad0564edaa98436e077/hipamd/include/hip/amd_detail/amd_hip_atomic.h#L283

atomicAdd: https://github.com/ROCm/clr/blob/4ec3a977b21da0025609fad0564edaa98436e077/hipamd/include/hip/amd_detail/amd_hip_atomic.h#L1425

jerrymannil added a commit to jerrymannil/pytorch that referenced this issue Jun 14, 2024
Fixes pytorch#128631
Current implementation is very specific to MI100.
This is causing performance degradation for other GPUs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rocm AMD GPU support for Pytorch module: scatter & gather ops needs research We need to decide whether or not this merits inclusion, based on research world triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: Done
Development

Successfully merging a pull request may close this issue.

4 participants