Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Add lowering and codegen for aten.sort #128458

Closed
wants to merge 11 commits into from

Conversation

peterbell10
Copy link
Collaborator

@peterbell10 peterbell10 commented Jun 11, 2024

Stack from ghstack (oldest at bottom):

Closes #125633

Benchmarks:

Shape dim stable compiled eager speedup
(256, 4096) 0 False 0.73 ms 1.26 ms 1.7
(256, 4096) 0 True 0.75 ms 1.27 ms 1.7
(4096, 256) 1 False 0.20 ms 0.73 ms 3.7
(4096, 256) 1 True 0.21 ms 0.73 ms 3.5
(255, 4096) 0 False 1.05 ms 1.48 ms 1.4
(255, 4096) 0 True 1.03 ms 1.47 ms 1.4
(4096, 255) 1 False 0.52 ms 0.98 ms 1.9
(4096, 255) 1 True 0.54 ms 1.00 ms 1.9

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Jun 11, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/128458

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit d7643bd with merge base c888ee3 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: dccc6d29e991f3c340847808fde00a970f8be5c8
Pull Request resolved: #128458
[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jun 13, 2024
ghstack-source-id: e73a00fed67aa344036a8152693a36ef71b6dda4
Pull Request resolved: #128458
[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jun 14, 2024
ghstack-source-id: 859705512a7341dedf55477606723453d5f21ec0
Pull Request resolved: #128458
@peterbell10 peterbell10 marked this pull request as ready for review June 21, 2024 00:49
sort_numel = sizevars.simplify(sympy_product(sort_ranges))

# Heuristic, smallest rblock where triton usually outperforms aten.sort
max_rblock = 256
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only uses the lowering for rnumel <= 256 because above that it seems to be significantly slower than eager. It also isn't bandwidth bound so I'm not convinced fusion would save us here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you leave this as a comment in the code?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if this kernel isn't bandwidth bound the subsequent fusion might be.



@triton.jit
def sort_with_index(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These functions are adapted from the triton library function but with added stable sort and the addition of an rmask to support sorting on non-powers of two.


# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
# so taking the hit of non-coalesced loads is okay
has_sort = any(_node_has_sort(node) for node in node_schedule)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of a work-around. For outer reductions we don't normally use RBLOCK >= 64 but for sort since we can't loop over we really need the longer RBLOCK as we don't have a non-persistent version of the sort. So, I just add an exception if I find ops.sort in the kernel.

Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind adding some benchmarks to the OP?

The implementation LGTM

return bool(sort_nodes)

# ops.sort only works with persistent reduction, and is not bandwidth bound anyway
# so taking the hit of non-coalesced loads is okay
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is what ways are coalesced loads related to whether a kernel is persistent or not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For outer reductions we can use e.g. XBLOCK=32, RBLOCK=32 which requires a non-persistent reduction for rnumel > 32 but allows the loads to be coalesced in the x dimension.

torch/_inductor/codegen/triton.py Outdated Show resolved Hide resolved
torch/_inductor/codegen/triton.py Show resolved Hide resolved
torch/_inductor/codegen/triton.py Show resolved Hide resolved
sort_numel = sizevars.simplify(sympy_product(sort_ranges))

# Heuristic, smallest rblock where triton usually outperforms aten.sort
max_rblock = 256
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you leave this as a comment in the code?

torch/_inductor/ir.py Show resolved Hide resolved
@@ -744,6 +756,10 @@ def frexp(x) -> Tuple[None, None]:
def scan(dtypes, combine_fn, values) -> Tuple[None, ...]:
return tuple(None for i in range(len(values)))

@staticmethod
def sort(dtypes, values, stable, descending) -> Tuple[None, ...]:
return tuple(None for i in range(len(values)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit. and it follows the pattern above, but we can simply do (None,) * len(values)

Comment on lines +405 to +406
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1)[:, None, :], shape)
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1)[:, None, :], shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this trick all over the place because triton's reduce does not work for non-commutative operators. Otherwise, we could do a reduce with lambda a, b: b. If you think it'd be beneficial (you say this kernel is not bandwidth?), I could send a PR to triton fixing this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to give it a shot

torch/_inductor/runtime/triton_helpers.py Outdated Show resolved Hide resolved
torch/_inductor/runtime/triton_helpers.py Show resolved Hide resolved
[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jun 21, 2024
ghstack-source-id: 559ed4e005ee962dbad2b0824f27d4935afe2b19
Pull Request resolved: #128458
descending=descending,
)
if values is None:
return sort_fallback(x, stable=stable, dim=dim, descending=descending)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also fallback if using halide codegen

sort_numel = sizevars.simplify(sympy_product(sort_ranges))

# Heuristic, smallest rblock where triton usually outperforms aten.sort
max_rblock = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if this kernel isn't bandwidth bound the subsequent fusion might be.


# Heuristic, smallest rblock where triton usually outperforms aten.sort
# It also isn't bandwidth bound so fusion is unlikely to help.
max_rblock = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously you had suggested 1024 as max. Would you mind posting the numbers for max_rblock > 256, maybe with and without cheap pointwise fusion

Copy link
Collaborator Author

@peterbell10 peterbell10 Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, here are some speedup numbers for

torch.sort(a * a, dim=-1)[0] + 2

at different shapes. So we have fusions before and after the sort, as well as being able to eliminate the index buffer to give the triton kernel its best shot:

1024, 255: 2.1x
1024, 256: 6.0x
1024, 257: 0.87x
1024, 511: 1.0x
1024, 512: 2.2x
1024, 513: 0.35x
1024, 1023: 0.36x
1024, 1024: 0.93x

You can see that it's marginally, but noticeably worse between 257-511 and only the exact value of 512 gains any speedup (where the mask is removed). Compared to 256

self.common(fn, (inp, False))
self.common(fn, (inp, True))

def test_sort_stable(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add some tests with rblock constrained <= 256 but still dynamic ?

[ghstack-poisoned]


@triton.jit
def sort_with_index(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we optimize out the index global write when it's not needed ? and if so, it still might be possible that triton isn't smart even to remove all the index related intermediaries

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write would be removed by inductor and triton can probably DCE the final steps of the index computation, but I do expect we could get a bigger win by having custom codegen for for sort without index. Though that wouldn't be able to support stable sorting since I'm using the indices to get a stable sort from the bitonic sort algorithm which is naturally unstable.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually benchmarking it I see without indices is almost 2x faster so I guess triton is able to remove the index reductions entirely. That's pretty cool actually.

tl.static_assert(
_dim == len(x.shape) - 1, "only minor dimension is currently supported"
)
# iteratively run bitonic merge-sort steps
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we pad the masked values with inf/-inf here depending on descending instead of doing all the mask work in compare_and_swap_with_index ?

Maybe that would give better perf for masked inputs ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That has the issue that inf might appear in the input and so we might sort the padding into the output. That said, it would work if either:

  1. we don't care about the indices
  2. we're doing a stable sort and so the in-sequence infs would get sorted earlier than the padding.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, not blocking for this land but maybe file issue after you land ?

From original issue 5/6 of the sorts either did not use indices or used stable sort https://gist.github.com/eellison/36f0c3e6025360315dd67932461ff11b

Copy link
Contributor

@eellison eellison Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For index case could use negative number for padded index and use that as part of tie breaker

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #129507

[ghstack-poisoned]
[ghstack-poisoned]
peterbell10 added a commit that referenced this pull request Jun 25, 2024
ghstack-source-id: 1809983d39ae9039fc3f83b4b6738728c47861c4
Pull Request resolved: #128458
@peterbell10
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jun 25, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

peterbell10 added a commit to peterbell10/pytorch that referenced this pull request Jun 26, 2024
ghstack-source-id: 1809983d39ae9039fc3f83b4b6738728c47861c4
Pull Request resolved: pytorch#128458
@github-actions github-actions bot deleted the gh/peterbell10/742/head branch July 26, 2024 01:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants