Skip to content

Commit

Permalink
Extend support to varying block sizes on both dimensions for 2D matri…
Browse files Browse the repository at this point in the history
…ces (#2302)

Summary:
Pull Request resolved: #2302

Extend support for reducing across individual dimensions on 2-dimensional matrices by allowing for varying block sizes on both the `M` (first) and `N` (second) dimensions.

The existing kernel performed a simplified reduction, assuming that the entire reduction dimension fit within one thread block. The new kernel implementation removes the need for this assumption, allowing both the reduction and the non-reduction dimensions to fit in multiple thread blocks. This implementation also enables autotuning on block sizes for both the `M` and `N` dimensions.

For 1D results, add a `sum_then_buffer` configuration which decides which kernel configuration to run. `Sum_then_buffer` sums individual blocks of input and adds these sums into a buffer. `Buffer_then_sum` adds blocks of raw input into a buffer, then reduces the buffer.

Reviewed By: davidberard98

Differential Revision: D58313958

fbshipit-source-id: 639ea6b7d7b92f478c0f5669a1cdc0dcb68004e3
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 14, 2024
1 parent 10e4bc4 commit f4cbf78
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 51 deletions.
154 changes: 118 additions & 36 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def triton_sum_kernel_scalar_result(
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_NON_REDUCE_DIM": b},
{
"BLOCK_SIZE_NON_REDUCE_DIM": b,
"BLOCK_SIZE_REDUCE_DIM": b,
},
num_warps=w,
)
for b, w in itertools.product(
Expand All @@ -54,7 +57,7 @@ def triton_sum_kernel_scalar_result(
key=["M", "N"],
)
@triton.jit
def triton_sum_kernel_1D_result(
def triton_sum_kernel_1D_result_sum_then_buffer(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
Expand All @@ -66,49 +69,128 @@ def triton_sum_kernel_1D_result(
# reduction dimension
dim: tl.constexpr, # dimension along which to sum
):
"""
Sum blocks of input using Triton and store in buffer
"""

pid = tl.program_id(axis=0) # i-th block of input

block_start_m, block_start_n = 0, 0
offsets_m, offsets_n = None, None
if dim == 0:
block_start_n = pid * BLOCK_SIZE_REDUCE_DIM
# offsets have shape equal to input shape
offsets_m = block_start_m + tl.arange(
reduce_dim_len = M if dim == 0 else N
non_reduce_dim_len = N if dim == 0 else M

buffer = tl.zeros(
(1, BLOCK_SIZE_NON_REDUCE_DIM), dtype=tl.float32
) # create buffer as a row tensor

block_start_non_reduce_dim = pid * BLOCK_SIZE_NON_REDUCE_DIM
offsets_non_reduce_dim = block_start_non_reduce_dim + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
)
mask_non_reduce_dim = offsets_non_reduce_dim < non_reduce_dim_len

for block_start_reduce_dim in range(0, reduce_dim_len, BLOCK_SIZE_REDUCE_DIM):
offsets_reduce_dim = block_start_reduce_dim + tl.arange(
0, BLOCK_SIZE_REDUCE_DIM
) # create 1D vector for offsets on M-th dimension
offsets_n = block_start_n + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
) # create 1D vector for offsets on N-th dimension
elif dim == 1:
block_start_m = pid * BLOCK_SIZE_REDUCE_DIM
# offsets have shape equal to input shape
offsets_m = block_start_m + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
) # create 1D vector for offsets on M-th dimension
offsets_n = block_start_n + tl.arange(
)
mask_reduce_dim = offsets_reduce_dim < reduce_dim_len

idxs, mask = None, None
if dim == 0:
idxs = (
offsets_reduce_dim[:, None] * non_reduce_dim_len
) + offsets_non_reduce_dim
mask = mask_reduce_dim[:, None] & mask_non_reduce_dim
elif dim == 1:
idxs = (
offsets_non_reduce_dim[:, None] * reduce_dim_len
) + offsets_reduce_dim
mask = mask_non_reduce_dim[:, None] & mask_reduce_dim

input = tl.load(input_ptr + idxs, mask=mask, other=mask)

buffer += tl.sum(input, axis=dim)

buffer_view = buffer.reshape(
(BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim)


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_NON_REDUCE_DIM": b,
"BLOCK_SIZE_REDUCE_DIM": b,
},
num_warps=w,
)
for b, w in itertools.product(
[2, 4, 8, 16], [2, 4, 8] # block sizes # number of warps
)
],
key=["M", "N"],
)
@triton.jit
def triton_sum_kernel_1D_result_buffer_then_sum(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
M, # number of rows
N, # number of columns
# block sizes (input)
BLOCK_SIZE_NON_REDUCE_DIM: tl.constexpr, # number of elements in non-reduction dimension per block
BLOCK_SIZE_REDUCE_DIM: tl.constexpr, # number of elements in reduction dimension per block
# reduction dimension
dim: tl.constexpr, # dimension along which to sum
):
"""
Add blocks of input to a buffer and sum the buffer using Triton
"""

pid = tl.program_id(axis=0) # i-th block of input

reduce_dim_len = M if dim == 0 else N
non_reduce_dim_len = N if dim == 0 else M

buffer = tl.zeros(
(BLOCK_SIZE_REDUCE_DIM, BLOCK_SIZE_NON_REDUCE_DIM), dtype=tl.float32
) # create buffer as a 2D tensor

block_start_non_reduce_dim = pid * BLOCK_SIZE_NON_REDUCE_DIM
offsets_non_reduce_dim = block_start_non_reduce_dim + tl.arange(
0, BLOCK_SIZE_NON_REDUCE_DIM
)
mask_non_reduce_dim = offsets_non_reduce_dim < non_reduce_dim_len

for block_start_reduce_dim in range(0, reduce_dim_len, BLOCK_SIZE_REDUCE_DIM):
offsets_reduce_dim = block_start_reduce_dim + tl.arange(
0, BLOCK_SIZE_REDUCE_DIM
) # create 1D vector for offsets on N-th dimension
)
mask_reduce_dim = offsets_reduce_dim < reduce_dim_len

# mask has shape equal to input shape
mask_m = offsets_m < M
mask_n = offsets_n < N
idxs, mask = None, None
if dim == 0:
idxs = (
offsets_reduce_dim[:, None] * non_reduce_dim_len
) + offsets_non_reduce_dim
mask = mask_reduce_dim[:, None] & mask_non_reduce_dim
elif dim == 1:
idxs = (
offsets_non_reduce_dim[:, None] * reduce_dim_len
) + offsets_reduce_dim
mask = mask_non_reduce_dim[:, None] & mask_reduce_dim

# create 2D matrices of pointers and masks, using above M and N vectors
idxs = (offsets_m[:, None] * N) + offsets_n
mask = mask_m[:, None] & mask_n
buffer += tl.load(input_ptr + idxs, mask=mask, other=mask)

# loaded pointers have shape equal to input shape
input = tl.load(
input_ptr + idxs, mask=mask, other=mask
) # other=mask zeros out masked values from input
buffer_sum = tl.sum(buffer, axis=dim)

output = tl.sum(input, axis=dim)
buffer_view = buffer_sum.reshape(
(BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

# stored pointers have shape equal to output shape
if dim == 0: # store output along N-th dimension
tl.store(output_ptr + offsets_n, output, mask=mask_n)
elif dim == 1: # store output along M-th dimension
tl.store(output_ptr + offsets_m, output, mask=mask_m)
tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim)


@triton.autotune(
Expand Down
47 changes: 32 additions & 15 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
)

from .kernels import (
triton_sum_kernel_1D_result,
triton_sum_kernel_1D_result_buffer_then_sum,
triton_sum_kernel_1D_result_sum_then_buffer,
triton_sum_kernel_2D_result_dim_1,
triton_sum_kernel_scalar_result,
)
Expand All @@ -28,6 +29,12 @@ def parse_op_args(args: List[str]):
default=None,
help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1",
)
parser.add_argument(
"--sum-then-buffer",
type=int, # 1: sum then buffer, 0: buffer then sum
default=1,
help="[Optional] For 1D results, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum",
)
return parser.parse_args(args)


Expand All @@ -41,6 +48,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.sum_then_buffer = args.sum_then_buffer
self.sizes = range(1, 11)

@register_benchmark()
Expand All @@ -61,13 +69,12 @@ def triton_sum(self, x: torch.Tensor):
) # race condition in cases where BLOCK_SIZE < n_elements^2
elif x.dim() == 2 and num_output_dims == 1:
M, N = x.shape
BLOCK_SIZE_M, BLOCK_SIZE_N = triton.next_power_of_2(
M
), triton.next_power_of_2(N)
grid = lambda meta: (
max(
triton.cdiv(M, meta["BLOCK_SIZE_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(M, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_REDUCE_DIM"]),
),
)
elif x.dim() == 3 and num_output_dims == 2 and self.reduce_dim[0] == 1:
Expand All @@ -94,23 +101,29 @@ def _inner():
elif kernel_input.dim() == 2 and num_output_dims == 1:
if self.reduce_dim[0] == 0:
kernel_output = torch.empty(N, device=self.device)
BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_M
elif self.reduce_dim[0] == 1:
kernel_output = torch.empty(M, device=self.device)
BLOCK_SIZE_REDUCE_DIM = BLOCK_SIZE_N
else:
raise Exception(
f"Existing sum Triton kernels do not support reducing input with shape {kernel_input.size} along dimension(s) {self.reduce_dim}"
)

triton_sum_kernel_1D_result[grid](
kernel_input,
kernel_output,
M=M,
N=N,
BLOCK_SIZE_REDUCE_DIM=BLOCK_SIZE_REDUCE_DIM,
dim=self.reduce_dim[0],
)
if self.sum_then_buffer:
triton_sum_kernel_1D_result_sum_then_buffer[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=self.reduce_dim[0],
)
else:
triton_sum_kernel_1D_result_buffer_then_sum[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=self.reduce_dim[0],
)
elif (
kernel_input.dim() == 3
and num_output_dims == 2
Expand Down Expand Up @@ -201,6 +214,10 @@ def best_config(
if example_inputs[0].dim() == 3 and self.reduce_dim and self.reduce_dim[0] == 1:
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
elif self.reduce_dim and len(self.reduce_dim) < example_inputs[0].dim():
return dump_autotuner_best_config(triton_sum_kernel_1D_result)
if self.sum_then_buffer:
return dump_autotuner_best_config(
triton_sum_kernel_1D_result_sum_then_buffer
)
return dump_autotuner_best_config(triton_sum_kernel_1D_result_buffer_then_sum)
else:
return ""

0 comments on commit f4cbf78

Please sign in to comment.