Skip to content

Commit

Permalink
Add command line argument parsing for reduction dimensions in Triton …
Browse files Browse the repository at this point in the history
…sum kernel (#2284)

Summary:
Pull Request resolved: #2284

Add argument parsing for the command line in order to pass in dimension(s) across which the kernel reduces and enable more rigorous testing of different versions of the sum kernel, referencing [torchbenchmark/operators/fb/flash_attention/operator.py](https://www.internalfb.com/code/fbsource/[864a578ce44afdba619d50a352c8ca3b783e05ef]/fbcode/pytorch/benchmark/torchbenchmark/operators/fb/flash_attention/operator.py?lines=84).

Inherit the `__init__` function from the parent class `BenchmarkOperator` in order to facilitate command line argument parsing.

Change `dim` type to `list` to avoid type issues resulting from `tl.constexpr`.

Modify equality checks in kernel and operator to satisfy type requirements for `dim`.

Reviewed By: xuzhao9

Differential Revision: D58212366

fbshipit-source-id: 5c88a7c3e8bf2f37408c6c5e3d302b7e9a473bd4
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 7, 2024
1 parent 2d8999b commit 9f2ab74
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 14 deletions.
16 changes: 12 additions & 4 deletions torchbenchmark/operators/sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,26 @@ def triton_sum_kernel_scalar(

block_start = pid * BLOCK_SIZE_M
# offsets have shape equal to input shape
offsets = block_start + tl.arange(0, BLOCK_SIZE_M) # create 1D vector (input shape) ranging from beginning to end of this program's block
offsets = block_start + tl.arange(
0, BLOCK_SIZE_M
) # create 1D vector (input shape) ranging from beginning to end of this program's block

# mask has shape equal to input shape
mask = offsets < M # mask out offsets that are out of bounds for input

# loaded pointers have shape equal to input shape
x = tl.load(input_ptr + offsets, mask=mask, other=mask) # load input, where the loaded pointers are in the desired input shape
x = tl.load(
input_ptr + offsets, mask=mask, other=mask
) # load input, where the loaded pointers are in the desired input shape

output = tl.sum(x)

# output_offsets have shape equal to output shape
output_offsets = tl.arange(0, 1) # create offsets for scalar output pointer (output shape == (1,))
output_offsets = tl.arange(
0, 1
) # create offsets for scalar output pointer (output shape == (1,))

# stored pointers have shape equal to output shape
tl.store(output_ptr + output_offsets, output) # store output, where the stored pointers are in the desired output shape
tl.store(
output_ptr + output_offsets, output
) # store output, where the stored pointers are in the desired output shape
47 changes: 37 additions & 10 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,47 @@
from .kernels import triton_sum_kernel_scalar


def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--reduce-dim",
type=int,
nargs="*",
default=None,
help="[Optional] Dimension(s) on which kernel performs reduction; e.g. --reduce-dim 0, --reduce-dim 0 1",
)
return parser.parse_args(args)


class Operator(BenchmarkOperator):

DEFAULT_METRICS = ["latency", "accuracy"]

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
args = parse_op_args(self.extra_args)
self.reduce_dim = (
args.reduce_dim if args.reduce_dim else None
) # for 2D case, guaranteed to be a list with 1 integer
self.sizes = range(1, 17)

@register_benchmark()
def triton_sum(self, x: torch.Tensor):
x_1d = x.view(-1)
M = x_1d.shape[0]
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE_M"]),)
BLOCK_SIZE_M = triton.next_power_of_2(M) # race condition in cases where BLOCK_SIZE < n_elements^2
BLOCK_SIZE_M = triton.next_power_of_2(
M
) # race condition in cases where BLOCK_SIZE < n_elements^2

def _inner():
output = torch.zeros(1, device=x.device, dtype=x.dtype)

triton_sum_kernel_scalar[grid](
x_1d, output, M=M, BLOCK_SIZE_M=BLOCK_SIZE_M,
x_1d,
output,
M=M,
BLOCK_SIZE_M=BLOCK_SIZE_M,
)

return output
Expand All @@ -53,32 +74,38 @@ def get_x_vals(self) -> List[int]:

x_vals.extend([2**n for n in self.sizes])
x_vals.extend([(n - 1) * (n + 1) for n in self.sizes if n - 1 > 0])

return x_vals

def get_input_iter(self) -> Generator:
# reduce to a scalar value
for size in self.get_x_vals(): # 1D matrix
input_1d = torch.randn(size, device=self.device, dtype=self.dtype)
yield (input_1d, )
yield (input_1d,)

for size in self.get_x_vals(): # 2D matrix
if size < pow(2, 8): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )
input_2d = torch.randn(
(size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

for size in self.get_x_vals(): # 3D matrix
if size < pow(2, 4): # ensure we don't exceed floating point limitations
input_2d = torch.randn((size, size, size), device=self.device, dtype=self.dtype)
yield (input_2d, )
input_2d = torch.randn(
(size, size, size), device=self.device, dtype=self.dtype
)
yield (input_2d,)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
return torch.allclose(output, baseline_output, atol=1e-4)

@register_metric(skip_baseline=True)
def input_dims(self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics):
def input_dims(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return [ex.dim() for ex in example_inputs]

@register_metric()
Expand Down

0 comments on commit 9f2ab74

Please sign in to comment.