torch.gather can be slow on AMD with duplicated index #128631
Labels
module: rocm
AMD GPU support for Pytorch
module: scatter & gather ops
needs research
We need to decide whether or not this merits inclusion, based on research world
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🐛 Describe the bug
We found that torch.gather can be very slow running on AMD GPU compared with NVIDIA GPU with duplicated index.
To repro:
The slowness comes from the backward kernel: _scatter_gather_elementwise_kernel
On AMD MI300X:
I0612 223553.302 test_gather.py:41] Link to trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/testttt_20240612_223534_3731825843.json.gz&bucket=pyper_traces (Meta internal only)
I0612 223608.688 test_gather.py:42] Time (ms): 1887.4288330078125
Trace:
On H100:
I0612 224254.306 test_gather.py:41] Link to trace: https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/test/testttt_20240612_224254_3319354165.json.gz&bucket=pyper_traces (Meta internal only)
I0612 224254.408 test_gather.py:42] Time (ms): 0.26988503336906433
Versions
Latest PyTorch trunk.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang
The text was updated successfully, but these errors were encountered: