Skip to content

Commit

Permalink
Add --skip to install.py (#2314)
Browse files Browse the repository at this point in the history
Summary:
Use `--skip` to skip the install of certain models.

Fixes #2308

Pull Request resolved: #2314

Test Plan:
```
$ python install.py hf_Bert hf_Bart yolov3 --skip hf_Bert yolov3
checking packages torch, torchvision, torchaudio are installed...OK
running setup for /Users/xzhao9/git/benchmark/torchbenchmark/models/hf_Bart...OK
```

Reviewed By: eellison

Differential Revision: D58767785

Pulled By: xuzhao9

fbshipit-source-id: 5fcf5d539dda322e86bc6c2ae86b37e288ec71ad
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jun 20, 2024
1 parent 8451c1f commit 5254910
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
23 changes: 21 additions & 2 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@ def pip_install_requirements(requirements_txt="requirements.txt"):
action="store_true",
help="Run in test mode and check package versions",
)
parser.add_argument(
"--skip",
nargs="*",
default=[],
help="Skip models to install."
)
parser.add_argument(
"--torch",
action="store_true",
help="Only require torch to be installed, ignore torchvision and torchaudio."
)
parser.add_argument(
"--numpy",
action="store_true",
help="Only require numpy to be installed, ignore torch, torchvision and torchaudio."
)
parser.add_argument("--canary", action="store_true", help="Install canary model.")
parser.add_argument("--continue_on_fail", action="store_true")
parser.add_argument("--verbose", "-v", action="store_true")
Expand All @@ -50,13 +66,15 @@ def pip_install_requirements(requirements_txt="requirements.txt"):

os.chdir(os.path.realpath(os.path.dirname(__file__)))

if args.torch or args.userbenchmark:
TORCH_DEPS = ["numpy", "torch"]
if args.numpy:
TORCH_DEPS = ["numpy"]
print(
f"checking packages {', '.join(TORCH_DEPS)} are installed...",
end="",
flush=True,
)
if args.userbenchmark:
TORCH_DEPS = ["torch"]
try:
versions = get_pkg_versions(TORCH_DEPS)
except ModuleNotFoundError as e:
Expand Down Expand Up @@ -88,6 +106,7 @@ def pip_install_requirements(requirements_txt="requirements.txt"):

success &= setup(
models=args.models,
skip_models=args.skip,
verbose=args.verbose,
continue_on_fail=args.continue_on_fail,
test_mode=args.test_mode,
Expand Down
5 changes: 4 additions & 1 deletion torchbenchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ def _is_canary_model(model_name: str) -> bool:


def setup(
models: List[str] = [],
models: Optional[List[str]] = None,
skip_models: Optional[List[str]] = None,
verbose: bool = True,
continue_on_fail: bool = False,
test_mode: bool = False,
Expand All @@ -175,6 +176,8 @@ def setup(
)
model_paths = list(model_paths)
model_paths.extend(canary_model_paths)
skip_models = [] if not skip_models else skip_models
model_paths = [ x for x in model_paths if os.path.basename(x) not in skip_models ]
for model_path in model_paths:
print(f"running setup for {model_path}...", end="", flush=True)
if test_mode:
Expand Down
2 changes: 1 addition & 1 deletion utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from typing import Dict, List

TORCH_DEPS = ["torch", "torchvision", "torchaudio"]
TORCH_DEPS = ["numpy", "torch", "torchvision", "torchaudio"]


class add_path:
Expand Down

0 comments on commit 5254910

Please sign in to comment.