Skip to content

Commit

Permalink
Add variable seqlen and sparsity parameters to jagged_sum benchmark (#…
Browse files Browse the repository at this point in the history
…2324)

Summary:
Pull Request resolved: #2324

Modify existing `jagged_sum` operator benchmark to optionally accept any of the following parameters: `B` (dimension 0 of nested tensor), `M` (dimension 2 of nested tensor), `seqlen` (maximum sequence length on ragged dimension), or `sparsity` (average sparsity on ragged dimension). This diff fixes the provided command line parameters and varies all other parameters above, enabling testing of all combinations of multiple parameters in parallel.

The following errors persist with sufficiently large inputs:
- `RuntimeError: numel needs to be smaller than int32_t max; otherwise, please use packed_accessor64` (when running command `buck2 run mode/{opt,inplace} //pytorch/benchmark:triton -- --op jagged_sum --B 1024 --M 1024 --sparsity 0.3`)
- `torch.OutOfMemoryError: CUDA out of memory.`

Reviewed By: davidberard98

Differential Revision: D58772201

fbshipit-source-id: 365bc1735652736a562d350c7938c1a565a630fe
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 21, 2024
1 parent 53faa0a commit 1425f68
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 44 deletions.
8 changes: 6 additions & 2 deletions torchbenchmark/operators/jagged_sum/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
block_start_ragged = (
ragged_start + block_pos
) # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

Expand Down Expand Up @@ -132,7 +134,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
for block_pos in range(
0, MAX_SEQLEN, BLOCK_SIZE_RAGGED
): # loop over ragged dimension, ranging until maximum seqlen
block_start_ragged = ragged_start + block_pos # offset block position by start of current program
block_start_ragged = (
ragged_start + block_pos
) # offset block position by start of current program
offsets_ragged = block_start_ragged + tl.arange(0, BLOCK_SIZE_RAGGED)
mask_ragged = offsets_ragged < ragged_end

Expand Down
119 changes: 77 additions & 42 deletions torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,25 @@

def parse_op_args(args: List[str]):
parser = argparse.ArgumentParser()
parser.add_argument(
"--B",
type=int,
help="[Optional] Size of dimension 0 in shape (B, *, M) (integer)",
)
parser.add_argument(
"--M",
type=int,
help="[Optional] Size of dimension 2 in shape (B, *, M) (integer)",
)
parser.add_argument(
"--seqlen",
type=int,
default=500,
help="Maximum sequence length on ragged dimension (integer)",
help="[Optional] Maximum sequence length on ragged dimension (integer)",
)
parser.add_argument(
"--sparsity",
type=float,
default=0.5,
help="Average sparsity for nested tensor (float, (0.0-1.0))",
help="[Optional] Average sparsity for nested tensor (float, (0.0-1.0))",
)
parser.add_argument(
"--sum-then-buffer",
Expand Down Expand Up @@ -91,12 +99,16 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
) # bias towards larger sizes, which are more representative of real-world shapes

args = parse_op_args(self.extra_args)
self.seqlen = args.seqlen
self.sparsity = args.sparsity
self.B = args.B if args.B is not None else None
self.M = args.M if args.M is not None else None
self.seqlen = args.seqlen if args.seqlen is not None else None
self.sparsity = args.sparsity if args.sparsity is not None else None
self.sum_then_buffer = args.sum_then_buffer

@register_benchmark(baseline=True)
def torch_jagged_sum_no_pad(self, x: torch.Tensor):
def torch_jagged_sum_no_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.tensor(
[
torch.sum(t, dim=0).tolist() for t in x.unbind()
Expand All @@ -106,66 +118,87 @@ def torch_jagged_sum_no_pad(self, x: torch.Tensor):
)

@register_benchmark()
def torch_jagged_sum_pad(self, x: torch.Tensor):
def torch_jagged_sum_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
return lambda: torch.sum(
torch.ops.aten._jagged_to_padded_dense_forward(
x.values(),
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
max_lengths=[self.seqlen], # max length of ragged dimension
max_lengths=[seqlen], # max length of ragged dimension
),
dim=1,
) # sum along ragged dimension (dim == 1)

@register_benchmark()
def triton_jagged_sum_no_pad(self, x: torch.Tensor):
def triton_jagged_sum_no_pad(
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
):
def _inner():
return execute_kernel_simple_fused(x, self.seqlen, self.sum_then_buffer)
return execute_kernel_simple_fused(x, seqlen, self.sum_then_buffer)

return _inner

def get_x_val(self, example_inputs):
return len(example_inputs[0])

def get_x_vals(self) -> Tuple[List[int], List[int]]:
B_vals, M_vals = [], []

B_vals.extend([2**n for n in self.sizes])
B_vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in B_vals
]
)
def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]:
B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], []

def get_dim_vals():
vals = []
vals.extend([2**n for n in self.sizes])
vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in vals
]
)
return vals

if self.B is None:
B_vals.extend(get_dim_vals())
else:
B_vals.extend([self.B])

if self.M is None:
M_vals.extend(get_dim_vals())
else:
M_vals.extend([self.M])

if self.seqlen is None:
seqlen_vals.extend(
list(range(100, 1000, 100))
+ list(range(1000, 10000, 1000))
)
else:
seqlen_vals.extend([self.seqlen])

M_vals.extend([2**n for n in self.sizes])
M_vals.extend(
[
(n - 1) * (n + 1)
for n in self.sizes
if n - 1 > 0 and (n - 1) * (n + 1) not in M_vals
]
)
if self.sparsity is None:
sparsity_vals.extend([n / 10 for n in range(1, 10)])
else:
sparsity_vals.extend([self.sparsity])

return B_vals, M_vals
return B_vals, M_vals, seqlen_vals, sparsity_vals

def get_input_iter(self) -> Generator:
"""
Generate random nested tensors of shape (B, *, M), where * is the ragged dimension
"""

B_vals, M_vals = self.get_x_vals()
B_M_vals = itertools.product(B_vals, M_vals)
B_vals, M_vals, seqlen_vals, sparsity_vals = self.get_x_vals()
vals = itertools.product(B_vals, M_vals, seqlen_vals, sparsity_vals)

for B, M in B_M_vals:
for B, M, seqlen, sparsity in vals:
tensors = []

# greater sparsity --> shorter sequence lengths on ragged dimension
seqlen_avg = math.floor(
self.seqlen * (1 - self.sparsity)
seqlen * (1 - sparsity)
) # average sequence length across all tensors in nested tensor
seqlen_margin = math.floor(
self.seqlen * RANDOM_CHOICE_MARGIN
seqlen * RANDOM_CHOICE_MARGIN
) # use margin to constrain sequence lengths to range [seqlen_avg - seqlen_margin, seqlen_avg + seqlen_margin] to approximate an average sequence length, which correlates with sparsity

for _ in range(B):
Expand All @@ -174,7 +207,7 @@ def get_input_iter(self) -> Generator:
seqlen_avg - seqlen_margin, 1
), # seqlen_randint must be at least 1
min(
seqlen_avg + seqlen_margin, self.seqlen
seqlen_avg + seqlen_margin, seqlen
), # seqlen_randint must not exceed self.seqlen
)
tensor_2d = torch.randn(
Expand All @@ -189,7 +222,7 @@ def get_input_iter(self) -> Generator:
dtype=self.dtype,
)

yield (nt,)
yield (nt, B, M, seqlen, sparsity)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
Expand All @@ -205,15 +238,17 @@ def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
* GIGABYTES_PER_BYTE
)

@register_metric(x_only=True)
@register_metric(x_only=True) # TODO modify!!!!
def input_shape(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
return (
example_inputs[0].shape[0],
f"B: {example_inputs[1]}", # B
"*",
example_inputs[0].shape[2],
) # return (B, '*', M) for each example input
f"M: {example_inputs[2]}", # M
f"max seqlen: {example_inputs[3]}", # seqlen
f"sparsity: {example_inputs[4]}", # sparsity
) # return (B, '*', M, max seqlen, sparsity) for each example input

@register_metric(skip_baseline=True)
def best_config(
Expand Down

0 comments on commit 1425f68

Please sign in to comment.