Skip to content

Commit

Permalink
Make CUDA graph benchmarking overridable on a per-op basis
Browse files Browse the repository at this point in the history
Summary: some operators need to do gpu-cpu syncs, which is not supported under graph capture

Reviewed By: davidberard98

Differential Revision: D58680076

fbshipit-source-id: 7c86c484990445512723ebdda25ef4af8cfffde5
  • Loading branch information
int3 authored and facebook-github-bot committed Jun 17, 2024
1 parent d5f0a12 commit 51eca7b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
_input_iter: Optional[Generator] = None
extra_args: List[str] = []
example_inputs: Any = None
use_cuda_graphs: bool = True

# By default, only collect latency metrics
# Each operator can override to define their own default metrics
Expand Down Expand Up @@ -743,9 +744,18 @@ def _init_extra_metrics() -> Dict[str, Any]:
if set(["latency", "tflops", "speedup", "compile_time"]) & set(
self.required_metrics
):
with torch.cuda.stream(torch.cuda.Stream()):
metrics.latency = triton.testing.do_bench_cudagraph(
if self.use_cuda_graphs:
with torch.cuda.stream(torch.cuda.Stream()):
metrics.latency = triton.testing.do_bench_cudagraph(
fn,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
)
else:
metrics.latency = triton.testing.do_bench(
fn,
warmup=warmup,
rep=rep,
return_mode="median",
grad_to_none=self.get_grad_to_none(self.example_inputs),
Expand Down

0 comments on commit 51eca7b

Please sign in to comment.