Skip to content

Commit

Permalink
Add plots variable along 1 input parameter to jagged_sum operator in …
Browse files Browse the repository at this point in the history
…TritonBench

Summary:
Add plotting functionality to the `jagged_sum` operator in TritonBench, enabling the creation of line plots for any set of benchmarks variable along one of the following input parameters: `B`, `M`, `seqlen`, or `sparsity`. This diff makes it easier to visualize the differences in `latency` among the different benchmarks in the `jagged_sum` operator.

Add a command-line argument to toggle the benchmarks displayed on the plots, making it easier to visualize just the 2 Triton benchmarks or just the 2 PyTorch benchmarks, if necessary. This modification helps more clearly visualize the `latency` differences between the simple fused and variable-length loop Triton kernels as well as the unpadded and padded PyTorch benchmarks.

Note that for plots testing multiple values of `B` or `M`, the x-axis is on a log scale to more accurately depict trends in latency; whether or not the x-axis is on a log scale is noted in the plot name in the Test Plan.

Reviewed By: davidberard98

Differential Revision: D59034792

fbshipit-source-id: d27e8e59856d24328c25192069d0afbc655a62e7
  • Loading branch information
jananisriram authored and facebook-github-bot committed Jun 27, 2024
1 parent ec0909d commit 847d38e
Showing 1 changed file with 94 additions and 1 deletion.
95 changes: 94 additions & 1 deletion torchbenchmark/operators/jagged_sum/operator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import itertools
import math
import os
import random
from typing import Callable, Generator, List, Optional, Tuple

Expand Down Expand Up @@ -61,6 +62,12 @@ def parse_op_args(args: List[str]):
default=0,
help="[Optional] For Triton kernels, determines whether to sum individual blocks then add to a buffer or add to a buffer then sum; 1: sum then buffer, 0: buffer then sum; default 0",
)
parser.add_argument(
"--plot-benchmarks",
type=str,
default="all",
help="[Optional] Determines which benchmarks to plot: all, torch, triton",
)
return parser.parse_args(args)


Expand Down Expand Up @@ -131,6 +138,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.seqlen = args.seqlen
self.sparsity = args.sparsity
self.sum_then_buffer = args.sum_then_buffer
self.plot_benchmarks = args.plot_benchmarks

@register_benchmark(baseline=True)
def torch_jagged_sum_no_pad(
Expand Down Expand Up @@ -176,7 +184,14 @@ def _inner():
return _inner

def get_x_val(self, example_inputs):
return len(example_inputs[0])
if self.B is None:
return example_inputs[1]
if self.M is None:
return example_inputs[2]
if self.seqlen is None:
return example_inputs[3]
if self.sparsity is None:
return example_inputs[4]

def get_x_vals(self) -> Tuple[List[int], List[int], List[int], List[float]]:
B_vals, M_vals, seqlen_vals, sparsity_vals = [], [], [], []
Expand Down Expand Up @@ -316,3 +331,81 @@ def best_config(
return dump_autotuner_best_config(
triton_jagged_sum_kernel_variable_length_loop_buffer_then_sum
)

def plot(self):
str_B, str_M, str_seqlen, str_sparsity = f"-B-{self.B}", f"-M-{self.M}", f"-seqlen-{self.seqlen}", f"-sparsity-{self.sparsity}"
if self.B is None:
x_axis = "B"
x_log = True
params = str_M + str_seqlen + str_sparsity
elif self.M is None:
x_axis = "M"
x_log = True
params = str_B + str_seqlen + str_sparsity
elif self.seqlen is None:
x_axis = "seqlen"
x_log = False
params = str_B + str_M + str_sparsity
else:
x_axis = "sparsity"
x_log = False
params = str_B + str_M + str_seqlen

line_vals_all = [
"torch_jagged_sum_no_pad",
"torch_jagged_sum_pad",
"triton_jagged_sum_no_pad_simple_fused",
"triton_jagged_sum_no_pad_variable_length_loop",
]
line_names_all = [
"PyTorch jagged sum, no padding",
"PyTorch jagged sum, padding",
"Triton kernel jagged sum, simple fused",
"Triton kernel jagged sum, variable length loop",
]
styles_all = [
("blue", "-"),
("red", "-"),
("green", "-"),
("yellow", "-"),
]
if self.plot_benchmarks == "all":
line_vals, line_names, styles = line_vals_all, line_names_all, styles_all
elif self.plot_benchmarks == "torch":
line_vals = line_vals_all[:2]
line_names = line_names_all[:2]
styles = styles_all[:2]
else:
line_vals = line_vals_all[2:]
line_names = line_names_all[2:]
styles = styles_all[2:]

plot_name = f"jagged-sum-perf-var-{x_axis}-xlog-{x_log}" + params

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["x_axis"],
x_vals=self.output.x_vals,
line_arg="provider",
line_vals=line_vals,
line_names=line_names,
styles=styles,
xlabel=x_axis,
ylabel="latency",
x_log=x_log,
plot_name=plot_name,
args={},
)
)
def _plot(x_axis, provider):
return self.output.get_y_vals(x_axis, provider, "latency")

save_path = (
os.getcwd()
+ f"/pytorch/benchmark/torchbenchmark/operators/jagged_sum/jagged_sum_performance/{plot_name}"
)

if not os.path.exists(save_path):
os.mkdir(save_path)

_plot.run(show_plots=True, print_data=True, save_path=save_path)

0 comments on commit 847d38e

Please sign in to comment.