-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[inductor] Add lowering and codegen for aten.sort #128458
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128458
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (3 Unrelated Failures)As of commit d7643bd with merge base c888ee3 (): FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: dccc6d29e991f3c340847808fde00a970f8be5c8 Pull Request resolved: #128458
ghstack-source-id: e73a00fed67aa344036a8152693a36ef71b6dda4 Pull Request resolved: #128458
ghstack-source-id: 859705512a7341dedf55477606723453d5f21ec0 Pull Request resolved: #128458
sort_numel = sizevars.simplify(sympy_product(sort_ranges)) | ||
|
||
# Heuristic, smallest rblock where triton usually outperforms aten.sort | ||
max_rblock = 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only uses the lowering for rnumel <= 256
because above that it seems to be significantly slower than eager. It also isn't bandwidth bound so I'm not convinced fusion would save us here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you leave this as a comment in the code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if this kernel isn't bandwidth bound the subsequent fusion might be.
|
||
|
||
@triton.jit | ||
def sort_with_index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These functions are adapted from the triton library function but with added stable sort and the addition of an rmask
to support sorting on non-powers of two.
|
||
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway | ||
# so taking the hit of non-coalesced loads is okay | ||
has_sort = any(_node_has_sort(node) for node in node_schedule) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit of a work-around. For outer reductions we don't normally use RBLOCK >= 64
but for sort since we can't loop over we really need the longer RBLOCK
as we don't have a non-persistent version of the sort. So, I just add an exception if I find ops.sort
in the kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mind adding some benchmarks to the OP?
The implementation LGTM
return bool(sort_nodes) | ||
|
||
# ops.sort only works with persistent reduction, and is not bandwidth bound anyway | ||
# so taking the hit of non-coalesced loads is okay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is what ways are coalesced loads related to whether a kernel is persistent or not?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For outer reductions we can use e.g. XBLOCK=32
, RBLOCK=32
which requires a non-persistent reduction for rnumel > 32
but allows the loads to be coalesced in the x dimension.
sort_numel = sizevars.simplify(sympy_product(sort_ranges)) | ||
|
||
# Heuristic, smallest rblock where triton usually outperforms aten.sort | ||
max_rblock = 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you leave this as a comment in the code?
torch/_inductor/ops_handler.py
Outdated
@@ -744,6 +756,10 @@ def frexp(x) -> Tuple[None, None]: | |||
def scan(dtypes, combine_fn, values) -> Tuple[None, ...]: | |||
return tuple(None for i in range(len(values))) | |||
|
|||
@staticmethod | |||
def sort(dtypes, values, stable, descending) -> Tuple[None, ...]: | |||
return tuple(None for i in range(len(values))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. and it follows the pattern above, but we can simply do (None,) * len(values)
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape) | ||
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use this trick all over the place because triton's reduce does not work for non-commutative operators. Otherwise, we could do a reduce with lambda a, b: b
. If you think it'd be beneficial (you say this kernel is not bandwidth?), I could send a PR to triton fixing this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to give it a shot
ghstack-source-id: 559ed4e005ee962dbad2b0824f27d4935afe2b19 Pull Request resolved: #128458
descending=descending, | ||
) | ||
if values is None: | ||
return sort_fallback(x, stable=stable, dim=dim, descending=descending) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should also fallback if using halide codegen
sort_numel = sizevars.simplify(sympy_product(sort_ranges)) | ||
|
||
# Heuristic, smallest rblock where triton usually outperforms aten.sort | ||
max_rblock = 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if this kernel isn't bandwidth bound the subsequent fusion might be.
|
||
# Heuristic, smallest rblock where triton usually outperforms aten.sort | ||
# It also isn't bandwidth bound so fusion is unlikely to help. | ||
max_rblock = 256 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously you had suggested 1024 as max. Would you mind posting the numbers for max_rblock > 256, maybe with and without cheap pointwise fusion
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, here are some speedup numbers for
torch.sort(a * a, dim=-1)[0] + 2
at different shapes. So we have fusions before and after the sort, as well as being able to eliminate the index buffer to give the triton kernel its best shot:
1024, 255: 2.1x
1024, 256: 6.0x
1024, 257: 0.87x
1024, 511: 1.0x
1024, 512: 2.2x
1024, 513: 0.35x
1024, 1023: 0.36x
1024, 1024: 0.93x
You can see that it's marginally, but noticeably worse between 257-511 and only the exact value of 512 gains any speedup (where the mask is removed). Compared to 256
self.common(fn, (inp, False)) | ||
self.common(fn, (inp, True)) | ||
|
||
def test_sort_stable(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add some tests with rblock constrained <= 256 but still dynamic ?
|
||
|
||
@triton.jit | ||
def sort_with_index( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we optimize out the index global write when it's not needed ? and if so, it still might be possible that triton isn't smart even to remove all the index related intermediaries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The write would be removed by inductor and triton can probably DCE the final steps of the index computation, but I do expect we could get a bigger win by having custom codegen for for sort without index. Though that wouldn't be able to support stable sorting since I'm using the indices to get a stable sort from the bitonic sort algorithm which is naturally unstable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually benchmarking it I see without indices is almost 2x faster so I guess triton is able to remove the index reductions entirely. That's pretty cool actually.
tl.static_assert( | ||
_dim == len(x.shape) - 1, "only minor dimension is currently supported" | ||
) | ||
# iteratively run bitonic merge-sort steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we pad the masked values with inf/-inf here depending on descending
instead of doing all the mask work in compare_and_swap_with_index
?
Maybe that would give better perf for masked inputs ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That has the issue that inf
might appear in the input and so we might sort the padding into the output. That said, it would work if either:
- we don't care about the indices
- we're doing a stable sort and so the in-sequence infs would get sorted earlier than the padding.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, not blocking for this land but maybe file issue after you land ?
From original issue 5/6 of the sorts either did not use indices or used stable sort https://gist.github.com/eellison/36f0c3e6025360315dd67932461ff11b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For index case could use negative number for padded index and use that as part of tie breaker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Opened #129507
[ghstack-poisoned]
ghstack-source-id: 1809983d39ae9039fc3f83b4b6738728c47861c4 Pull Request resolved: #128458
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 1809983d39ae9039fc3f83b4b6738728c47861c4 Pull Request resolved: pytorch#128458
Stack from ghstack (oldest at bottom):
Closes #125633
Benchmarks:
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang