Skip to content

Commit

Permalink
Add support for reducing across the middle dimension for 3D matrices …
Browse files Browse the repository at this point in the history
…using the sum Triton kernel (#2297)

Summary:
Pull Request resolved: #2297

Support reducing 3-dimensional matrices across the middle dimension (`dim == 1`) such that the result is of dimensions `(M, K)`. This kernel assumes that `BLOCK_SIZE_M == 1`, as Triton is currently unable to perform reductions on a middle dimension, and that that the entire reduction dimension of the tensor fits in a thread block (`BLOCK_SIZE_N >= N`).

Reviewed By: davidberard98

Differential Revision: D58307854

fbshipit-source-id: 77a4225cad80c1f3ac6d6355f4c5e0e221e62ff5
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 12, 2024
1 parent c8d6c2a commit 55c975e
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 21 deletions.
73 changes: 70 additions & 3 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ def triton_sum_kernel_scalar_result(
triton.Config(
{"BLOCK_SIZE_NON_REDUCE_DIM": b},
num_warps=w,
) for b, w in itertools.product(
[2, 4, 8, 16, 32, 64, 128, 256, 512, 1024], # block sizes
[2, 4, 8, 16] # number of warps
)
for b, w in itertools.product(
[2, 4, 8, 16], [2, 4, 8] # block sizes # number of warps
)
],
key=["M", "N"],
Expand Down Expand Up @@ -109,3 +109,70 @@ def triton_sum_kernel_1D_result(
tl.store(output_ptr + offsets_n, output, mask=mask_n)
elif dim == 1: # store output along M-th dimension
tl.store(output_ptr + offsets_m, output, mask=mask_m)


@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_K": b},
num_warps=w,
)
for b, w in itertools.product(
[2, 4, 16, 32, 128, 256], [2, 4, 8] # block sizes # number of warps
)
],
key=["N"],
)
@triton.jit
def triton_sum_kernel_2D_result_dim_1(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
M: tl.constexpr, # number of elements in M-th dimension
N: tl.constexpr, # number of elements in N-th dimension
K: tl.constexpr, # number of elements in K-th dimension
# block sizes (input)
BLOCK_SIZE_N: tl.constexpr, # number of elements in block on N-th dimension
BLOCK_SIZE_K: tl.constexpr, # number of elements in block on K-th dimension
):
# input block shape: (1, N, BLOCK_SIZE_K)

pid = tl.program_id(axis=0) # i-th block of input

pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_k = pid % tl.cdiv(K, BLOCK_SIZE_K)

block_start_n = (
0 # assuming that the entire reduction dimension fits within one thread block
)
block_start_k = pid_k * BLOCK_SIZE_K

offsets_n = block_start_n + tl.arange(0, BLOCK_SIZE_N)
offsets_k = block_start_k + tl.arange(0, BLOCK_SIZE_K)

mask_n = offsets_n < N
mask_k = offsets_k < K

# idxs has shape (N, BLOCK_SIZE_K)
idxs_base = (offsets_n[:, None] * K) + offsets_k
idxs = idxs_base + (
pid_m * N * K
) # increment idxs by the number of elements in all previous blocks

# mask has shape (N, BLOCK_SIZE_K)
mask = mask_n[:, None] & mask_k

# loaded pointers have shape (N, K)
input = tl.load(
input_ptr + idxs, mask=mask, other=0
) # zero out masked values from input

# output has shape (1, BLOCK_SIZE_K)
output = tl.sum(input, axis=0)

output_offsets = (pid_m * K) + offsets_k

# stored pointers have shape (1, BLOCK_SIZE_K)
tl.store(
output_ptr + output_offsets, output, mask=mask_k
) # store a 1D vector into a specific row of 2D output
64 changes: 46 additions & 18 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
register_metric,
)

from .kernels import triton_sum_kernel_1D_result, triton_sum_kernel_scalar_result
from .kernels import (
triton_sum_kernel_1D_result,
triton_sum_kernel_2D_result_dim_1,
triton_sum_kernel_scalar_result,
)


def parse_op_args(args: List[str]):
Expand All @@ -37,7 +41,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.sizes = range(1, 9)
self.sizes = range(1, 11)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
Expand All @@ -55,7 +59,7 @@ def triton_sum(self, x: torch.Tensor):
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2
elif num_output_dims == 1:
elif x.dim() == 2 and num_output_dims == 1:
M, N = x.shape
BLOCK_SIZE_M, BLOCK_SIZE_N = triton.next_power_of_2(
M
Expand All @@ -66,6 +70,14 @@ def triton_sum(self, x: torch.Tensor):
triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
),
)
elif x.dim() == 3 and num_output_dims == 2 and self.reduce_dim[0] == 1:
M, N, K = x.shape
BLOCK_SIZE_N = triton.next_power_of_2(N)
grid = lambda meta: (M * triton.cdiv(K, meta["BLOCK_SIZE_K"]),)
else:
raise Exception(
f"Existing sum Triton kernels do not support input shape {x.shape} and reduction dimension(s) {self.reduce_dim}"
)

def _inner():
if num_output_dims == 0:
Expand All @@ -79,7 +91,7 @@ def _inner():
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)
elif num_output_dims == 1:
elif kernel_input.dim() == 2 and num_output_dims == 1:
if self.reduce_dim[0] == 0:
kernel_output = torch.empty(N, device=self.device)
BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_M
Expand All @@ -99,6 +111,21 @@ def _inner():
BLOCK_SIZE_REDUCE_DIM=BLOCK_SIZE_REDUCE_DIM,
dim=self.reduce_dim[0],
)
elif (
kernel_input.dim() == 3
and num_output_dims == 2
and self.reduce_dim[0] == 1
):
kernel_output = torch.empty((M, K), device=self.device)

triton_sum_kernel_2D_result_dim_1[grid](
kernel_input,
kernel_output,
M=M,
N=N,
K=K,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)

return kernel_output

Expand Down Expand Up @@ -126,27 +153,26 @@ def get_x_vals(self) -> List[int]:
return x_vals

def get_input_iter(self) -> Generator:
if not self.reduce_dim: # reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
if not self.reduce_dim:
for size in self.get_x_vals(): # 1D tensor
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d,)

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 6): # ensure we don't exceed floating point limitations
if not self.reduce_dim or (self.reduce_dim and len(self.reduce_dim) <= 2):
for size in self.get_x_vals(): # 2D tensor
input_2d = torch.randn(
(size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

if not self.reduce_dim:
for size in self.get_x_vals(): # 3D matrix
if size < pow(
2, 4
): # ensure we don't exceed floating point limitations
input_2d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)
if not self.reduce_dim or (
self.reduce_dim and len(self.reduce_dim) <= 3 and 0 not in self.reduce_dim
): # in current kernels, cannot reduce a 3D tensor on the 0-th dimension
for size in self.get_x_vals(): # 3D tensor
input_3d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_3d,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand All @@ -173,7 +199,9 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if self.reduce_dim:
if example_inputs[0].dim() == 3 and self.reduce_dim and self.reduce_dim[0] == 1:
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
elif self.reduce_dim and len(self.reduce_dim) < example_inputs[0].dim():
return dump_autotuner_best_config(triton_sum_kernel_1D_result)
else:
return ""

0 comments on commit 55c975e

Please sign in to comment.