From ac9f61c06ac5299a7a00baf582d362570279fe15 Mon Sep 17 00:00:00 2001 From: Neta Zmora Date: Thu, 14 Feb 2019 17:06:22 +0200 Subject: [PATCH] Fix automated-compression imports To use automated compression you need to install several optional packages which are not required for other use-cases. This fix hides the import requirements for users who do not want to install the extra packages. --- examples/automated_deep_compression/ADC.py | 9 +++++++-- examples/automated_deep_compression/__init__.py | 5 ++++- examples/classifier_compression/compress_classifier.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/automated_deep_compression/ADC.py b/examples/automated_deep_compression/ADC.py index a04781150..590acf4c6 100755 --- a/examples/automated_deep_compression/ADC.py +++ b/examples/automated_deep_compression/ADC.py @@ -65,7 +65,12 @@ import numpy as np import torch import csv -import gym +try: + import gym +except ImportError as e: + print("WARNING: to use automated compression you will need to install extra packages") + print("See instructions in the header of examples/automated_deep_compression/ADC.py") + raise e from gym import spaces import distiller from apputils import SummaryGraph @@ -183,7 +188,7 @@ def amc_reward_fn(env, top1, top5, vloss, total_macs): experimental_reward_fn = harmonic_mean_reward_fn -def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): +def do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): dataset = args.dataset arch = args.arch perform_thinning = True # args.amc_thinning diff --git a/examples/automated_deep_compression/__init__.py b/examples/automated_deep_compression/__init__.py index a05f8a23b..5b2e97cc9 100755 --- a/examples/automated_deep_compression/__init__.py +++ b/examples/automated_deep_compression/__init__.py @@ -1,2 +1,5 @@ from .automl_args import add_automl_args -from .ADC import do_adc + +def do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn): + from .ADC import do_adc_internal + do_adc_internal(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) diff --git a/examples/classifier_compression/compress_classifier.py b/examples/classifier_compression/compress_classifier.py index 6f18978bc..b3d951311 100755 --- a/examples/classifier_compression/compress_classifier.py +++ b/examples/classifier_compression/compress_classifier.py @@ -657,7 +657,7 @@ def automated_deep_compression(model, criterion, optimizer, loggers, args): save_checkpoint_fn = partial(apputils.save_checkpoint, arch=args.arch, dir=msglogger.logdir) optimizer_data = {'lr': args.lr, 'momentum': args.momentum, 'weight_decay': args.weight_decay} - adc.ADC.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) + adc.do_adc(model, args, optimizer_data, validate_fn, save_checkpoint_fn, train_fn) def greedy(model, criterion, optimizer, loggers, args):