Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple fused Triton kernel benchmark for jagged_mean operator #2355

Closed
wants to merge 4 commits into from

Conversation

jananisriram
Copy link
Contributor

Summary:
Add Triton kernel benchmark implementing a simple fused mean for the jagged_mean operator. The Triton kernels perform a mean along the ragged dimension of a nested tensor of logical dimensions (B, *, M), where * is the ragged dimension. They load in blocks of the values tensor along its last dimension M, reduce each block of variable length along its first dimension *, and store each of B reductions in an output tensor of shape (B, M). The first kernel, sum_then_buffer, performs a sum on each block of input, then accumulates into a buffer. The second kernel, buffer_then_sum, is a faster implementation which accumulates blocks into a buffer, then performs a cumulative sum.

This diff is particularly useful in emulating the loop in Inductor-generated (torch.compile) kernels and serves as a benchmark proxy for Inductor kernels.

Use the command-line argument sum_then_buffer, defaulted to 0 (as buffer_then_sum is faster, shown below), to decide which Triton kernel to benchmark.

These Triton kernels are benchmarked against two PyTorch implementations, one of which uses torch.mean, and the other torch.div, torch.sum, and shape.

This diff follows the general framework found in the jagged_sum operator (D58549297, D59034792).

Reviewed By: davidberard98

Differential Revision: D59146627

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59146627

Summary:
Add to TritonBench a `jagged_mean` reduction operator for nested tensors using the PyTorch `torch.mean` and `unbind` functions. This diff implements a basic benchmark for reducing along the ragged dimension of 3-dimensional jagged tensors. For a 3-dimensional tensor of shape `(B, *, M)`, where `*` is the ragged dimension, this benchmark uses PyTorch's `mean` operator to reduce `B` `(*, M)` 2-dimensional tensors to a `(B, M)` output tensor.

Add plotting functionality to the `jagged_mean` operator in TritonBench, enabling the creation of line plots for any set of benchmarks variable along one of the following input parameters: `B`, `M`, `seqlen`, or `sparsity`. This diff sets the groundwork to visualize the differences in `latency` among the different benchmarks in the `jagged_mean` operator.

Measure performance of basic PyTorch benchmark using the `latency` and `gbps` metrics as well as the `latency` plot, variable along one input parameter. Display nested tensor parameters in benchmark output.

This diff follows the general framework found in the `jagged_sum` operator (D58396957, D59034792).

Differential Revision: D59144906

Reviewed By: davidberard98
Summary:
Add to TritonBench a `jagged_mean` reduction operator benchmark for nested tensors using the PyTorch `torch.sum`, `torch.div`, `torch.shape`, and `unbind` functions. This diff implements a basic benchmark for reducing along the ragged dimension of 3-dimensional jagged tensors. For a 3-dimensional tensor of shape `(B, *, M)`, where `*` is the ragged dimension, this benchmark `unbind`s the nested tensor into `B` `(*, M)` 2-dimensional tensors. For each `(*, M)` tensor, the benchmark divides the `sum` along the ragged dimension `0` by the `shape` along the ragged dimension `0`, which calculates the `mean` for `(*, M)`.

Extend plotting functionality for the `jagged_mean` operator to account for the new benchmark. Add an `accuracy` metric to verify that the results of all existing benchmarks match.

This diff follows the general framework found in the `jagged_sum` operator (D58396957, D59034792).

Differential Revision: D59146024
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59146627

jananisriram added a commit to jananisriram/benchmark that referenced this pull request Jul 2, 2024
…torch#2355)

Summary:
Pull Request resolved: pytorch#2355

Add Triton kernel benchmark implementing a simple fused `mean` for the `jagged_mean` operator. The Triton kernels perform a `mean` along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. They load in blocks of the values tensor along its last dimension `M`, reduce each block of variable length along its first dimension `*`, and store each of `B` reductions in an output tensor of shape `(B, M)`. The first kernel, `sum_then_buffer`, performs a `sum` on each block of input, then accumulates into a buffer. The second kernel, `buffer_then_sum`, is a faster implementation which accumulates blocks into a buffer, then performs a cumulative `sum`.

This diff is particularly useful in emulating the loop in Inductor-generated (`torch.compile`) kernels and serves as a benchmark proxy for Inductor kernels.

Use the command-line argument `sum_then_buffer`, defaulted to `0` (as `buffer_then_sum` is faster, shown below), to decide which Triton kernel to benchmark.

These Triton kernels are benchmarked against two PyTorch implementations, one of which uses `torch.mean`, and the other `torch.div`, `torch.sum`, and `shape`.

This diff follows the general framework found in the jagged_sum operator (D58549297, D59034792).

Reviewed By: davidberard98

Differential Revision: D59146627
…torch#2355)

Summary:
Pull Request resolved: pytorch#2355

Add Triton kernel benchmark implementing a simple fused `mean` for the `jagged_mean` operator. The Triton kernels perform a `mean` along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension. They load in blocks of the values tensor along its last dimension `M`, reduce each block of variable length along its first dimension `*`, and store each of `B` reductions in an output tensor of shape `(B, M)`. The first kernel, `sum_then_buffer`, performs a `sum` on each block of input, then accumulates into a buffer. The second kernel, `buffer_then_sum`, is a faster implementation which accumulates blocks into a buffer, then performs a cumulative `sum`.

This diff is particularly useful in emulating the loop in Inductor-generated (`torch.compile`) kernels and serves as a benchmark proxy for Inductor kernels.

Use the command-line argument `sum_then_buffer`, defaulted to `0` (as `buffer_then_sum` is faster, shown below), to decide which Triton kernel to benchmark.

These Triton kernels are benchmarked against two PyTorch implementations, one of which uses `torch.mean`, and the other `torch.div`, `torch.sum`, and `shape`.

This diff follows the general framework found in the jagged_sum operator (D58549297, D59034792).

Reviewed By: davidberard98

Differential Revision: D59146627
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D59146627

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 1e79c04.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants