Skip to content

Commit

Permalink
Add bias into SphereConv2D
Browse files Browse the repository at this point in the history
  • Loading branch information
sunset1995 committed Sep 22, 2018
1 parent 897ac4f commit 29a1358
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
15 changes: 11 additions & 4 deletions deform_conv/deform_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def __init__(self,
stride=1,
padding=0,
dilation=1,
num_deformable_groups=1):
num_deformable_groups=1,
bias=True):
super(ConvOffset2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
Expand All @@ -129,6 +130,9 @@ def __init__(self,

self.weight = nn.Parameter(
torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = None
if bias:
self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))

self.reset_parameters()

Expand All @@ -140,6 +144,9 @@ def reset_parameters(self):
self.weight.data.uniform_(-stdv, stdv)

def forward(self, input, offset):
return conv_offset2d(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
output = conv_offset2d(input, offset, self.weight, self.stride,
self.padding, self.dilation,
self.num_deformable_groups)
if self.bias is not None:
output = output + self.bias
return output
24 changes: 12 additions & 12 deletions sphere_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from deform_conv import ConvOffset2d


# Calculate kernels of SphereCNN
# Calculate kernels of SphereCNN
@lru_cache(None)
def get_xy(delta_phi, delta_theta):
return np.array([
Expand All @@ -32,7 +32,7 @@ def get_xy(delta_phi, delta_theta):

@lru_cache(None)
def cal_index(h, w, img_r, img_c):
'''
'''
Calculate Kernel Sampling Pattern
only support 3x3 filter
return 9 locations: (3, 3, 2)
Expand Down Expand Up @@ -105,19 +105,19 @@ def map_coordinates(input, coordinates, mode='bilinear', pad='wrap', slice_mode=
coordinates.requires_grad = False
h = input.shape[2]
w = input.shape[3]

def _coordinates_pad_wrap(h, w, coordinates):
coordinates[0] = coordinates[0] % h
coordinates[1] = coordinates[1] % w
return coordinates

def _coordinates_pad_zero(h, w, coordinates):
out_of_bound_h = (coordinates[0] < 0) | (coordinates[0] > (h-1))
out_of_bound_w = (coordinates[1] < 0) | (coordinates[1] > (w-1))
coordinates[0, out_of_bound_h] = h
coordinates[1, out_of_bound_w] = w
return coordinates

if mode == 'nearest':
coordinates = torch.round(coordinates).long()
if pad == 'wrap':
Expand Down Expand Up @@ -154,12 +154,12 @@ class SphereConv2D(nn.Module):
''' SphereConv2D
Note that this layer only support 3x3 filter
'''
def __init__(self, in_c, out_c, stride=1):
def __init__(self, in_c, out_c, stride=1, bias=True):
super(SphereConv2D, self).__init__()
self.stride = stride
self.conv = ConvOffset2d(in_c, out_c, kernel_size=3, padding=1, num_deformable_groups=1)
self.conv = ConvOffset2d(in_c, out_c, kernel_size=3, padding=1, num_deformable_groups=1, bias=bias)
self.offset = None

def forward(self, x):
# x: (B, C, H, W)
if self.offset is None or self.offset.shape[0] != x.shape[0] or\
Expand All @@ -186,7 +186,7 @@ def __init__(self, stride=1, mode='bilinear'):
super(SphereMaxPool2D, self).__init__()
self.mode = mode
self.stride = stride

def forward(self, x):
# x: (B, C, H, W)
with torch.no_grad():
Expand All @@ -205,11 +205,11 @@ def forward(self, x):
ridx = ridx[None, None, :, None]
cidx = cidx[None, None, None, :]
max_coord = coordinates[:, ridx, cidx, maxi]

return map_coordinates(x, max_coord, mode=self.mode, slice_mode='points')

if __name__ == '__main__':

if __name__ == '__main__':
# test cnn
cnn = SphereConv2D(3, 5, 1)
out = cnn(torch.randn(2, 3, 10, 10))
Expand Down

0 comments on commit 29a1358

Please sign in to comment.