Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add simple fused Triton kernel for jagged_sum operator #2322

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add simple fused Triton kernel for jagged_sum operator
Summary:
Add Triton kernel implementation to `jagged_sum` operator in TritonBench. This Triton kernel performs a sum along the ragged dimension of a nested tensor of logical dimensions `(B, *, M)`, where `*` is the ragged dimension.  It loads in blocks of the `values` tensor along its last dimension `M`, reduces each block of variable length along its first dimension `*`, and stores each of `B` reductions in an output tensor of shape `(B, M)`.

This Triton kernel is benchmarked against two PyTorch implementations, one which does not pad blocks of variable length and one which does pad.

Reviewed By: davidberard98

Differential Revision: D58549297
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 20, 2024
commit e16c06340b4c418fded96e35b9a33fef68affec6
155 changes: 155 additions & 0 deletions torchbenchmark/operators/jagged_sum/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import itertools

import triton
import triton.language as tl


BLOCK_SIZES = [2**n for n in range(2, 11, 3)]
NUM_WARPS = [2, 4, 8]
NUM_STAGES = [2, 4, 8]


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_RAGGED": b_r,
"BLOCK_SIZE_M": b_m,
},
num_warps=w,
num_stages=s,
)
for b_r, b_m, w, s in itertools.product(
BLOCK_SIZES, # block sizes on non-reduction dimension
BLOCK_SIZES, # block sizes on reduction dimension
NUM_WARPS, # number of warps
NUM_STAGES, # number of stages
)
],
key=["M"],
)
@triton.jit
def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
input_ptr_values, # pointer to input values (2D tensor)
input_ptr_offsets, # pointer to input offsets (1D tensor)
output_ptr, # pointer to output tensor (2D tensor)
# matrix dimensions (input)
M, # number of elements in M-th dimension, with logical dimensions (B, *, M)
MAX_SEQLEN, # max length of ragged dimension
# block sizes (input)
BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M)
BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M)
):
pid = tl.program_id(axis=0) # i-th tensor in nested tensor
pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M)

buffer = tl.zeros(
(1, BLOCK_SIZE_M), dtype=tl.float32
) # create buffer as a row tensor

block_start_m = pid_m * BLOCK_SIZE_M
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

idxs = (offsets_ragged[:, None] * M) + offsets_m
mask = mask_ragged[:, None] & mask_m

input = tl.load(input_ptr_values + idxs, mask=mask, other=0)

buffer += tl.sum(input, axis=0)

buffer_view = buffer.reshape(
(BLOCK_SIZE_M,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

output_offsets = offsets_m + (
pid_ragged * M
) # output is offset by both ragged dimension and M-th dimension
output_mask = output_offsets < (M * (pid_ragged + 1))

tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask)


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_RAGGED": b_r,
"BLOCK_SIZE_M": b_m,
},
num_warps=w,
num_stages=s,
)
for b_r, b_m, w, s in itertools.product(
BLOCK_SIZES, # block sizes on non-reduction dimension
BLOCK_SIZES, # block sizes on reduction dimension
NUM_WARPS, # number of warps
NUM_STAGES, # number of stages
)
],
key=["M"],
)
@triton.jit
def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
input_ptr_values, # pointer to input values (2D tensor)
input_ptr_offsets, # pointer to input offsets (1D tensor)
output_ptr, # pointer to output tensor (2D tensor)
# matrix dimensions (input)
M, # number of elements in M-th dimension, with logical dimensions (B, *, M)
MAX_SEQLEN, # max length of ragged dimension
# block sizes (input)
BLOCK_SIZE_RAGGED: tl.constexpr, # number of elements in ragged dimension per block, with logical dimensions (B, *, M)
BLOCK_SIZE_M: tl.constexpr, # number of elements in M-th dimension per block, with logical dimensions (B, *, M)
):
pid = tl.program_id(axis=0) # i-th tensor in nested tensor
pid_ragged = pid // tl.cdiv(M, BLOCK_SIZE_M)
pid_m = pid % tl.cdiv(M, BLOCK_SIZE_M)

buffer = tl.zeros(
(BLOCK_SIZE_RAGGED, BLOCK_SIZE_M), dtype=tl.float32
) # create buffer as a row tensor

block_start_m = pid_m * BLOCK_SIZE_M
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < M

ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
input_ptr_offsets + (pid_ragged + 1)
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]

for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

idxs = (offsets_ragged[:, None] * M) + offsets_m
mask = mask_ragged[:, None] & mask_m

buffer += tl.load(input_ptr_values + idxs, mask=mask, other=0)

buffer_sum = tl.sum(buffer, axis=0)

buffer_view = buffer_sum.reshape(
(BLOCK_SIZE_M,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

output_offsets = offsets_m + (
pid_ragged * M
) # output is offset by both ragged dimension and M-th dimension
output_mask = output_offsets < (M * (pid_ragged + 1))

tl.store(output_ptr + output_offsets, buffer_view, mask=output_mask)
66 changes: 63 additions & 3 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
register_metric,
)

from .kernels import (
triton_jagged_sum_kernel_simple_fused_buffer_then_sum,
triton_jagged_sum_kernel_simple_fused_sum_then_buffer,
)

seed = 16
random.seed(seed)
torch.manual_seed(seed)
Expand All @@ -38,21 +43,57 @@ def parse_op_args(args: List[str]):
default=0.5,
help="Average sparsity for nested tensor (float, (0.0-1.0))",
)
parser.add_argument(
"--sum-then-buffer",
type=int, # 1: sum then buffer, 0: buffer then sum
default=1,
help="[Optional] For Triton kernels, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum",
)
return parser.parse_args(args)


def execute_kernel_simple_fused(x, max_seqlen, sum_then_buffer):
B, M = x.shape[0], x.shape[2]
grid = lambda meta: ((len(x.offsets()) - 1) * triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
kernel_output = torch.zeros((B, M), device=x.device)

if sum_then_buffer:
triton_jagged_sum_kernel_simple_fused_sum_then_buffer[grid](
x.values(),
x.offsets(),
kernel_output,
M=M,
MAX_SEQLEN=max_seqlen,
)
else:
triton_jagged_sum_kernel_simple_fused_buffer_then_sum[grid](
x.values(),
x.offsets(),
kernel_output,
M=M,
MAX_SEQLEN=max_seqlen,
)

return kernel_output


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]
use_cuda_graphs = False # enables GPU/CPU sync (for methods like NestedTensor unbind)
use_cuda_graphs = (
False # enables GPU/CPU sync (for methods like NestedTensor unbind)
)

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.sizes = range(4, 10, 2)
self.sizes = list(range(2, 8, 2)) + list(
range(8, 12)
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
self.seqlen = args.seqlen
self.sparsity = args.sparsity
self.sum_then_buffer = args.sum_then_buffer

@register_benchmark(baseline=True)
def torch_jagged_sum_no_pad(self, x: torch.Tensor):
Expand All @@ -75,6 +116,13 @@ def torch_jagged_sum_pad(self, x: torch.Tensor):
dim=1,
) # sum along ragged dimension (dim == 1)

@register_benchmark()
def triton_jagged_sum_no_pad(self, x: torch.Tensor):
def _inner():
return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer)

return _inner

def get_x_val(self, example_inputs):
return len(example_inputs[0])

Expand Down Expand Up @@ -156,7 +204,7 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
/ metrics.latency
* GIGABYTES_PER_BYTE
)

@register_metric(x_only=True)
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
Expand All @@ -166,3 +214,15 @@ def input_shape(
"*",
example_inputs[0].shape[2],
) # return (B, '*', M) for each example input

@register_metric(skip_baseline=True)
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if self.sum_then_buffer:
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_sum_then_buffer
)
return dump_autotuner_best_config(
triton_jagged_sum_kernel_simple_fused_buffer_then_sum
)
Loading