Skip to content

Commit

Permalink
CPU support: fix the case of loading a thinned GPU-model on the CPU
Browse files Browse the repository at this point in the history
This commit fixes (and adds a test) for the case that we with to load
a thinned GPU checkpoint onto the CPU.
  • Loading branch information
nzmora committed Feb 12, 2019
1 parent 0ae0754 commit ba05f6c
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
4 changes: 2 additions & 2 deletions apputils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
raise IOError(ENOENT, 'Could not find a checkpoint file at', chkpt_file)

msglogger.info("=> loading checkpoint %s", chkpt_file)
checkpoint = torch.load(chkpt_file, map_location = lambda storage, loc: storage)
checkpoint = torch.load(chkpt_file, map_location=lambda storage, loc: storage)
msglogger.debug("\n\t".join(['Checkpoint keys:'] + list(checkpoint)))

if 'state_dict' not in checkpoint:
Expand Down Expand Up @@ -121,7 +121,7 @@ def load_checkpoint(model, chkpt_file, optimizer=None):
# Cache the recipes in case we need them later
model.thinning_recipes = checkpoint['thinning_recipes']
if normalize_dataparallel_keys:
model.thinning_recipes = {normalize_module_name(k): v for k, v in model.thinning_recipes.items()}
model.thinning_recipes = [distiller.get_normalized_recipe(recipe) for recipe in model.thinning_recipes]
distiller.execute_thinning_recipes_list(model,
compression_scheduler.zeros_mask_dict,
model.thinning_recipes)
Expand Down
8 changes: 7 additions & 1 deletion distiller/thinning.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
'ChannelRemover', 'remove_channels',
'FilterRemover', 'remove_filters',
'find_nonzero_channels', 'find_nonzero_channels_list',
'execute_thinning_recipes_list']
'execute_thinning_recipes_list', 'get_normalized_recipe']


def create_graph(dataset, arch):
Expand All @@ -77,6 +77,12 @@ def create_graph(dataset, arch):
return SummaryGraph(model, dummy_input)


def get_normalized_recipe(recipe):
new_recipe = ThinningRecipe(modules={normalize_module_name(k): v for k, v in recipe.modules.items()},
parameters={normalize_module_name(k): v for k, v in recipe.parameters.items()})
return new_recipe


def param_name_2_layer_name(param_name):
return param_name[:-len('weights')]

Expand Down
37 changes: 35 additions & 2 deletions tests/test_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
sys.path.append(module_path)

import distiller
from apputils import load_checkpoint
from apputils import save_checkpoint, load_checkpoint
from models import create_model


Expand All @@ -39,6 +39,7 @@ def test_load():
assert compression_scheduler is not None
assert start_epoch == 180


def test_load_state_dict():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
Expand All @@ -52,6 +53,7 @@ def test_load_state_dict():
assert compression_scheduler is None
assert start_epoch == 0


def test_load_dumb_checkpoint():
# prepare lean checkpoint
state_dict_arrays = torch.load('../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar').get('state_dict')
Expand All @@ -62,19 +64,50 @@ def test_load_dumb_checkpoint():
with pytest.raises(ValueError):
model, compression_scheduler, start_epoch = load_checkpoint(model, tmpfile.name)


def test_load_negative():
with pytest.raises(FileNotFoundError):
model = create_model(False, 'cifar10', 'resnet20_cifar')
model, compression_scheduler, start_epoch = load_checkpoint(model, 'THIS_IS_AN_ERROR/checkpoint_trained_dense.pth.tar')


def test_load_gpu_model_on_cpu():
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=-1)
# Issue #148
CPU_DEVICE_ID = -1
model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
model, compression_scheduler, start_epoch = load_checkpoint(model,
'../examples/ssl/checkpoints/checkpoint_trained_dense.pth.tar')
assert compression_scheduler is not None
assert start_epoch == 180
assert distiller.model_device(model) == 'cpu'


def test_load_gpu_model_on_cpu_with_thinning():
# Issue #148
# 1. create a GPU model and remove 50% of the filters in one of the layers (thninning)
# 2. save the thinned model in a checkpoint file
# 3. load the checkpoint and place it on the CPU
CPU_DEVICE_ID = -1
gpu_model = create_model(False, 'cifar10', 'resnet20_cifar')
conv_pname = "module.layer1.0.conv1.weight"
conv_p = distiller.model_find_param(gpu_model, conv_pname)
pruner = distiller.pruning.L1RankedStructureParameterPruner("test_pruner", group_type="Filters",
desired_sparsity=0.5, weights=conv_pname)
zeros_mask_dict = distiller.create_model_masks_dict(gpu_model)
pruner.set_param_mask(conv_p, conv_pname, zeros_mask_dict, meta=None)

# Use the mask to prune
zeros_mask_dict[conv_pname].apply_mask(conv_p)
distiller.remove_filters(gpu_model, zeros_mask_dict, 'resnet20_cifar', 'cifar10', optimizer=None)
assert hasattr(gpu_model, 'thinning_recipes')
scheduler = distiller.CompressionScheduler(gpu_model)
save_checkpoint(epoch=0, arch='resnet20_cifar', model=gpu_model, scheduler=scheduler, optimizer=None)

CPU_DEVICE_ID = -1
cpu_model = create_model(False, 'cifar10', 'resnet20_cifar', device_ids=CPU_DEVICE_ID)
load_checkpoint(cpu_model, "checkpoint.pth.tar")
assert distiller.model_device(cpu_model) == 'cpu'


if __name__ == '__main__':
test_load_gpu_model_on_cpu()

0 comments on commit ba05f6c

Please sign in to comment.