Skip to content

Commit

Permalink
Add support for reducing across individual dimensions for 2D matrices…
Browse files Browse the repository at this point in the history
… using the sum Triton kernel (#2295)

Summary:
Pull Request resolved: #2295

Support reducing a 2-dimensional matrix across one dimension, where the `BLOCK_SIZE` in the reduced dimension is larger than the dimension size. This kernel performs a simplified reduction which assumes that the entire reduction dimension of the tensor fits in a thread block. The implementation handles toggling between block sizes for the `M` and `N` dimensions depending on the reduction dimension. For example, this kernel will reduce across the 0-th dimension for a (M, N) = (16, 16) matrix where `BLOCK_SIZE_M >= 16` and `BLOCK_SIZE_N` is autotuned.

Add a `best_config` metric to find the best `BLOCK_SIZE` for the non-reduction dimension and `num_warps` given some input size.

Reviewed By: jbschlosser

Differential Revision: D58261858

fbshipit-source-id: 8995c91c54e9792b52f4608446e8e940027a604d
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 12, 2024
1 parent c13df57 commit 3ecaae9
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 36 deletions.
78 changes: 75 additions & 3 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import itertools

import torch
import triton
import triton.language as tl


@triton.jit
def triton_sum_kernel_scalar(
input_ptr,
output_ptr,
def triton_sum_kernel_scalar_result(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
M, # number of elements
BLOCK_SIZE_M: tl.constexpr, # number of elements per block
):
Expand Down Expand Up @@ -37,3 +39,73 @@ def triton_sum_kernel_scalar(
tl.store(
output_ptr + output_offsets, output
) # store output, where the stored pointers are in the desired output shape


@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_NON_REDUCE_DIM": b},
num_warps=w,
) for b, w in itertools.product(
[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # block sizes
[2, 4, 8, 16] # number of warps
)
],
key=["M", "N"],
)
@triton.jit
def triton_sum_kernel_1D_result(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
M, # number of rows
N, # number of columns
# block sizes (input)
BLOCK_SIZE_NON_REDUCE_DIM: tl.constexpr, # number of elements in non-reduction dimension per block
BLOCK_SIZE_REDUCE_DIM: tl.constexpr, # number of elements in reduction dimension per block
# reduction dimension
dim: tl.constexpr, # dimension along which to sum
):
pid = tl.program_id(axis=0) # i-th block of input

block_start_m, block_start_n = 0, 0
offsets_m, offsets_n = None, None
if dim == 0:
block_start_n = pid * BLOCK_SIZE_REDUCE_DIM
# offsets have shape equal to input shape
offsets_m = block_start_m + tl.arange(
0, BLOCK_SIZE_REDUCE_DIM
) # create 1D vector for offsets on M-th dimension
offsets_n = block_start_n + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
) # create 1D vector for offsets on N-th dimension
elif dim == 1:
block_start_m = pid * BLOCK_SIZE_REDUCE_DIM
# offsets have shape equal to input shape
offsets_m = block_start_m + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
) # create 1D vector for offsets on M-th dimension
offsets_n = block_start_n + tl.arange(
0, BLOCK_SIZE_REDUCE_DIM
) # create 1D vector for offsets on N-th dimension

# mask has shape equal to input shape
mask_m = offsets_m < M
mask_n = offsets_n < N

# create 2D matrices of pointers and masks, using above M and N vectors
idxs = (offsets_m[:, None] * N) + offsets_n
mask = mask_m[:, None] & mask_n

# loaded pointers have shape equal to input shape
input = tl.load(
input_ptr + idxs, mask=mask, other=mask
) # other=mask zeros out masked values from input

output = tl.sum(input, axis=dim)

# stored pointers have shape equal to output shape
if dim == 0: # store output along N-th dimension
tl.store(output_ptr + offsets_n, output, mask=mask_n)
elif dim == 1: # store output along M-th dimension
tl.store(output_ptr + offsets_m, output, mask=mask_m)
125 changes: 92 additions & 33 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
dump_autotuner_best_config,
register_benchmark,
register_metric,
)

from .kernels import triton_sum_kernel_scalar
from .kernels import triton_sum_kernel_1D_result, triton_sum_kernel_scalar_result


def parse_op_args(args: List[str]):
Expand All @@ -36,35 +37,76 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.sizes = range(1, 17)
self.sizes = range(1, 9)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
x_1d = x.view(-1)
M = x_1d.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2
num_output_dims = 0 if not self.reduce_dim else x.dim() - len(self.reduce_dim)
kernel_input = x

assert (
x.is_contiguous()
), "Existing sum Triton kernels only support contiguous tensors"

if num_output_dims == 0:
kernel_input = x.view(-1)
M = kernel_input.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2
elif num_output_dims == 1:
M, N = x.shape
BLOCK_SIZE_M, BLOCK_SIZE_N = triton.next_power_of_2(
M
), triton.next_power_of_2(N)
grid = lambda meta: (
max(
triton.cdiv(M, meta["BLOCK_SIZE_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
),
)

def _inner():
output = torch.zeros(1, device=x.device, dtype=x.dtype)

triton_sum_kernel_scalar[grid](
x_1d,
output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)
if num_output_dims == 0:
kernel_output = torch.zeros(
(), device=x.device, dtype=x.dtype
) # scalar tensor output

triton_sum_kernel_scalar_result[grid](
kernel_input,
kernel_output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)
elif num_output_dims == 1:
if self.reduce_dim[0] == 0:
kernel_output = torch.empty(N, device=self.device)
BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_M
elif self.reduce_dim[0] == 1:
kernel_output = torch.empty(M, device=self.device)
BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_N
else:
raise Exception(
f"Existing sum Triton kernels do not support reducing input with shape {kernel_input.size} along dimension(s) {self.reduce_dim}"
)

triton_sum_kernel_1D_result[grid](
kernel_input,
kernel_output,
M=M,
N=N,
BLOCK_SIZE_REDUCE_DIM=BLOCK_SIZE_REDUCE_DIM,
dim=self.reduce_dim[0],
)

return output
return kernel_output

return _inner

@register_benchmark(baseline=True)
def torch_sum(self, x: torch.Tensor):
result = torch.sum(x)
return lambda: result
return lambda: torch.sum(x, dim=self.reduce_dim)

def get_x_val(self, example_inputs):
return len(example_inputs[0])
Expand All @@ -73,29 +115,38 @@ def get_x_vals(self) -> List[int]:
x_vals = []

x_vals.extend([2**n for n in self.sizes])
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])
x_vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in x_vals
]
)

return x_vals

def get_input_iter(self) -> Generator:
# reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d,)
if not self.reduce_dim: # reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d,)

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 8): # ensure we don't exceed floating point limitations
if size < pow(2, 6): # ensure we don't exceed floating point limitations
input_2d = torch.randn(
(size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

for size in self.get_x_vals(): # 3D matrix
if size < pow(2, 4): # ensure we don't exceed floating point limitations
input_2d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)
if not self.reduce_dim:
for size in self.get_x_vals(): # 3D matrix
if size < pow(
2, 4
): # ensure we don't exceed floating point limitations
input_2d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand All @@ -111,10 +162,18 @@ def input_dims(
@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: 3
* example_inputs[0].element_size()
lambda ms: example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
* 1e-6
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))

@register_metric(skip_baseline=True)
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if self.reduce_dim:
return dump_autotuner_best_config(triton_sum_kernel_1D_result)
else:
return ""

0 comments on commit 3ecaae9

Please sign in to comment.