Skip to content

Commit

Permalink
no grad (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
Krovatkin committed Feb 5, 2021
1 parent 1a41710 commit f4ee431
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
2 changes: 2 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ def pytest_addoption(parser):
parser.addoption("--ignore_machine_config",
action='store_true',
help="Disable checks/assertions for machine configuration for stable benchmarks")
parser.addoption("--disable_nograd", action='store_true',
help="Disable no_grad for eval() runs")

def set_fuser(fuser):
if fuser == "old":
Expand Down
11 changes: 7 additions & 4 deletions test_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
from torchbenchmark import list_models
from torchbenchmark.util.machine_config import get_machine_state
from torchbenchmark.util.model import no_grad

def pytest_generate_tests(metafunc):
# This is where the list of models to test can be configured
Expand Down Expand Up @@ -72,10 +73,12 @@ def test_train(self, hub_model, benchmark):
except NotImplementedError:
print('Method train is not implemented, skipping...')

def test_eval(self, hub_model, benchmark):
def test_eval(self, hub_model, benchmark, pytestconfig):
try:
hub_model.set_eval()
benchmark(hub_model.eval)
benchmark.extra_info['machine_state'] = get_machine_state()
ng_flag = hub_model.eval_in_nograd() and not pytestconfig.getoption("disable_nograd")
with no_grad(ng_flag):
hub_model.set_eval()
benchmark(hub_model.eval)
benchmark.extra_info['machine_state'] = get_machine_state()
except NotImplementedError:
print('Method eval is not implemented, skipping...')
2 changes: 2 additions & 0 deletions torchbenchmark/models/maml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def train(self, niter=1):
for _ in range(niter):
self.module(*self.example_inputs)

def eval_in_nograd(self):
return False

if __name__ == '__main__':
m = Model(device='cpu', jit=False)
Expand Down
17 changes: 17 additions & 0 deletions torchbenchmark/util/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,20 @@
import typing
from collections.abc import Iterable
import torch
from contextlib import contextmanager


@contextmanager
def no_grad(val):
"""Some meta-learning models (e.g. maml) may need to train a target(another) model
in inference runs
"""
old_state = torch.is_grad_enabled()
try:
torch.set_grad_enabled(not val)
yield
finally:
torch.set_grad_enabled(old_state)

class BenchmarkModel():
"""
Expand All @@ -25,6 +39,9 @@ def set_eval(self):
def set_train(self):
self._set_mode(True)

def eval_in_nograd(self):
return True

def _set_mode(self, train):
(model, _) = self.get_module()
model.train(train)
Expand Down

0 comments on commit f4ee431

Please sign in to comment.