Skip to content

Commit

Permalink
Filter ranking: add support for ranking by L2 magnitude
Browse files Browse the repository at this point in the history
  • Loading branch information
nzmora committed Feb 6, 2019
1 parent 6d7288a commit 2179ec5
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 35 deletions.
95 changes: 63 additions & 32 deletions distiller/pruning/ranked_structures_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,25 @@ def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, mod
raise NotImplementedError


class L1RankedStructureParameterPruner(RankedStructureParameterPruner):
l1_magnitude = partial(torch.norm, p=1)
l2_magnitude = partial(torch.norm, p=2)


class LpRankedStructureParameterPruner(RankedStructureParameterPruner):
"""Uses mean L1-norm to rank and prune structures.
This class prunes to a prescribed percentage of structured-sparsity (level pruning).
"""
def __init__(self, name, group_type, desired_sparsity, weights, group_dependency=None, kwargs=None):
def __init__(self, name, group_type, desired_sparsity, weights,
group_dependency=None, kwargs=None, magnitude_fn=None):
super().__init__(name, group_type, desired_sparsity, weights, group_dependency)
if group_type not in ['3D', 'Filters', 'Channels', 'Rows', 'Blocks']:
raise ValueError("Structure {} was requested but "
"currently ranking of this shape is not supported".
format(group_type))
assert magnitude_fn is not None
self.magnitude_fn = magnitude_fn

if group_type == 'Blocks':
try:
self.block_shape = kwargs['block_shape']
Expand All @@ -101,18 +109,19 @@ def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, mod
if self.group_type in ['3D', 'Filters']:
group_pruning_fn = self.rank_and_prune_filters
elif self.group_type == 'Channels':
group_pruning_fn = self.rank_and_prune_channels
group_pruning_fn = partial(self.rank_and_prune_channels)
elif self.group_type == 'Rows':
group_pruning_fn = self.rank_and_prune_rows
elif self.group_type == 'Blocks':
group_pruning_fn = partial(self.rank_and_prune_blocks, block_shape=self.block_shape)

binary_map = group_pruning_fn(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
binary_map = group_pruning_fn(fraction_to_prune, param, param_name,
zeros_mask_dict, model, binary_map, self.magnitude_fn)
return binary_map

@staticmethod
def rank_and_prune_channels(fraction_to_prune, param, param_name=None,
zeros_mask_dict=None, model=None, binary_map=None):
zeros_mask_dict=None, model=None, binary_map=None, magnitude_fn=l1_magnitude):
def rank_channels(fraction_to_prune, param):
num_filters = param.size(0)
num_channels = param.size(1)
Expand All @@ -122,9 +131,9 @@ def rank_channels(fraction_to_prune, param):
# tensor, is now a row in the 2D tensor.
view_2d = param.view(-1, kernel_size)
# Next, compute the sums of each kernel
kernel_sums = view_2d.abs().sum(dim=1)
kernel_mags = magnitude_fn(view_2d, dim=1)
# Now group by channels
k_sums_mat = kernel_sums.view(num_filters, num_channels).t()
k_sums_mat = kernel_mags.view(num_filters, num_channels).t()
channel_mags = k_sums_mat.mean(dim=1)
k = int(fraction_to_prune * channel_mags.size(0))
if k == 0:
Expand Down Expand Up @@ -160,14 +169,14 @@ def binary_map_to_mask(binary_map, param):

@staticmethod
def rank_and_prune_filters(fraction_to_prune, param, param_name,
zeros_mask_dict, model=None, binary_map=None):
zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
assert param.dim() == 4, "This thresholding is only supported for 4D weights"

threshold = None
if binary_map is None:
# First we rank the filters
view_filters = param.view(param.size(0), -1)
filter_mags = view_filters.data.abs().mean(dim=1)
filter_mags = magnitude_fn(view_filters, dim=1)
topk_filters = int(fraction_to_prune * filter_mags.size(0))
if topk_filters == 0:
msglogger.info("Too few filters - can't prune %.1f%% filters", 100*fraction_to_prune)
Expand All @@ -178,7 +187,8 @@ def rank_and_prune_filters(fraction_to_prune, param, param_name,
param_name,
topk_filters, filter_mags.size(0))
# Then we threshold
mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, 'Mean_Abs', binary_map)
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
mask, binary_map = distiller.group_threshold_mask(param, 'Filters', threshold, threshold_type, binary_map)
if zeros_mask_dict is not None:
zeros_mask_dict[param_name].mask = mask
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f",
Expand All @@ -189,7 +199,7 @@ def rank_and_prune_filters(fraction_to_prune, param, param_name,

@staticmethod
def rank_and_prune_rows(fraction_to_prune, param, param_name,
zeros_mask_dict, model=None, binary_map=None):
zeros_mask_dict, model=None, binary_map=None, magnitude_fn=l1_magnitude):
"""Prune the rows of a matrix, based on ranked L1-norms of the matrix rows.
PyTorch stores the weights matrices in a transposed format. I.e. before performing GEMM, a matrix is
Expand All @@ -203,21 +213,23 @@ def rank_and_prune_rows(fraction_to_prune, param, param_name,
assert param.dim() == 2, "This thresholding is only supported for 2D weights"
ROWS_DIM = 0
THRESHOLD_DIM = 'Cols'
rows_mags = param.abs().mean(dim=ROWS_DIM)
rows_mags = magnitude_fn(param, dim=ROWS_DIM)
num_rows_to_prune = int(fraction_to_prune * rows_mags.size(0))
if num_rows_to_prune == 0:
msglogger.info("Too few filters - can't prune %.1f%% rows", 100*fraction_to_prune)
return
bottomk_rows, _ = torch.topk(rows_mags, num_rows_to_prune, largest=False, sorted=True)
threshold = bottomk_rows[-1]
zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM, threshold, 'Mean_Abs')
threshold_type = 'L1' if magnitude_fn == l1_magnitude else 'L2'
zeros_mask_dict[param_name].mask = distiller.group_threshold_mask(param, THRESHOLD_DIM,
threshold, threshold_type)
msglogger.info("L1RankedStructureParameterPruner - param: %s pruned=%.3f goal=%.3f (%d/%d)", param_name,
distiller.sparsity(zeros_mask_dict[param_name].mask),
fraction_to_prune, num_rows_to_prune, rows_mags.size(0))

@staticmethod
def rank_and_prune_blocks(fraction_to_prune, param, param_name=None,
zeros_mask_dict=None, model=None, binary_map=None, block_shape=None):
def rank_and_prune_blocks(fraction_to_prune, param, param_name=None, zeros_mask_dict=None,
model=None, binary_map=None, block_shape=None, magnitude_fn=l1_magnitude):
"""Block-wise pruning for 4D tensors.
The block shape is specified using a tuple: [block_repetitions, block_depth, block_height, block_width].
Expand Down Expand Up @@ -251,26 +263,20 @@ def rank_and_prune_blocks(fraction_to_prune, param, param_name=None,
kernel_size = param.size(2) * param.size(3)

if block_depth > 1:
view_dims = (
num_filters*num_channels//(block_repetitions*block_depth),
block_repetitions*block_depth,
kernel_size,
)
view_dims = (num_filters*num_channels//(block_repetitions*block_depth),
block_repetitions*block_depth,
kernel_size,)
else:
view_dims = (
num_filters // block_repetitions,
block_repetitions,
-1,
)
view_dims = (num_filters // block_repetitions,
block_repetitions,
-1,)

def rank_blocks(fraction_to_prune, param):
# Create a view where each block is a column
view1 = param.view(*view_dims)
# Next, compute the sums of each column (block)
block_sums = view1.abs().sum(dim=1)

# Now group by channels
block_mags = block_sums.view(-1) # flatten
block_mags = magnitude_fn(view1, dim=1)
block_mags = block_mags.view(-1) # flatten
k = int(fraction_to_prune * block_mags.size(0))
if k == 0:
msglogger.info("Too few blocks (%d)- can't prune %.1f%% blocks",
Expand Down Expand Up @@ -302,6 +308,28 @@ def binary_map_to_mask(binary_map, param):
return binary_map


class L1RankedStructureParameterPruner(LpRankedStructureParameterPruner):
"""Uses mean L1-norm to rank and prune structures.
This class prunes to a prescribed percentage of structured-sparsity (level pruning).
"""
def __init__(self, name, group_type, desired_sparsity, weights,
group_dependency=None, kwargs=None):
super().__init__(name, group_type, desired_sparsity, weights,
group_dependency, kwargs, magnitude_fn=l1_magnitude)


class L2RankedStructureParameterPruner(LpRankedStructureParameterPruner):
"""Uses mean L2-norm to rank and prune structures.
This class prunes to a prescribed percentage of structured-sparsity (level pruning).
"""
def __init__(self, name, group_type, desired_sparsity, weights,
group_dependency=None, kwargs=None):
super().__init__(name, group_type, desired_sparsity, weights,
group_dependency, kwargs, magnitude_fn=l2_magnitude)


def mask_from_filter_order(filters_ordered_by_criterion, param, num_filters, binary_map):
if binary_map is None:
binary_map = torch.zeros(num_filters).cuda()
Expand All @@ -324,7 +352,8 @@ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency
def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
if fraction_to_prune == 0:
return
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
zeros_mask_dict, model, binary_map)
return binary_map

def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
Expand Down Expand Up @@ -370,7 +399,8 @@ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency
def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
if fraction_to_prune == 0:
return
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
zeros_mask_dict, model, binary_map)
return binary_map

def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
Expand Down Expand Up @@ -402,7 +432,8 @@ def __init__(self, name, group_type, desired_sparsity, weights, group_dependency
def prune_group(self, fraction_to_prune, param, param_name, zeros_mask_dict, model=None, binary_map=None):
if fraction_to_prune == 0:
return
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map)
binary_map = self.rank_and_prune_filters(fraction_to_prune, param, param_name,
zeros_mask_dict, model, binary_map)
return binary_map

def rank_and_prune_filters(self, fraction_to_prune, param, param_name, zeros_mask_dict, model, binary_map=None):
Expand Down
10 changes: 7 additions & 3 deletions distiller/thresholding.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
if binary_map is None:
binary_map = group_threshold_binary_map(param, group_type, threshold, threshold_criteria)
a = binary_map.expand(param.size(1) * param.size(2) * param.size(3), param.size(0)).t()
return a.view(param.size(0), param.size(1), param.size(2), param.size(3)), binary_map
return a.view(*param.shape), binary_map

elif group_type == '4D':
assert param.dim() == 4, "This thresholding is only supported for 4D weights"
Expand Down Expand Up @@ -181,10 +181,14 @@ def group_threshold_mask(param, group_type, threshold, threshold_criteria, binar
def threshold_policy(weights, thresholds, threshold_criteria, dim=1):
"""
"""
if threshold_criteria == 'Mean_Abs':
return weights.data.abs().mean(dim=dim).gt(thresholds).type(weights.type())
if threshold_criteria in ['Mean_Abs', 'Mean_L1']:
return weights.data.norm(p=1, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
if threshold_criteria == 'Mean_L2':
return weights.data.norm(p=2, dim=dim).div(weights.size(dim)).gt(thresholds).type(weights.type())
elif threshold_criteria == 'L1':
return weights.data.norm(p=1, dim=dim).gt(thresholds).type(weights.type())
elif threshold_criteria == 'L2':
return weights.data.norm(p=2, dim=dim).gt(thresholds).type(weights.type())
elif threshold_criteria == 'Max':
maxv, _ = weights.data.abs().max(dim=dim)
return maxv.gt(thresholds).type(weights.type())
Expand Down

0 comments on commit 2179ec5

Please sign in to comment.