Skip to content

Commit

Permalink
Extend support to varying block sizes on middle dimension of 3D tenso…
Browse files Browse the repository at this point in the history
…rs for sum operator

Summary: Add varying block sizes on the middle dimension, `N`, for `sum` Triton kernels which reduce a 3-dimensional input tensor of shape `(M, N, K)` to a 2-dimensional output along the middle dimension (`dim == 1`). This diff adds functionality for the `sum_then_buffer` approach, which proved to be faster than the `buffer_then_sum` approach, and is primarily beneficial for reducing 3-dimensional tensors with large middle dimensions, on the order of 2^12 and above. As seen below, Triton outperforms PyTorch for the large middle dimensions in the mid range of large inputs, on the order of approximately 2^16.

Reviewed By: jbschlosser

Differential Revision: D58892972

fbshipit-source-id: ddd8e051dfec13e61bd29fc6ea99fa00c32ee8cf
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 26, 2024
1 parent 2643d15 commit 2f6ea58
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 12 deletions.
161 changes: 158 additions & 3 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def triton_sum_kernel_scalar_result(
num_warps=w,
)
for b_nr, b_r, w in itertools.product(
[2, 4, 8, 16], [2, 4, 8, 16], [2, 4, 8] # block sizes on non-reduction dimension, block sizes on reduction dimension, number of warps
[2, 4, 8, 16],
[2, 4, 8, 16],
[
2,
4,
8,
], # block sizes on non-reduction dimension, block sizes on reduction dimension, number of warps
)
],
key=["M", "N"],
Expand Down Expand Up @@ -111,7 +117,7 @@ def triton_sum_kernel_1D_result_sum_then_buffer(
buffer += tl.sum(input, axis=dim)

buffer_view = buffer.reshape(
(BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True
(BLOCK_SIZE_NON_REDUCE_DIM,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim)
Expand Down Expand Up @@ -187,7 +193,7 @@ def triton_sum_kernel_1D_result_buffer_then_sum(
buffer_sum = tl.sum(buffer, axis=dim)

buffer_view = buffer_sum.reshape(
(BLOCK_SIZE_NON_REDUCE_DIM,), can_reorder=True
(BLOCK_SIZE_NON_REDUCE_DIM,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

tl.store(output_ptr + offsets_non_reduce_dim, buffer_view, mask=mask_non_reduce_dim)
Expand Down Expand Up @@ -258,3 +264,152 @@ def triton_sum_kernel_2D_result_dim_1(
tl.store(
output_ptr + output_offsets, output, mask=mask_k
) # store a 1D vector into a specific row of 2D output


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_N": b_n,
"BLOCK_SIZE_K": b_k,
},
num_warps=w,
)
for b_n, b_k, w in itertools.product(
[4**n for n in range(6)], [4**n for n in range(4)], [2, 4, 8]
)
],
key=["N"],
)
@triton.jit
def triton_sum_kernel_2D_result_dim_1_sum_then_buffer(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
M: tl.constexpr, # number of elements in M-th dimension
N: tl.constexpr, # number of elements in N-th dimension
K: tl.constexpr, # number of elements in K-th dimension
# block sizes (input)
BLOCK_SIZE_N: tl.constexpr, # number of elements in block on N-th dimension
BLOCK_SIZE_K: tl.constexpr, # number of elements in block on K-th dimension
):
"""
Modification to triton_sum_kernel_2D_result_dim_1() which uses a buffer to store intermediate results,
enabling reducing over a large middle dimension for 3D input tensors
"""

# input block shape: (1, N, BLOCK_SIZE_K)

pid = tl.program_id(axis=0) # i-th block of input

pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_k = pid % tl.cdiv(K, BLOCK_SIZE_K)

buffer = tl.zeros((1, BLOCK_SIZE_K), dtype=tl.float32)

block_start_k = pid_k * BLOCK_SIZE_K
offsets_k = block_start_k + tl.arange(0, BLOCK_SIZE_K)
mask_k = offsets_k < K

for block_start_n in range(0, N, BLOCK_SIZE_N):
offsets_n = block_start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < N

# idxs has shape (N, BLOCK_SIZE_K)
idxs_base = (offsets_n[:, None] * K) + offsets_k
idxs = idxs_base + (
pid_m * N * K
) # increment idxs by the number of elements in all previous blocks

# mask has shape (N, BLOCK_SIZE_K)
mask = mask_n[:, None] & mask_k

# loaded pointers have shape (N, K)
input = tl.load(
input_ptr + idxs, mask=mask, other=0
) # zero out masked values from input

buffer += tl.sum(input, axis=0)

# buffer has shape (1, BLOCK_SIZE_K)
buffer_view = buffer.reshape(
(BLOCK_SIZE_K,),
) # reshape buffer to 1D, as tl.sum may return a 2D tensor

output_offsets = (pid_m * K) + offsets_k

# stored pointers have shape (1, BLOCK_SIZE_K)
tl.store(
output_ptr + output_offsets, buffer_view, mask=mask_k
) # store a 1D vector into a specific row of 2D output


@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_N": b_n,
"BLOCK_SIZE_K": b_k,
},
num_warps=w,
)
for b_n, b_k, w in itertools.product(
[4**n for n in range(7)], [4**n for n in range(4)], [2, 4, 8]
)
],
key=["N"],
)
@triton.jit
def triton_sum_kernel_2D_result_dim_1_buffer_then_sum(
input_ptr, # pointer to input matrix
output_ptr, # pointer to output matrix
# matrix dimensions (input)
M: tl.constexpr, # number of elements in M-th dimension
N: tl.constexpr, # number of elements in N-th dimension
K: tl.constexpr, # number of elements in K-th dimension
# block sizes (input)
BLOCK_SIZE_N: tl.constexpr, # number of elements in block on N-th dimension
BLOCK_SIZE_K: tl.constexpr, # number of elements in block on K-th dimension
):
# input block shape: (1, N, BLOCK_SIZE_K)

pid = tl.program_id(axis=0) # i-th block of input

pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)
pid_k = pid % tl.cdiv(K, BLOCK_SIZE_K)

buffer = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32)

block_start_k = pid_k * BLOCK_SIZE_K
offsets_k = block_start_k + tl.arange(0, BLOCK_SIZE_K)
mask_k = offsets_k < K

for block_start_n in range(0, N, BLOCK_SIZE_N):
offsets_n = block_start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < N

# idxs has shape (N, BLOCK_SIZE_K)
idxs_base = (offsets_n[:, None] * K) + offsets_k
idxs = idxs_base + (
pid_m * N * K
) # increment idxs by the number of elements in all previous blocks

# mask has shape (N, BLOCK_SIZE_K)
mask = mask_n[:, None] & mask_k

# loaded pointers have shape (N, K)
input = tl.load(
input_ptr + idxs, mask=mask, other=0
) # zero out masked values from input

buffer += input

# output has shape (1, BLOCK_SIZE_K)
output = tl.sum(buffer, axis=0)

output_offsets = (pid_m * K) + offsets_k

# stored pointers have shape (1, BLOCK_SIZE_K)
tl.store(
output_ptr + output_offsets, output, mask=mask_k
) # store a 1D vector into a specific row of 2D output
41 changes: 32 additions & 9 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
triton_sum_kernel_1D_result_buffer_then_sum,
triton_sum_kernel_1D_result_sum_then_buffer,
triton_sum_kernel_2D_result_dim_1,
triton_sum_kernel_2D_result_dim_1_sum_then_buffer,
triton_sum_kernel_scalar_result,
)

GIGABYTES_PER_BYTE = 1e-6
ABSOLUTE_TOLERANCE = 1e-4
RELATIVE_TOLERANCE = 1e-3
TENSOR_BYTES_LIMIT = 1e10 # allocate tensors no greater than 10GB
TENSOR_BYTES_LIMIT = 8 * 1e9 # allocate tensors no greater than 10GB


def parse_op_args(args: List[str]):
Expand All @@ -47,8 +48,8 @@ def parse_op_args(args: List[str]):
parser.add_argument(
"--sum-then-buffer",
type=int, # 1: sum then buffer, 0: buffer then sum
default=1,
help="[Optional] For 1D results, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum",
default=0,
help="[Optional] For 1D results, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum; default 0",
)
parser.add_argument(
"--M",
Expand Down Expand Up @@ -131,17 +132,17 @@ def execute_kernel_1D_result(x, reduce_dim, sum_then_buffer):
def execute_kernel_2D_result(x):
kernel_input = x
M, N, K = x.shape
BLOCK_SIZE_N = triton.next_power_of_2(N)
grid = lambda meta: (M * triton.cdiv(K, meta["BLOCK_SIZE_K"]),)
kernel_output = torch.empty((M, K), device=x.device, dtype=x.dtype)

triton_sum_kernel_2D_result_dim_1[grid](
triton_sum_kernel_2D_result_dim_1_sum_then_buffer[
grid
]( # variable block sizes on N and K dimensions
kernel_input,
kernel_output,
M=M,
N=N,
K=K,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)

return kernel_output
Expand Down Expand Up @@ -206,6 +207,11 @@ def get_x_val(self, example_inputs):

def get_x_vals(self):
M_vals, N_vals, K_vals = [], [], []
M_vals_large_middle_dim, N_vals_large_middle_dim, K_vals_large_middle_dim = (
[],
[],
[],
)

def get_dim_vals():
vals = []
Expand All @@ -221,24 +227,37 @@ def get_dim_vals():

if self.M is None:
M_vals.extend(get_dim_vals())
M_vals_large_middle_dim.extend([8, 16])
else:
M_vals.extend([self.M])
M_vals_large_middle_dim.extend([self.M])

if self.N is None:
N_vals.extend(get_dim_vals())
N_vals_large_middle_dim.extend([2**n for n in range(12, 22, 2)])
else:
N_vals.extend([self.N])
N_vals_large_middle_dim.extend([self.N])

if self.K is None:
K_vals.extend(get_dim_vals())
K_vals_large_middle_dim.extend([8, 16])
else:
K_vals.extend([self.K])
K_vals_large_middle_dim.extend([self.K])

if self.input_dim == 1:
return M_vals
if self.input_dim == 2:
return M_vals, N_vals
return M_vals, N_vals, K_vals
return (
M_vals,
N_vals,
K_vals,
M_vals_large_middle_dim,
N_vals_large_middle_dim,
K_vals_large_middle_dim,
)

def get_input_iter(self) -> Generator:
assert (
Expand All @@ -256,7 +275,9 @@ def get_size_in_bytes(shape) -> int:
elif self.input_dim == 2:
sizes = itertools.product(x_vals[0], x_vals[1])
else:
sizes = itertools.product(x_vals[0], x_vals[1], x_vals[2])
sizes = list(itertools.product(x_vals[0], x_vals[1], x_vals[2])) + list(
itertools.product(x_vals[3], x_vals[4], x_vals[5])
) # small- to mid-range dimensions + large middle dimension

for size in sizes:
if get_size_in_bytes(size) < TENSOR_BYTES_LIMIT:
Expand Down Expand Up @@ -294,7 +315,9 @@ def best_config(
triton_sum_kernel_1D_result_buffer_then_sum
)
elif self.input_dim == 3:
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
return dump_autotuner_best_config(
triton_sum_kernel_2D_result_dim_1_sum_then_buffer
)
else:
return ""

Expand Down

0 comments on commit 2f6ea58

Please sign in to comment.