Skip to content

Commit

Permalink
Use cuda graphs for benchmarking
Browse files Browse the repository at this point in the history
Summary:
Per https://fb.workplace.com/groups/420659799592399/posts/807860500872325/, it's a lot more accurate than using regular non-cudagraph benchmarking.

I had to change a bunch of use sites of `metrics.latency` because `do_bench_cudagraph` does not support returning quantiles. Could certainly fix it upstream, but that would take more time + it doesn't really seem like quantiles are that useful in TritonBench anyway.

Reviewed By: xuzhao9, sijiac

Differential Revision: D58502780

fbshipit-source-id: 8c97b95097f49ece47ce9b1660af60afae8c25e8
  • Loading branch information
int3 authored and facebook-github-bot committed Jun 14, 2024
1 parent 86904ca commit b91caad
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 57 deletions.
7 changes: 3 additions & 4 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,18 +113,17 @@ def gbps(
+ (torch.addmm(a, mat1, mat2).numel())
)
numel = numel * a.element_size() / 1e9
gbps = list(map(lambda x: numel / x * 1e3, metrics.latency))
return statistics.median(gbps)
return numel / metrics.latency * 1e3

@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
) -> float:
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
flops = m * k * 2 * n
return [flops / x / 1e12 * 1e3 for x in metrics.latency]
return flops / metrics.latency / 1e12 * 1e3

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def sdpa_flash_attention(q, k, v):
@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
) -> float:
flops_per_matmul = (
2.0 * self.BATCH * self.H * self.N_CTX * self.N_CTX * self.D_HEAD
)
Expand All @@ -289,7 +289,7 @@ def tflops(
tflops *= 2.5 # 2.0(bwd) + 0.5(recompute)
elif self.mode == BenchmarkMode.FWD_BWD:
tflops *= 3.5 # 1.0(fwd) + 2.0(bwd) + 0.5(recompute)
return list(map(lambda x: tflops / x * 1e-9, metrics.latency))
return tflops / metrics.latency * 1e-9

def get_bwd_fn(self, fwd_fn: Callable) -> Callable:
o = fwd_fn()
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/fp8_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def nbytes(t):
m, k = a.shape
_, n = b.shape
gb = (nbytes(a) + nbytes(b) + nbytes(c)) / 1e9
return list(map(lambda x: gb / x * 1e3, metrics.latency))
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
Expand All @@ -102,7 +102,7 @@ def tflops(
m, k = a.size()
_, n = b.size()
flops = 2 * m * n * k
return [flops / x / 1e12 * 1e3 for x in metrics.latency]
return flops / metrics.latency / 1e12 * 1e3

def plot(self):
@triton.testing.perf_report(
Expand Down
8 changes: 4 additions & 4 deletions torchbenchmark/operators/gather_gemv/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@
from .triton_gather_gemv import triton_gemv_0 as triton_test_0
from torch._dynamo.testing import rand_strided


class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
arg0_1, arg1_1, arg2_1 = example_inputs
gbps = (
lambda ms: 2
return (
2
* arg2_1.size(0) * arg2_1.size(0)
* arg0_1.element_size()
/ ms
/ metrics.latency
* 1e-6
)
return list(map(gbps, metrics.latency))

def __init__(self, mode: str, device: str, extra_args: List[str] = []):
super().__init__(mode=mode, device=device, extra_args=extra_args)
Expand Down
5 changes: 2 additions & 3 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ def gbps(
a, w, bias = example_inputs
numel = a.numel() + w.numel() + (torch.mm(a, w).numel())
numel = numel * a.element_size() / 1e9
gbps = list(map(lambda x: numel / x * 1e3, metrics.latency))
return statistics.median(gbps)
return numel / metrics.latency * 1e3

@register_metric(skip_baseline=True)
def best_config(
Expand All @@ -205,7 +204,7 @@ def tflops(
flops = m * k * 2 * n + 2 * m * n
else:
flops = m * k * 2 * n
return [flops / x / 1e12 * 1e3 for x in metrics.latency]
return flops / metrics.latency / 1e12 * 1e3

@staticmethod
def _scaled_randn(*args, scale: float, **kwargs) -> torch.Tensor:
Expand Down
4 changes: 2 additions & 2 deletions torchbenchmark/operators/int4_gemm/int4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def nbytes(t):
c = fn()

gb = (sum(nbytes(t) for t in (x, scale_and_zero, c)) + nbytes(w) // 8) / 1e9
return list(map(lambda ms: gb / ms * 1e3, metrics.latency))
return gb / metrics.latency * 1e3

@register_metric()
def tflops(
Expand All @@ -106,7 +106,7 @@ def tflops(
m = B * m
_, n = b.size()
flops = 2 * m * n * k
return [flops / x / 1e12 * 1e3 for x in metrics.latency]
return flops / metrics.latency / 1e12 * 1e3

def plot(self):
@triton.testing.perf_report(
Expand Down
16 changes: 6 additions & 10 deletions torchbenchmark/operators/layer_norm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,12 @@ def get_x_val(self, args):
@register_metric()
def gbps(self, fn_name, args, metrics: BenchmarkOperatorMetrics) -> float:
x = args[0]

def gbps(ms):
base = x.numel() * x.element_size() / ms * 1e-6
return {
Mode.FWD: 2 * base,
Mode.BWD: 3 * base,
Mode.FWD_BWD: 5 * base,
}[self.mode]

return list(map(gbps, metrics.latency))
base = x.numel() * x.element_size() / metrics.latency * 1e-6
return {
Mode.FWD: 2 * base,
Mode.BWD: 3 * base,
Mode.FWD_BWD: 5 * base,
}[self.mode]

def plot(self):
@triton.testing.perf_report(
Expand Down
12 changes: 8 additions & 4 deletions torchbenchmark/operators/low_mem_dropout/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,25 @@

from .kernels import _triton_dropout, _seeded_triton_dropout


class Operator(BenchmarkOperator):
@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: 3 * example_inputs[1].element_size() * example_inputs[1].numel() / ms * 1e-6
return (
3
* example_inputs[1].element_size()
* example_inputs[1].numel()
/ metrics.latency
* 1e-6
)
return list(map(gbps, metrics.latency))

@register_metric()
def tflops(
self, fn_name: str, example_inputs, metrics: BenchmarkOperatorMetrics
):
p, a = example_inputs
flops = 2 * len(a)
return [flops / x / 1e12 * 1e3 for x in metrics.latency]
return flops / metrics.latency

@register_benchmark()
def triton_dropout(self, p, x):
Expand Down
9 changes: 4 additions & 5 deletions torchbenchmark/operators/softmax/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,14 @@ def get_x_val(self, example_inputs) -> int:
@register_metric()
def gbps(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> List[float]:
gbps = (
lambda ms: 2
) -> float:
return (
2
* example_inputs[0].nelement()
* example_inputs[0].element_size()
* 1e-9
/ (ms * 1e-3)
/ (metrics.latency * 1e-3)
)
return list(map(gbps, metrics.latency))

def plot(self):
@triton.testing.perf_report(
Expand Down
7 changes: 3 additions & 4 deletions torchbenchmark/operators/sum/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@ def input_dims(

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: example_inputs[0].element_size()
return (
example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
/ metrics.latency
* 1e-6
)
return list(map(gbps, metrics.latency if metrics.latency else [0]))

@register_metric(skip_baseline=True)
def best_config(
Expand Down
7 changes: 3 additions & 4 deletions torchbenchmark/operators/vector_add/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@ class Operator(BenchmarkOperator):

@register_metric()
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
gbps = (
lambda ms: 3
return (
3
* example_inputs[0].element_size()
* example_inputs[0].numel()
/ ms
/ metrics.latency
* 1e-6
)
return list(map(gbps, metrics.latency))

@register_benchmark()
def triton_add(self, x: torch.Tensor, y: torch.Tensor):
Expand Down
26 changes: 13 additions & 13 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def dump_autotuner_best_config(kernel: triton.runtime.Autotuner) -> str:
@dataclass
class BenchmarkOperatorMetrics:
# latency in ms
latency: Optional[List[float]]
latency: Optional[float]
# tflops
tflops: Optional[List[float]]
tflops: Optional[float]
# speedup over baseline
speedup: Optional[float]
# accuracy over baseline
Expand Down Expand Up @@ -735,13 +735,13 @@ def _init_extra_metrics() -> Dict[str, Any]:
if set(["latency", "tflops", "speedup", "compile_time"]) & set(
self.required_metrics
):
metrics.latency = triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
quantiles=quantiles,
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
with torch.cuda.stream(torch.cuda.Stream()):
metrics.latency = triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
if "walltime" in self.required_metrics:
metrics.walltime = do_bench_walltime(
fn,
Expand All @@ -750,7 +750,7 @@ def _init_extra_metrics() -> Dict[str, Any]:
)
if "speedup" in self.required_metrics:
metrics.speedup = (
numpy.median(self.baseline_metrics.latency) / numpy.median(metrics.latency)
self.baseline_metrics.latency / metrics.latency
if self.baseline_metrics and self.baseline_metrics.latency
else None
)
Expand Down Expand Up @@ -950,7 +950,7 @@ def compile_time(
op_task.run()
latency_with_compile = op_task.get_attribute("_latency_with_compile_in_task")
del op_task
latency_without_compile = numpy.median(metrics.latency)
latency_without_compile = metrics.latency
return latency_with_compile - latency_without_compile

def hw_roofline(self) -> float:
Expand Down Expand Up @@ -984,7 +984,7 @@ def _compile_time_in_task(

def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> List[float]:
) -> float:
def _get_flops(self, func: Callable) -> float:
"""By default, use the torch.__dispatch__ based flops counter."""
from torch.utils.flop_counter import FlopCounterMode
Expand All @@ -1010,4 +1010,4 @@ def work_func():
if not fn in self._op_flops:
self._op_flops[fn] = _get_flops(self, fn)
op_flops = self._op_flops[fn]
return list(map(lambda x: op_flops / x / 1e12 * 1e3, metrics.latency))
return op_flops / metrics.latency / 1e12 * 1e3

0 comments on commit b91caad

Please sign in to comment.