Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
compress_classifier.py refactoring (#126)
Browse files Browse the repository at this point in the history
* Support for multi-phase activations logging

Enable logging activation both durning training and validation at
the same session.

* Refactoring: Move parser to its own file

* Parser is moved from compress_classifier into its own file.
* Torch version check is moved to precede main() call.
* Move main definition to the top of the file.
* Modify parser choices to case-insensitive
  • Loading branch information
barrh authored and nzmora committed Jan 16, 2019
1 parent 4cc0e7d commit cfbc379
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 179 deletions.
244 changes: 66 additions & 178 deletions examples/classifier_compression/compress_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,190 +79,21 @@
from distiller.data_loggers import *
import distiller.quantization as quantization
from models import ALL_MODEL_NAMES, create_model
import parser


# Logger handle
msglogger = None


def float_range(val_str):
val = float(val_str)
if val < 0 or val >= 1:
raise argparse.ArgumentTypeError('Must be >= 0 and < 1 (received {0})'.format(val_str))
return val


parser = argparse.ArgumentParser(description='Distiller image classification model compression')
parser.add_argument('data', metavar='DIR', help='path to dataset')
parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
choices=ALL_MODEL_NAMES,
help='model architecture: ' +
' | '.join(ALL_MODEL_NAMES) +
' (default: resnet18)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--epochs', default=90, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)')
parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
help='evaluate model on validation set')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--act-stats', dest='activation_stats', choices=["train", "valid", "test"], default=None,
help='collect activation statistics (WARNING: this slows down training)')
parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
help='print masks sparsity table at end of each epoch')
parser.add_argument('--param-hist', dest='log_params_histograms', action='store_true', default=False,
help='log the parameter tensors histograms to file (WARNING: this can use significant disk space)')
SUMMARY_CHOICES = ['sparsity', 'compute', 'model', 'modules', 'png', 'png_w_params', 'onnx']
parser.add_argument('--summary', type=str, choices=SUMMARY_CHOICES,
help='print a summary of the model, and exit - options: ' +
' | '.join(SUMMARY_CHOICES))
parser.add_argument('--compress', dest='compress', type=str, nargs='?', action='store',
help='configuration file for pruning the model (default is to use hard-coded schedule)')
parser.add_argument('--sense', dest='sensitivity', choices=['element', 'filter', 'channel'],
help='test the sensitivity of layers to pruning')
parser.add_argument('--sense-range', dest='sensitivity_range', type=float, nargs=3, default=[0.0, 0.95, 0.05],
help='an optional parameter for sensitivity testing providing the range of sparsities to test.\n'
'This is equivalent to creating sensitivities = np.arange(start, stop, step)')
parser.add_argument('--extras', default=None, type=str,
help='file with extra configuration information')
parser.add_argument('--deterministic', '--det', action='store_true',
help='Ensure deterministic execution for re-producible results.')
parser.add_argument('--gpus', metavar='DEV_ID', default=None,
help='Comma-separated list of GPU device IDs to be used (default is to use all available devices)')
parser.add_argument('--cpu', action='store_true', default=False,
help='Use CPU only. \n'
'Flag not set => uses GPUs according to the --gpus flag value.'
'Flag set => overrides the --gpus flag')
parser.add_argument('--name', '-n', metavar='NAME', default=None, help='Experiment name')
parser.add_argument('--out-dir', '-o', dest='output_dir', default='logs', help='Path to dump logs and checkpoints')
parser.add_argument('--validation-size', '--vs', type=float_range, default=0.1,
help='Portion of training dataset to set aside for validation')
parser.add_argument('--adc', dest='ADC', action='store_true', help='temp HACK')
parser.add_argument('--adc-params', dest='ADC_params', default=None, help='temp HACK')
parser.add_argument('--confusion', dest='display_confusion', default=False, action='store_true',
help='Display the confusion matrix')
parser.add_argument('--earlyexit_lossweights', type=float, nargs='*', dest='earlyexit_lossweights', default=None,
help='List of loss weights for early exits (e.g. --earlyexit_lossweights 0.1 0.3)')
parser.add_argument('--earlyexit_thresholds', type=float, nargs='*', dest='earlyexit_thresholds', default=None,
help='List of EarlyExit thresholds (e.g. --earlyexit_thresholds 1.2 0.9)')
parser.add_argument('--num-best-scores', dest='num_best_scores', default=1, type=int,
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')

str_to_quant_mode_map = {'sym': quantization.LinearQuantMode.SYMMETRIC,
'asym_s': quantization.LinearQuantMode.ASYMMETRIC_SIGNED,
'asym_u': quantization.LinearQuantMode.ASYMMETRIC_UNSIGNED}


def linear_quant_mode_str(val_str):
try:
return str_to_quant_mode_map[val_str]
except KeyError:
raise argparse.ArgumentError('Must be one of {0} (received {1})'.format(list(str_to_quant_mode_map.keys()),
val_str))


quant_group = parser.add_argument_group('Arguments controlling quantization at evaluation time'
'("post-training quantization)')
quant_group.add_argument('--quantize-eval', '--qe', action='store_true',
help='Apply linear quantization to model before evaluation. Applicable only if'
'--evaluate is also set')
quant_group.add_argument('--qe-mode', '--qem', type=linear_quant_mode_str, default='sym',
help='Linear quantization mode. Choices: ' + ' | '.join(str_to_quant_mode_map.keys()))
quant_group.add_argument('--qe-bits-acts', '--qeba', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of activations')
quant_group.add_argument('--qe-bits-wts', '--qebw', type=int, default=8, metavar='NUM_BITS',
help='Number of bits for quantization of weights')
quant_group.add_argument('--qe-bits-accum', type=int, default=32, metavar='NUM_BITS',
help='Number of bits for quantization of the accumulator')
quant_group.add_argument('--qe-clip-acts', '--qeca', action='store_true',
help='Enable clipping of activations using min/max values averaging over batch')
quant_group.add_argument('--qe-no-clip-layers', '--qencl', type=str, nargs='+', metavar='LAYER_NAME', default=[],
help='List of layer names for which not to clip activations. Applicable only if '
'--qe-clip-acts is also set')
quant_group.add_argument('--qe-per-channel', '--qepc', action='store_true',
help='Enable per-channel quantization of weights (per output channel)')

distiller.knowledge_distillation.add_distillation_args(parser, ALL_MODEL_NAMES, True)


def check_pytorch_version():
if torch.__version__ < '0.4.0':
print("\nNOTICE:")
print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
"PyTorch API changes which are not backward-compatible.\n"
"Please install PyTorch 0.4.0 or its derivative.\n"
"If you are using a virtual environment, do not forget to update it:\n"
" 1. Deactivate the old environment\n"
" 2. Install the new environment\n"
" 3. Activate the new environment")
exit(1)


def create_activation_stats_collectors(model, collection_phase):
"""Create objects that collect activation statistics.
This is a utility function that creates two collectors:
1. Fine-grade sparsity levels of the activations
2. L1-magnitude of each of the activation channels
Args:
model - the model on which we want to collect statistics
phase - the statistics collection phase which is either "train" (for training),
or "valid" (for validation)
WARNING! Enabling activation statsitics collection will significantly slow down training!
"""
class missingdict(dict):
"""This is a little trick to prevent KeyError"""
def __missing__(self, key):
return None # note, does *not* set self[key] - we don't want defaultdict's behavior

distiller.utils.assign_layer_fq_names(model)

activations_collectors = {"train": missingdict(), "valid": missingdict(), "test": missingdict()}
if collection_phase is None:
return activations_collectors
collectors = missingdict({
"sparsity": SummaryActivationStatsCollector(model, "sparsity",
lambda t: 100 * distiller.utils.sparsity(t)),
"l1_channels": SummaryActivationStatsCollector(model, "l1_channels",
distiller.utils.activation_channels_l1),
"apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
distiller.utils.activation_channels_apoz),
"records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
})
activations_collectors[collection_phase] = collectors
return activations_collectors


def save_collectors_data(collectors, directory):
"""Utility function that saves all activation statistics to Excel workbooks
"""
for name, collector in collectors.items():
workbook = os.path.join(directory, name)
msglogger.info("Generating {}".format(workbook))
collector.to_xlsx(workbook)


def main():
global msglogger
check_pytorch_version()
args = parser.parse_args()

# Parse arguments
prsr = parser.getParser()
distiller.knowledge_distillation.add_distillation_args(prsr, ALL_MODEL_NAMES, True)
args = prsr.parse_args()

if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
msglogger = apputils.config_pylogger(os.path.join(script_dir, 'logging.conf'), args.name, args.output_dir)
Expand Down Expand Up @@ -369,7 +200,7 @@ def main():
msglogger.info('Dataset sizes:\n\ttraining=%d\n\tvalidation=%d\n\ttest=%d',
len(train_loader.sampler), len(val_loader.sampler), len(test_loader.sampler))

activations_collectors = create_activation_stats_collectors(model, collection_phase=args.activation_stats)
activations_collectors = create_activation_stats_collectors(model, *args.activation_stats)

if args.sensitivity is not None:
sensitivities = np.arange(args.sensitivity_range[0], args.sensitivity_range[1], args.sensitivity_range[2])
Expand Down Expand Up @@ -784,7 +615,7 @@ def sensitivity_analysis(model, criterion, data_loader, loggers, args, sparsitie
loggers = [loggers]
test_fnc = partial(test, test_loader=data_loader, criterion=criterion,
loggers=loggers, args=args,
activations_collectors=create_activation_stats_collectors(model, None))
activations_collectors=create_activation_stats_collectors(model))
which_params = [param_name for param_name, _ in model.named_parameters()]
sensitivity = distiller.perform_sensitivity_analysis(model,
net_params=which_params,
Expand Down Expand Up @@ -823,8 +654,65 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):
ADC.do_adc(model, args.dataset, args.arch, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)


def create_activation_stats_collectors(model, *phases):
"""Create objects that collect activation statistics.
This is a utility function that creates two collectors:
1. Fine-grade sparsity levels of the activations
2. L1-magnitude of each of the activation channels
Args:
model - the model on which we want to collect statistics
phases - the statistics collection phases: train, valid, and/or test
WARNING! Enabling activation statsitics collection will significantly slow down training!
"""
class missingdict(dict):
"""This is a little trick to prevent KeyError"""
def __missing__(self, key):
return None # note, does *not* set self[key] - we don't want defaultdict's behavior

distiller.utils.assign_layer_fq_names(model)

genCollectors = lambda: missingdict({
"sparsity": SummaryActivationStatsCollector(model, "sparsity",
lambda t: 100 * distiller.utils.sparsity(t)),
"l1_channels": SummaryActivationStatsCollector(model, "l1_channels",
distiller.utils.activation_channels_l1),
"apoz_channels": SummaryActivationStatsCollector(model, "apoz_channels",
distiller.utils.activation_channels_apoz),
"records": RecordsActivationStatsCollector(model, classes=[torch.nn.Conv2d])
})

return {k: (genCollectors() if k in phases else missingdict())
for k in ('train', 'valid', 'test')}


def save_collectors_data(collectors, directory):
"""Utility function that saves all activation statistics to Excel workbooks
"""
for name, collector in collectors.items():
workbook = os.path.join(directory, name)
msglogger.info("Generating {}".format(workbook))
collector.to_xlsx(workbook)


def check_pytorch_version():
if torch.__version__ < '0.4.0':
print("\nNOTICE:")
print("The Distiller \'master\' branch now requires at least PyTorch version 0.4.0 due to "
"PyTorch API changes which are not backward-compatible.\n"
"Please install PyTorch 0.4.0 or its derivative.\n"
"If you are using a virtual environment, do not forget to update it:\n"
" 1. Deactivate the old environment\n"
" 2. Install the new environment\n"
" 3. Activate the new environment")
exit(1)


if __name__ == '__main__':
try:
check_pytorch_version()
main()
except KeyboardInterrupt:
print("\n-- KeyboardInterrupt --")
Expand Down
Loading

0 comments on commit cfbc379

Please sign in to comment.