-
Notifications
You must be signed in to change notification settings - Fork 273
/
conftest.py
57 lines (48 loc) · 2.12 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import pytest
import torch
from torchbenchmark.util.machine_config import get_machine_config, check_machine_configured
def pytest_addoption(parser):
parser.addoption("--fuser", help="fuser to use for benchmarks")
parser.addoption("--ignore_machine_config",
action='store_true',
help="Disable checks/assertions for machine configuration for stable benchmarks")
def set_fuser(fuser):
if fuser == "old":
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
elif fuser == "te":
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_set_bailout_depth(20)
torch._C._jit_set_num_profiled_runs(2)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
def pytest_sessionstart(session):
if not session.config.getoption('ignore_machine_config'):
check_machine_configured()
def pytest_configure(config):
set_fuser(config.getoption("fuser"))
def pytest_benchmark_update_machine_info(config, machine_info):
machine_info['pytorch_version'] = torch.__version__
try:
import torchtext
machine_info['torchtext_version'] = torchtext.__version__
except ImportError:
machine_info['torchtext_version'] = '*not-installed*'
try:
import torchvision
machine_info['torchvision_version'] = torchvision.__version__
except ImportError:
machine_info['torchvision_version'] = '*not-installed*'
machine_info['circle_build_num'] = os.environ.get("CIRCLE_BUILD_NUM")
machine_info['circle_project_name'] = os.environ.get("CIRCLE_PROJECT_REPONAME")
try:
# if running on unexpected machine/os, get_machine_config _may_ not work
machine_info['torchbench_machine_config'] = get_machine_config()
except Exception:
if not config.getoption('ignore_machine_config'):
raise