Skip to content

Commit

Permalink
Misc improvements to long-running benchmarks
Browse files Browse the repository at this point in the history
Summary:
I've been dealing with very long autotuning times when benchmarking the triton
MQA kernels, to the extent of having benchmarks run overnight, and I found
these changes to be useful quality-of-life improvements.

1. Added a log message to show how long autotuning was taking
2. Increased the dcgm disable timeout to 100k seconds, or about a day
3. Replaced references to "/tmp" with `gettempdir()`, which makes it possible
   for the user to override the temp dir via `TMPDIR=`, which allows us to use
   a temp dir that's not in danger of being removed by tmpcleaner

Reviewed By: chenyang78

Differential Revision: D58972675

fbshipit-source-id: 897f454680ca1f604dca128d670f234def10dd21
  • Loading branch information
int3 authored and facebook-github-bot committed Jun 25, 2024
1 parent a4ad760 commit 6de6dd2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
59 changes: 44 additions & 15 deletions torchbenchmark/util/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import functools
import gc
import json
import logging
import os
import random
import shlex
import tempfile
import time
import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, fields, make_dataclass
from enum import Enum
from numbers import Number
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple
from pathlib import Path
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import numpy
import tabulate
Expand All @@ -27,6 +31,8 @@
except ImportError:
tqdm = None

logger = logging.getLogger(__name__)

DEFAULT_WARMUP = 25
DEFAULT_RUN_ITERS = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
Expand Down Expand Up @@ -65,16 +71,31 @@ class Mode(Enum):
FWD_NO_GRAD = "fwd_no_grad"


class TimerContext:
def __init__(self, enabled=True):
self.enabled = enabled
self.elapsed_ms = None

def __enter__(self):
if self.enabled:
self._start_time = time.perf_counter()
return self

def __exit__(self, *args, **kwargs):
if self.enabled:
end_time = time.perf_counter()
self.elapsed_ms = (end_time - self._start_time) * 1e3


def do_bench_walltime(fn, warmup=25, rep=100):
fn()
torch.cuda.synchronize()

start_time = time.perf_counter()
for _ in range(5):
fn()
torch.cuda.synchronize()
end_time = time.perf_counter()
estimate_ms = (end_time - start_time) * 1e3 / 5
with TimerContext() as timer:
for _ in range(5):
fn()
torch.cuda.synchronize()
estimate_ms = timer.elapsed_ms / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
Expand Down Expand Up @@ -177,6 +198,8 @@ def _table(self):
table = []
# generate headers
headers = [REGISTERED_X_VALS[self.op_name]]
if len(self.result) == 0:
return headers, table
y_val = self.result[0][1]
y_val_keys = list(y_val.keys())
# move the baseline benchmark to the front of the list if exists
Expand Down Expand Up @@ -482,10 +505,13 @@ def _get_bm_func(self, bm_func_name: str):
f"Could not find benchmark {bm_func_name} registered in {self.name}. "
f"Available benchmarks: {REGISTERED_BENCHMARKS[self.name].keys()}. "
)
if isinstance(self.example_inputs, dict):
fwd_fn = fwd_fn_lambda(**self.example_inputs)
else:
fwd_fn = fwd_fn_lambda(*self.example_inputs)
with TimerContext(enabled=logger.level <= logging.INFO) as timer:
if isinstance(self.example_inputs, dict):
fwd_fn = fwd_fn_lambda(**self.example_inputs)
else:
fwd_fn = fwd_fn_lambda(*self.example_inputs)
logger.info("Took %.02fms to get benchmark function for %s", timer.elapsed_ms, bm_func_name)

if self.mode == Mode.FWD:
setattr(fwd_fn, "_name", bm_func_name)
return fwd_fn
Expand Down Expand Up @@ -687,6 +713,9 @@ def get_example_inputs(self):
except StopIteration:
return None

def get_temp_path(path: Union[str, Path]) -> Path:
return Path(tempfile.gettempdir()) / "tritonbench" / self.name / Path(path)

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
output = fn()
baseline_output = baseline_fn()
Expand Down Expand Up @@ -866,7 +895,6 @@ def ncu_trace(self, input_id: int, fn_name: str, replay: bool=False) -> str:
# collect the ncu trace
import sys
import subprocess
from pathlib import Path

op_task_args = copy.deepcopy(sys.argv)
for override_option in ["--only", "--input-id", "--num-inputs", "--metrics"]:
Expand All @@ -892,14 +920,14 @@ def ncu_trace(self, input_id: int, fn_name: str, replay: bool=False) -> str:
"dyno",
"dcgm_profiling",
"--mute=true",
"--duration=1000_s",
"--duration=100000_s",
]
subprocess.run(disable_dcgm, check=True)
except subprocess.SubprocessError:
warnings.warn(
"Cannot find dyno to disable DCGM. Proceed to collect NCU Trace."
)
ncu_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn_name}_{input_id}")
ncu_output_dir = self.get_temp_path("ncu_traces/{fn_name}_{input_id}")
ncu_output_dir.mkdir(parents=True, exist_ok=True)
ext = ".csv" if not replay else ".ncu-rep"
ncu_output_file = ncu_output_dir.joinpath(f"ncu_output{ext}").resolve()
Expand Down Expand Up @@ -929,14 +957,15 @@ def ncu_trace(self, input_id: int, fn_name: str, replay: bool=False) -> str:
str(ncu_output_file.resolve()),
])
ncu_args.extend(op_task_args)
logger.info("Running NCU: %s", shlex.join(ncu_args))
subprocess.check_call(ncu_args)
return str(ncu_output_file.resolve())

def kineto_trace(self, input_id: int, fn: Callable) -> str:
from pathlib import Path
from torchbenchmark._components.kineto import do_bench_kineto

kineto_output_dir = Path(f"/tmp/tritonbench_{self.name}_{fn._name}_{input_id}")
kineto_output_dir = self.get_temp_path("kineto_traces/{fn._name}_{input_id}")
kineto_output_dir.mkdir(parents=True, exist_ok=True)
return do_bench_kineto(
fn=fn,
Expand Down
3 changes: 2 additions & 1 deletion userbenchmark/triton/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import os
import sys
import tempfile
from typing import List
from torch import version as torch_version
from torchbenchmark.operators import load_opbench_by_name
Expand All @@ -11,7 +12,7 @@
DEFAULT_WARMUP,
)

TRITON_BENCH_CSV_DUMP_PATH = "/tmp/triton_bench/"
TRITON_BENCH_CSV_DUMP_PATH = tempfile.gettempdir() + "/tritonbench/"


def parse_args(args):
Expand Down

0 comments on commit 6de6dd2

Please sign in to comment.