From b2b4158d3278d39e5575143e1d6954a9dc01e396 Mon Sep 17 00:00:00 2001 From: Jez Ng Date: Wed, 19 Jun 2024 16:08:51 -0700 Subject: [PATCH] Use NVTX filtering to limit NCU profile collection Summary: Previously, we used `--replay-mode range`, but that did not give us per-kernel metrics, so it was changed to `---replay-mode kernel` (the default). However, that can causes us to profile a lot more kernels outside the ones in the desired benchmark. It appears we can instead use NVTX filtering to solve this problem. Relevant docs: https://docs.nvidia.com/nsight-compute/NsightComputeCli/index.html#nvtx-filtering I also tacked on a minor change to the ncu invocation, adding `--import-source yes`. This makes it easier to analyze the traces on a different machine from the one doing the profiling. Reviewed By: chenyang78 Differential Revision: D58711358 fbshipit-source-id: 28aec4f71a736c7427b1886335297ece4a2a54a8 --- torchbenchmark/_components/ncu/__init__.py | 17 +++++---- torchbenchmark/util/triton_op.py | 40 +++++++++++++--------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/torchbenchmark/_components/ncu/__init__.py b/torchbenchmark/_components/ncu/__init__.py index 2b65a9ac40..ff669cc08b 100644 --- a/torchbenchmark/_components/ncu/__init__.py +++ b/torchbenchmark/_components/ncu/__init__.py @@ -1,7 +1,14 @@ - from typing import Callable -def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush=True, output_dir=None) -> None: + +def do_bench_ncu_in_task( + fn: Callable, + warmup=25, + grad_to_none=None, + fast_flush=True, + output_dir=None, + range_name: str = "", +) -> None: """ Benchmark the runtime of the provided function. By default, return the median runtime of :code:`fn` along with the 20-th and 80-th performance percentile. @@ -46,8 +53,6 @@ def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush= # Warm-up for _ in range(n_warmup): fn() - # Start ncu profiling - torch.cuda.cudart().cudaProfilerStart() # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the # provided gradients @@ -56,5 +61,5 @@ def do_bench_ncu_in_task(fn: Callable, warmup=25, grad_to_none=None, fast_flush= x.grad = None # we clear the L2 cache before run cache.zero_() - fn() - torch.cuda.cudart().cudaProfilerStop() + with torch.cuda.nvtx.range(range_name): + fn() diff --git a/torchbenchmark/util/triton_op.py b/torchbenchmark/util/triton_op.py index 3a67c1024f..fdcd14729a 100644 --- a/torchbenchmark/util/triton_op.py +++ b/torchbenchmark/util/triton_op.py @@ -414,6 +414,10 @@ def __call__(cls, *args, **kwargs): obj.__post__init__() return obj + +_RANGE_NAME = "tritonbench_range" + + class BenchmarkOperator(metaclass=PostInitProcessor): mode: Mode = Mode.FWD test: str = "eval" @@ -827,6 +831,7 @@ def _init_extra_metrics() -> Dict[str, Any]: fn=fn, warmup=warmup, grad_to_none=self.get_grad_to_none(self.example_inputs), + range_name=_RANGE_NAME, ) metrics.extra_metrics["_ncu_trace_in_task"] = "success" # generate customized metrics @@ -901,26 +906,27 @@ def ncu_trace(self, input_id: int, fn_name: str, replay: bool=False) -> str: "ncu", "--set", "full", - "--replay-mode", - "kernel", + "--nvtx", + "--nvtx-include", + f"{_RANGE_NAME}/", "--target-processes", "all", - "--csv", - "-f", - "--log-file", - str(ncu_output_file.resolve()), - ] if not replay else [ - "ncu", - "--set", - "full", - "--replay-mode", - "kernel", - "--target-processes", - "all", - "-f", - "-o", - str(ncu_output_file.resolve()), + "--import-source", + "yes", ] + if replay: + ncu_args.extend([ + "-f", + "-o", + str(ncu_output_file.resolve()), + ]) + else: + ncu_args.extend([ + "--csv", + "-f", + "--log-file", + str(ncu_output_file.resolve()), + ]) ncu_args.extend(op_task_args) subprocess.check_call(ncu_args) return str(ncu_output_file.resolve())