Skip to content

Commit

Permalink
Refactor code for sum Triton kernels (#2303)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2303

Refactor code to improve readability and logical flow for cases which select the `sum` Triton kernel implementation to run. Create helper functions for the following cases:
- Reduce N-dimensional input to scalar output
- Reduce 2-dimensional input to 1-dimensional output
- Reduce 3-dimensional input along dimension 1 to 2-dimensional output

Add command line argument parsing for the `input_dim` parameter, which specifies the number of dimensions desired in kernel inputs.

Modify absolute tolerance to account for floating-point operation error.

Reviewed By: jbschlosser

Differential Revision: D58488137

fbshipit-source-id: 01e1f6104383cb0ec5338c4d0427b3a30c2bffd1
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 14, 2024
1 parent f4cbf78 commit 339ccfd
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 120 deletions.
10 changes: 5 additions & 5 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ def triton_sum_kernel_scalar_result(
configs=[
triton.Config(
{
"BLOCK_SIZE_NON_REDUCE_DIM": b,
"BLOCK_SIZE_REDUCE_DIM": b,
"BLOCK_SIZE_NON_REDUCE_DIM": b_nr,
"BLOCK_SIZE_REDUCE_DIM": b_r,
},
num_warps=w,
)
for b, w in itertools.product(
[2, 4, 8, 16], [2, 4, 8] # block sizes # number of warps
for b_nr, b_r, w in itertools.product(
[2, 4, 8, 16], [2, 4, 8, 16], [2, 4, 8] # block sizes on non-reduction dimension, block sizes on reduction dimension, number of warps
)
],
key=["M", "N"],
Expand Down Expand Up @@ -200,7 +200,7 @@ def triton_sum_kernel_1D_result_buffer_then_sum(
num_warps=w,
)
for b, w in itertools.product(
[2, 4, 16, 32, 128, 256], [2, 4, 8] # block sizes # number of warps
[2, 4, 16, 32, 128, 256], [2, 4, 8] # block sizes, number of warps
)
],
key=["N"],
Expand Down
240 changes: 125 additions & 115 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,17 @@

def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-dim",
type=int,
default=1,
help="Number of dimensions desired in input tensor; e.g. --input-dim 2 for a 2D input tensor",
)
parser.add_argument(
"--reduce-dim",
type=int,
nargs="*",
default=None,
help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1",
default=None, # reduce to a scalar result
help="[Optional] Dimension on which kernel performs reduction; e.g. --reduce-dim 0",
)
parser.add_argument(
"--sum-then-buffer",
Expand All @@ -38,106 +43,122 @@ def parse_op_args(args: List[str]):
return parser.parse_args(args)


# helper functions to get kernel parameters based on output dimension


def execute_kernel_scalar_result(x):
kernel_input = x.view(-1)
M = kernel_input.shape[0]
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
kernel_output = torch.zeros(
(), device=x.device, dtype=x.dtype
) # scalar tensor output

triton_sum_kernel_scalar_result[grid](
kernel_input,
kernel_output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)

return kernel_output


def execute_kernel_1D_result(x, reduce_dim, sum_then_buffer):
kernel_input = x
M, N = x.shape
grid = lambda meta: (
max(
triton.cdiv(M, meta["BLOCK_SIZE_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(M, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_REDUCE_DIM"]),
),
)
if reduce_dim == 0:
kernel_output = torch.empty(N, device=x.device)
else: # reduce_dim == 1
kernel_output = torch.empty(M, device=x.device)

if sum_then_buffer:
triton_sum_kernel_1D_result_sum_then_buffer[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=reduce_dim,
)
else:
triton_sum_kernel_1D_result_buffer_then_sum[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=reduce_dim,
)

return kernel_output


def execute_kernel_2D_result(x):
kernel_input = x
M, N, K = x.shape
BLOCK_SIZE_N = triton.next_power_of_2(N)
grid = lambda meta: (M * triton.cdiv(K, meta["BLOCK_SIZE_K"]),)
kernel_output = torch.empty((M, K), device=x.device)

triton_sum_kernel_2D_result_dim_1[grid](
kernel_input,
kernel_output,
M=M,
N=N,
K=K,
BLOCK_SIZE_N=BLOCK_SIZE_N,
)

return kernel_output


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
args = parse_op_args(self.extra_args)
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.input_dim = args.input_dim
self.reduce_dim = args.reduce_dim
self.sum_then_buffer = args.sum_then_buffer
self.sizes = range(1, 11)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
num_output_dims = 0 if not self.reduce_dim else x.dim() - len(self.reduce_dim)
kernel_input = x

assert (
x.is_contiguous()
), "Existing sum Triton kernels only support contiguous tensors"

if num_output_dims == 0:
kernel_input = x.view(-1)
M = kernel_input.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2
elif x.dim() == 2 and num_output_dims == 1:
M, N = x.shape
grid = lambda meta: (
max(
triton.cdiv(M, meta["BLOCK_SIZE_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(M, meta["BLOCK_SIZE_NON_REDUCE_DIM"]),
triton.cdiv(N, meta["BLOCK_SIZE_REDUCE_DIM"]),
),
)
elif x.dim() == 3 and num_output_dims == 2 and self.reduce_dim[0] == 1:
M, N, K = x.shape
BLOCK_SIZE_N = triton.next_power_of_2(N)
grid = lambda meta: (M * triton.cdiv(K, meta["BLOCK_SIZE_K"]),)
else:
raise Exception(
f"Existing sum Triton kernels do not support input shape {x.shape} and reduction dimension(s) {self.reduce_dim}"
)
assert (
self.reduce_dim is None or self.reduce_dim <= 1
), f"Existing sum Triton kernels do not support reducing along dimension {self.reduce_dim}"

def _inner():
if num_output_dims == 0:
kernel_output = torch.zeros(
(), device=x.device, dtype=x.dtype
) # scalar tensor output

triton_sum_kernel_scalar_result[grid](
kernel_input,
kernel_output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
if self.reduce_dim is None or self.input_dim == 1:
kernel_output = execute_kernel_scalar_result(x)
elif self.input_dim == 2:
kernel_output = execute_kernel_1D_result(
x, self.reduce_dim, self.sum_then_buffer
)
elif kernel_input.dim() == 2 and num_output_dims == 1:
if self.reduce_dim[0] == 0:
kernel_output = torch.empty(N, device=self.device)
elif self.reduce_dim[0] == 1:
kernel_output = torch.empty(M, device=self.device)
else:
raise Exception(
f"Existing sum Triton kernels do not support reducing input with shape {kernel_input.size} along dimension(s) {self.reduce_dim}"
)

if self.sum_then_buffer:
triton_sum_kernel_1D_result_sum_then_buffer[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=self.reduce_dim[0],
)
else:
triton_sum_kernel_1D_result_buffer_then_sum[grid](
kernel_input,
kernel_output,
M=M,
N=N,
dim=self.reduce_dim[0],
)
elif (
kernel_input.dim() == 3
and num_output_dims == 2
and self.reduce_dim[0] == 1
):
kernel_output = torch.empty((M, K), device=self.device)

triton_sum_kernel_2D_result_dim_1[grid](
kernel_input,
kernel_output,
M=M,
N=N,
K=K,
BLOCK_SIZE_N=BLOCK_SIZE_N,
elif self.input_dim == 3:
assert (
self.reduce_dim == 1
), f"Existing sum Triton kernels do not support reducing {self.input_dim}-D input along dimension {self.reduce_dim}"
kernel_output = execute_kernel_2D_result(x)
else:
raise NotImplementedError(
f"Existing sum Triton kernels do not support {self.input_dim}-D inputs"
)

return kernel_output
Expand Down Expand Up @@ -166,37 +187,24 @@ def get_x_vals(self) -> List[int]:
return x_vals

def get_input_iter(self) -> Generator:
if not self.reduce_dim:
for size in self.get_x_vals(): # 1D tensor
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d,)

if not self.reduce_dim or (self.reduce_dim and len(self.reduce_dim) <= 2):
for size in self.get_x_vals(): # 2D tensor
input_2d = torch.randn(
(size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

if not self.reduce_dim or (
self.reduce_dim and len(self.reduce_dim) <= 3 and 0 not in self.reduce_dim
): # in current kernels, cannot reduce a 3D tensor on the 0-th dimension
for size in self.get_x_vals(): # 3D tensor
input_3d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_3d,)
assert (
self.input_dim <= 3
), f"Existing sum Triton kernels do not support input dimension {self.input_dim}"

for size in self.get_x_vals():
input_tensor = torch.randn(
tuple(
[size for _ in range(self.input_dim)]
), # tuple with self.input_dim dimensions
device=self.device,
dtype=self.dtype,
)
yield (input_tensor,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-4)

@register_metric(skip_baseline=True)
def input_dims(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return [ex.dim() for ex in example_inputs]
return torch.allclose(output, baseline_output, atol=1e-3)

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
Expand All @@ -211,13 +219,15 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
def best_config(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> str:
if example_inputs[0].dim() == 3 and self.reduce_dim and self.reduce_dim[0] == 1:
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
elif self.reduce_dim and len(self.reduce_dim) < example_inputs[0].dim():
if self.input_dim == 2:
if self.sum_then_buffer:
return dump_autotuner_best_config(
triton_sum_kernel_1D_result_sum_then_buffer
)
return dump_autotuner_best_config(triton_sum_kernel_1D_result_buffer_then_sum)
return dump_autotuner_best_config(
triton_sum_kernel_1D_result_buffer_then_sum
)
elif self.input_dim == 3:
return dump_autotuner_best_config(triton_sum_kernel_2D_result_dim_1)
else:
return ""

0 comments on commit 339ccfd

Please sign in to comment.