Skip to content

Commit

Permalink
compress_classifier: add command-line option to "thinnify" a model
Browse files Browse the repository at this point in the history
  • Loading branch information
nzmora committed Jan 31, 2019
1 parent d81927d commit 2650f8f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
13 changes: 10 additions & 3 deletions examples/classifier_compression/compress_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
"""

import math
import argparse
import time
import os
import sys
Expand Down Expand Up @@ -171,8 +170,7 @@ def main():

# We can optionally resume from a checkpoint
if args.resume:
model, compression_scheduler, start_epoch = apputils.load_checkpoint(
model, chkpt_file=args.resume)
model, compression_scheduler, start_epoch = apputils.load_checkpoint(model, chkpt_file=args.resume)
model.to(args.device)

# Define loss function (criterion) and optimizer
Expand Down Expand Up @@ -218,6 +216,15 @@ def main():
elif compression_scheduler is None:
compression_scheduler = distiller.CompressionScheduler(model)

if args.thinnify:
#zeros_mask_dict = distiller.create_model_masks_dict(model)
assert args.resume is not None, "You must use --resume to provide a checkpoint file to thinnify"
distiller.remove_filters(model, compression_scheduler.zeros_mask_dict, args.arch, args.dataset, optimizer=None)
apputils.save_checkpoint(0, args.arch, model, optimizer=None, scheduler=compression_scheduler,
name="{}_thinned".format(args.resume.replace(".pth.tar", "")), dir=msglogger.logdir)
print("Note: your model may have collapsed to random inference, so you may want to fine-tune")
return

args.kd_policy = None
if args.kd_teacher:
teacher = create_model(args.kd_pretrained, args.dataset, args.kd_teacher, device_ids=args.gpus)
Expand Down
3 changes: 2 additions & 1 deletion examples/classifier_compression/parser.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def getParser():
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--activation-stats', '--act-stats', nargs='+', metavar='PHASE', default=list(),
# choices=["train", "valid", "test"]
help='collect activation statistics on phases: train, valid, and/or test'
' (WARNING: this slows down training)')
parser.add_argument('--masks-sparsity', dest='masks_sparsity', action='store_true', default=False,
Expand Down Expand Up @@ -94,6 +93,8 @@ def getParser():
help='number of best scores to track and report (default: 1)')
parser.add_argument('--load-serialized', dest='load_serialized', action='store_true', default=False,
help='Load a model without DataParallel wrapping it')
parser.add_argument('--thinnify', dest='thinnify', action='store_true', default=False,
help='physically remove zero-filters and create a smaller model')

str_to_quant_mode_map = {
'sym': distiller.quantization.LinearQuantMode.SYMMETRIC,
Expand Down

0 comments on commit 2650f8f

Please sign in to comment.