Skip to content

Commit

Permalink
Merge pull request #900 from f0k/grouped-conv
Browse files Browse the repository at this point in the history
Support grouped convolutions
  • Loading branch information
f0k committed Feb 28, 2018
2 parents 8978b1d + ae33f85 commit 7b4d1e2
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 30 deletions.
63 changes: 56 additions & 7 deletions lasagne/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .. import init
from .. import nonlinearities
from ..utils import as_tuple
from ..utils import as_tuple, inspect_kwargs
from ..theano_extensions import conv

from .base import Layer
Expand Down Expand Up @@ -246,6 +246,12 @@ class BaseConvLayer(Layer):
Lasagne, flipping incurs an overhead and is disabled by default --
check the documentation when using learned weights from another layer.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
n : int or None
The dimensionality of the convolution (i.e., the number of spatial
dimensions of each feature map and each convolutional filter). If
Expand All @@ -266,7 +272,7 @@ def __init__(self, incoming, num_filters, filter_size, stride=1, pad=0,
untie_biases=False,
W=init.GlorotUniform(), b=init.Constant(0.),
nonlinearity=nonlinearities.rectify, flip_filters=True,
n=None, **kwargs):
num_groups=1, n=None, **kwargs):
super(BaseConvLayer, self).__init__(incoming, **kwargs)
if nonlinearity is None:
self.nonlinearity = nonlinearities.identity
Expand Down Expand Up @@ -298,6 +304,19 @@ def __init__(self, incoming, num_filters, filter_size, stride=1, pad=0,
else:
self.pad = as_tuple(pad, n, int)

if (num_groups <= 0 or
self.num_filters % num_groups != 0 or
self.input_shape[1] % num_groups != 0):
raise ValueError(
"num_groups (here: %d) must be positive and evenly divide the "
"number of input and output channels (here: %d and %d)" %
(num_groups, self.input_shape[1], self.num_filters))
elif (num_groups > 1 and
"num_groups" not in inspect_kwargs(T.nnet.conv2d)):
raise RuntimeError("num_groups > 1 requires "
"Theano 0.10 or later") # pragma: no cover
self.num_groups = num_groups

self.W = self.add_param(W, self.get_W_shape(), name="W")
if b is None:
self.b = None
Expand All @@ -317,7 +336,7 @@ def get_W_shape(self):
tuple of int
The shape of the weight matrix.
"""
num_input_channels = self.input_shape[1]
num_input_channels = self.input_shape[1] // self.num_groups
return (self.num_filters, num_input_channels) + self.filter_size

def get_output_shape_for(self, input_shape):
Expand Down Expand Up @@ -448,13 +467,19 @@ class Conv1DLayer(BaseConvLayer):
Lasagne, flipping incurs an overhead and is disabled by default --
check the documentation when using learned weights from another layer.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
convolution : callable
The convolution implementation to use. The
`lasagne.theano_extensions.conv` module provides some alternative
implementations for 1D convolutions, because the Theano API only
features a 2D convolution implementation. Usually it should be fine
to leave this at the default value. Note that not all implementations
support all settings for `pad` and `subsample`.
support all settings for `pad`, `subsample` and `num_groups`.
**kwargs
Any additional keyword arguments are passed to the `Layer` superclass.
Expand All @@ -480,11 +505,15 @@ def __init__(self, incoming, num_filters, filter_size, stride=1,

def convolve(self, input, **kwargs):
border_mode = 'half' if self.pad == 'same' else self.pad
extra_kwargs = {}
if self.num_groups > 1: # pragma: no cover
extra_kwargs['num_groups'] = self.num_groups
conved = self.convolution(input, self.W,
self.input_shape, self.get_W_shape(),
subsample=self.stride,
border_mode=border_mode,
filter_flip=self.flip_filters)
filter_flip=self.flip_filters,
**extra_kwargs)
return conved


Expand Down Expand Up @@ -576,6 +605,12 @@ class Conv2DLayer(BaseConvLayer):
Lasagne, flipping incurs an overhead and is disabled by default --
check the documentation when using learned weights from another layer.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
convolution : callable
The convolution implementation to use. Usually it should be fine to
leave this at the default value.
Expand Down Expand Up @@ -604,11 +639,15 @@ def __init__(self, incoming, num_filters, filter_size, stride=(1, 1),

def convolve(self, input, **kwargs):
border_mode = 'half' if self.pad == 'same' else self.pad
extra_kwargs = {}
if self.num_groups > 1: # pragma: no cover
extra_kwargs['num_groups'] = self.num_groups
conved = self.convolution(input, self.W,
self.input_shape, self.get_W_shape(),
subsample=self.stride,
border_mode=border_mode,
filter_flip=self.flip_filters)
filter_flip=self.flip_filters,
**extra_kwargs)
return conved


Expand Down Expand Up @@ -699,6 +738,12 @@ class Conv3DLayer(BaseConvLayer): # pragma: no cover
Lasagne, flipping incurs an overhead and is disabled by default --
check the documentation when using learned weights from another layer.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
convolution : callable
The convolution implementation to use. Usually it should be fine to
leave this at the default value.
Expand Down Expand Up @@ -729,11 +774,15 @@ def __init__(self, incoming, num_filters, filter_size, stride=(1, 1, 1),

def convolve(self, input, **kwargs):
border_mode = 'half' if self.pad == 'same' else self.pad
extra_kwargs = {}
if self.num_groups > 1: # pragma: no cover
extra_kwargs['num_groups'] = self.num_groups
conved = self.convolution(input, self.W,
self.input_shape, self.get_W_shape(),
subsample=self.stride,
border_mode=border_mode,
filter_flip=self.flip_filters)
filter_flip=self.flip_filters,
**extra_kwargs)
return conved


Expand Down
16 changes: 12 additions & 4 deletions lasagne/layers/corrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class Conv2DMMLayer(BaseConvLayer):
be set to ``True`` if weights are loaded into it that were learnt using
a regular :class:`lasagne.layers.Conv2DLayer`, for example.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
**kwargs
Any additional keyword arguments are passed to the `Layer` superclass.
Expand All @@ -150,14 +156,16 @@ class Conv2DMMLayer(BaseConvLayer):
def __init__(self, incoming, num_filters, filter_size, stride=(1, 1),
pad=0, untie_biases=False, W=init.GlorotUniform(),
b=init.Constant(0.), nonlinearity=nonlinearities.rectify,
flip_filters=False, **kwargs):
flip_filters=False, num_groups=1, **kwargs):
super(Conv2DMMLayer, self).__init__(incoming, num_filters, filter_size,
stride, pad, untie_biases, W, b,
nonlinearity, flip_filters, n=2,
**kwargs)
nonlinearity, flip_filters,
num_groups, n=2, **kwargs)
border_mode = 'half' if self.pad == 'same' else self.pad
extra_kwargs = {'num_groups': num_groups} if num_groups > 1 else {}
self.corr_mm_op = GpuCorrMM(subsample=self.stride,
border_mode=border_mode)
border_mode=border_mode,
**extra_kwargs)

def convolve(self, input, **kwargs):
filters = self.W
Expand Down
36 changes: 28 additions & 8 deletions lasagne/layers/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,12 @@ class Conv2DDNNLayer(BaseConvLayer):
be set to ``True`` if weights are loaded into it that were learnt using
a regular :class:`lasagne.layers.Conv2DLayer`, for example.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
**kwargs
Any additional keyword arguments are passed to the `Layer` superclass.
Expand All @@ -389,25 +395,29 @@ class Conv2DDNNLayer(BaseConvLayer):
def __init__(self, incoming, num_filters, filter_size, stride=(1, 1),
pad=0, untie_biases=False, W=init.GlorotUniform(),
b=init.Constant(0.), nonlinearity=nonlinearities.rectify,
flip_filters=False, **kwargs):
flip_filters=False, num_groups=1, **kwargs):
super(Conv2DDNNLayer, self).__init__(incoming, num_filters,
filter_size, stride, pad,
untie_biases, W, b, nonlinearity,
flip_filters, n=2, **kwargs)
flip_filters, num_groups, n=2,
**kwargs)

def convolve(self, input, **kwargs):
# by default we assume 'cross', consistent with corrmm.
conv_mode = 'conv' if self.flip_filters else 'cross'
border_mode = self.pad
if border_mode == 'same':
border_mode = tuple(s // 2 for s in self.filter_size)
extra_kwargs = {}
if self.num_groups > 1: # pragma: no cover
extra_kwargs = {'num_groups': self.num_groups}

conved = dnn.dnn_conv(img=input,
kerns=self.W,
subsample=self.stride,
border_mode=border_mode,
conv_mode=conv_mode
)
conv_mode=conv_mode,
**extra_kwargs)
return conved


Expand Down Expand Up @@ -500,6 +510,12 @@ class Conv3DDNNLayer(BaseConvLayer):
anyway because the filters are learned, but if you want to compute
predictions with pre-trained weights, take care if they need flipping.
num_groups : int (default: 1)
The number of groups to split the input channels and output channels
into, such that data does not cross the group boundaries. Requires the
number of channels to be divisible by the number of groups, and
requires Theano 0.10 or later for more than one group.
**kwargs
Any additional keyword arguments are passed to the `Layer` superclass.
Expand All @@ -514,25 +530,29 @@ class Conv3DDNNLayer(BaseConvLayer):
def __init__(self, incoming, num_filters, filter_size, stride=(1, 1, 1),
pad=0, untie_biases=False, W=init.GlorotUniform(),
b=init.Constant(0.), nonlinearity=nonlinearities.rectify,
flip_filters=False, **kwargs):
flip_filters=False, num_groups=1, **kwargs):
super(Conv3DDNNLayer, self).__init__(incoming, num_filters,
filter_size, stride, pad,
untie_biases, W, b, nonlinearity,
flip_filters, n=3, **kwargs)
flip_filters, num_groups, n=3,
**kwargs)

def convolve(self, input, **kwargs):
# by default we assume 'cross', consistent with corrmm.
conv_mode = 'conv' if self.flip_filters else 'cross'
border_mode = self.pad
if border_mode == 'same':
border_mode = tuple(s // 2 for s in self.filter_size)
extra_kwargs = {}
if self.num_groups > 1:
extra_kwargs = {'num_groups': self.num_groups}

conved = dnn.dnn_conv3d(img=input,
kerns=self.W,
subsample=self.stride,
border_mode=border_mode,
conv_mode=conv_mode
)
conv_mode=conv_mode,
**extra_kwargs)
return conved


Expand Down
37 changes: 31 additions & 6 deletions lasagne/tests/layers/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
theano_backend = "cpu"


def convNd(input, kernel, pad, stride=1, n=None):
def convNd(input, kernel, pad, stride=1, groups=1, n=None):
"""Execute a batch of a stack of N-dimensional convolutions.
Parameters
Expand All @@ -34,12 +34,20 @@ def convNd(input, kernel, pad, stride=1, n=None):
kernel : numpy array
pad : {0, 'valid', 'same', 'full'}, int or tuple of int
stride : int or tuple of int
groups: int
n : int
Returns
-------
numpy array
"""
if groups > 1:
input = input.reshape(input.shape[0], groups, -1, *input.shape[2:])
kernel = kernel.reshape(groups, -1, *kernel.shape[1:])
return np.concatenate([convNd(input[:, g], kernel[g], pad, stride,
groups=1, n=n)
for g in range(groups)], axis=1)

if n is None:
n = input.ndim - 2
if pad not in ['valid', 'same', 'full']:
Expand Down Expand Up @@ -171,6 +179,14 @@ def _convert(input, kernel, output, kwargs):
output = convNd(input, kernel[flip], pad='valid')
yield _convert(input, kernel, output, {'flip_filters': False})

# num_groups=3 case
input_shape = (2, 6) + extra_shape[-n:]
input = np.random.random(input_shape)
kernel = np.random.random((9, 2) + (3,) * n)
output = convNd(input, kernel, pad='valid', groups=3)
yield _convert(input, kernel, output, {'num_groups': 3,
'flip_filters': True})


def conv3d_test_sets():
return convNd_test_sets(3)
Expand Down Expand Up @@ -374,6 +390,15 @@ def test_fail_on_mismatching_dimensionality(self):
BaseConvLayer((10, 20, 30, 40), 1, 3, n=1)
assert "Expected 3 input dimensions" in exc.value.args[0]

def test_fail_on_mismatching_groups(self):
from lasagne.layers.conv import BaseConvLayer
with pytest.raises(ValueError) as exc:
BaseConvLayer((2, 3, 4), 1, 3, num_groups=2)
assert "evenly divide" in exc.value.args[0]
with pytest.raises(ValueError) as exc:
BaseConvLayer((2, 3, 4), 1, 3, num_groups=-3)
assert "must be positive" in exc.value.args[0]


class TestConv1DLayer:

Expand All @@ -397,7 +422,7 @@ def test_defaults(self, DummyInputLayer,
assert actual.shape == layer.output_shape
assert np.allclose(actual, output)

except NotImplementedError:
except (NotImplementedError, RuntimeError):
pass

def test_init_none_nonlinearity_bias(self, DummyInputLayer):
Expand Down Expand Up @@ -460,7 +485,7 @@ def test_defaults(self, Conv2DImpl, DummyInputLayer,
assert actual.shape == layer.output_shape
assert np.allclose(actual, output)

except NotImplementedError:
except (NotImplementedError, RuntimeError):
pytest.skip()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -488,7 +513,7 @@ def test_with_nones(self, Conv2DImpl, DummyInputLayer,
assert actual.shape == output.shape
assert np.allclose(actual, output)

except NotImplementedError:
except (NotImplementedError, RuntimeError):
pytest.skip()

def test_init_none_nonlinearity_bias(self, Conv2DImpl, DummyInputLayer):
Expand Down Expand Up @@ -557,7 +582,7 @@ def test_defaults(self, Conv3DImpl, DummyInputLayer,
assert actual.shape == layer.output_shape
assert np.allclose(actual, output)

except NotImplementedError:
except (NotImplementedError, RuntimeError):
pytest.skip()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -586,7 +611,7 @@ def test_with_nones(self, Conv3DImpl, DummyInputLayer,
assert actual.shape == output.shape
assert np.allclose(actual, output)

except NotImplementedError:
except (NotImplementedError, RuntimeError):
pytest.skip()

def test_init_none_nonlinearity_bias(self, Conv3DImpl, DummyInputLayer):
Expand Down
Loading

0 comments on commit 7b4d1e2

Please sign in to comment.