Skip to content

Commit

Permalink
Implement with grid sample
Browse files Browse the repository at this point in the history
  • Loading branch information
sunset1995 committed Sep 30, 2018
1 parent 2fb8e98 commit 9562065
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions sphere_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,21 @@ class SphereConv2D(nn.Module):
def __init__(self, in_c, out_c, stride=1, bias=True):
super(SphereConv2D, self).__init__()
self.stride = stride
self.conv = ConvOffset2d(in_c, out_c, kernel_size=3, padding=1, num_deformable_groups=1, bias=bias)
self.offset = None
self.conv = nn.Conv2d(in_c, out_c, kernel_size=3, stride=3)

def forward(self, x):
# x: (B, C, H, W)
if self.offset is None or self.offset.shape[0] != x.shape[0] or\
self.offset.shape[-2:] != x.shape[-2:]:
coordinates = gen_filters_coordinates(x.shape[2], x.shape[3], self.stride)
oriidx = np.stack(np.meshgrid(range(x.size(3)), range(x.size(2)))[::-1], 0)
oriidx = oriidx[..., None, None]
offset = coordinates - oriidx
offset = offset.transpose(0, 3, 4, 1, 2)
offset = offset.reshape(1, 2 * 3 * 3, x.size(2), x.size(3))
offset = offset.repeat(x.size(0), axis=0)
self.offset = torch.tensor(offset, dtype=x.dtype, device=x.device)
self.offset.requires_grad = False

return self.conv(x, self.offset)
coordinates = gen_filters_coordinates(x.shape[2], x.shape[3], self.stride).copy()
coordinates[0] = (coordinates[0] * 2 / x.shape[2]) - 1
coordinates[1] = (coordinates[1] * 2 / x.shape[3]) - 1
coordinates = coordinates[::-1]
coordinates = coordinates.transpose(1, 3, 2, 4, 0)
coordinates = coordinates.reshape(x.shape[2]*3, x.shape[3]*3, 2)
grid = torch.FloatTensor(coordinates).to(x.device)
grid = grid.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)

x = nn.functional.grid_sample(x, grid)
return self.conv(x)


class SphereMaxPool2D(nn.Module):
Expand Down

0 comments on commit 9562065

Please sign in to comment.