Skip to content

Commit

Permalink
Fix automated-compression imports
Browse files Browse the repository at this point in the history
To use automated compression you need to install several optional packages
which are not required for other use-cases.
This fix hides the import requirements for users who do not want to install
the extra packages.
  • Loading branch information
nzmora committed Feb 14, 2019
1 parent d7b5a50 commit ac9f61c
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
9 changes: 7 additions & 2 deletions examples/automated_deep_compression/ADC.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,12 @@
import numpy as np
import torch
import csv
import gym
try:
import gym
except ImportError as e:
print("WARNING: to use automated compression you will need to install extra packages")
print("See instructions in the header of examples/automated_deep_compression/ADC.py")
raise e
from gym import spaces
import distiller
from apputils import SummaryGraph
Expand Down Expand Up @@ -183,7 +188,7 @@ def amc_reward_fn(env, top1, top5, vloss, total_macs):
experimental_reward_fn = harmonic_mean_reward_fn


def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
dataset = args.dataset
arch = args.arch
perform_thinning = True # args.amc_thinning
Expand Down
5 changes: 4 additions & 1 deletion examples/automated_deep_compression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .automl_args import add_automl_args
from .ADC import do_adc

def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn):
from .ADC import do_adc_internal
do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
2 changes: 1 addition & 1 deletion examples/classifier_compression/compress_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args):

save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir)
optimizer_data = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay}
adc.ADC.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)
adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn)


def greedy(model, criterion, optimizer, loggers, args):
Expand Down

0 comments on commit ac9f61c

Please sign in to comment.