Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Modify parser choices to case-insensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
barrh committed Jan 14, 2019
1 parent 9a1e28a commit 0c42a48
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions examples/classifier_compression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
def getParser():
parser = argparse.ArgumentParser(description='Distiller image classification model compression')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18', type=lambda s: s.lower(),
choices=models.ALL_MODEL_NAMES,
help='model architecture: ' +
' | '.join(models.ALL_MODEL_NAMES) +
Expand Down Expand Up @@ -58,12 +58,12 @@ def getParser():
help='print masks sparsity table at end of each epoch')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
parser.add_argument('--summary', type=lambda s: s.lower(), choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'], type=lambda s: s.lower(),
help='test the sensitivity of layers to pruning')
parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
Expand Down
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
if name.islower() and not name.startswith("__")
and callable(cifar10_models.__dict__[name]))

ALL_MODEL_NAMES = sorted(set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES))
ALL_MODEL_NAMES = sorted(map(lambda s: s.lower(), set(IMAGENET_MODEL_NAMES + CIFAR10_MODEL_NAMES)))


def create_model(pretrained, dataset, arch, parallel=True, device_ids=None):
Expand Down

0 comments on commit 0c42a48

Please sign in to comment.