-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
04da88f
commit 7fc819e
Showing
11 changed files
with
1,128 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .deform_conv import ConvOffset2d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import os | ||
import torch | ||
from torch.utils.ffi import create_extension | ||
|
||
this_file = os.path.dirname(__file__) | ||
|
||
sources = ['src/deform_conv.c'] | ||
headers = ['src/deform_conv.h'] | ||
defines = [] | ||
with_cuda = False | ||
|
||
if torch.cuda.is_available(): | ||
print('Including CUDA code.') | ||
sources += ['src/deform_conv_cuda.c'] | ||
headers += ['src/deform_conv_cuda.h'] | ||
defines += [('WITH_CUDA', None)] | ||
with_cuda = True | ||
|
||
this_file = os.path.dirname(os.path.realpath(__file__)) | ||
print(this_file) | ||
extra_objects = ['src/deform_conv_cuda_kernel.cu.o'] | ||
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] | ||
|
||
ffi = create_extension( | ||
'_ext.deform_conv', | ||
headers=headers, | ||
sources=sources, | ||
define_macros=defines, | ||
relative_to=__file__, | ||
with_cuda=with_cuda, | ||
extra_objects=extra_objects | ||
) | ||
|
||
if __name__ == '__main__': | ||
assert torch.cuda.is_available(), 'Please install CUDA for GPU support.' | ||
ffi.build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Function | ||
from torch.nn.modules.module import Module | ||
from torch.nn.modules.utils import _pair | ||
|
||
from ._ext import deform_conv | ||
|
||
|
||
def conv_offset2d(input, | ||
offset, | ||
weight, | ||
stride=1, | ||
padding=0, | ||
dilation=1, | ||
deform_groups=1): | ||
|
||
if input is not None and input.dim() != 4: | ||
raise ValueError( | ||
"Expected 4D tensor as input, got {}D tensor instead.".format( | ||
input.dim())) | ||
|
||
f = ConvOffset2dFunction( | ||
_pair(stride), _pair(padding), _pair(dilation), deform_groups) | ||
return f(input, offset, weight) | ||
|
||
|
||
class ConvOffset2dFunction(Function): | ||
def __init__(self, stride, padding, dilation, deformable_groups=1): | ||
super(ConvOffset2dFunction, self).__init__() | ||
self.stride = stride | ||
self.padding = padding | ||
self.dilation = dilation | ||
self.deformable_groups = deformable_groups | ||
|
||
def forward(self, input, offset, weight): | ||
self.save_for_backward(input, offset, weight) | ||
|
||
output = input.new(*self._output_size(input, weight)) | ||
|
||
self.bufs_ = [input.new(), input.new()] # columns, ones | ||
|
||
if not input.is_cuda: | ||
raise NotImplementedError | ||
else: | ||
if isinstance(input, torch.autograd.Variable): | ||
if not isinstance(input.data, torch.cuda.FloatTensor): | ||
raise NotImplementedError | ||
else: | ||
if not isinstance(input, torch.cuda.FloatTensor): | ||
raise NotImplementedError | ||
deform_conv.deform_conv_forward_cuda( | ||
input, weight, offset, output, self.bufs_[0], self.bufs_[1], | ||
weight.size(3), weight.size(2), self.stride[1], self.stride[0], | ||
self.padding[1], self.padding[0], self.dilation[1], | ||
self.dilation[0], self.deformable_groups) | ||
return output | ||
|
||
def backward(self, grad_output): | ||
input, offset, weight = self.saved_tensors | ||
|
||
grad_input = grad_offset = grad_weight = None | ||
|
||
if not grad_output.is_cuda: | ||
raise NotImplementedError | ||
else: | ||
if isinstance(grad_output, torch.autograd.Variable): | ||
if not isinstance(grad_output.data, torch.cuda.FloatTensor): | ||
raise NotImplementedError | ||
else: | ||
if not isinstance(grad_output, torch.cuda.FloatTensor): | ||
raise NotImplementedError | ||
if self.needs_input_grad[0] or self.needs_input_grad[1]: | ||
grad_input = input.new(*input.size()).zero_() | ||
grad_offset = offset.new(*offset.size()).zero_() | ||
deform_conv.deform_conv_backward_input_cuda( | ||
input, offset, grad_output, grad_input, | ||
grad_offset, weight, self.bufs_[0], weight.size(3), | ||
weight.size(2), self.stride[1], self.stride[0], | ||
self.padding[1], self.padding[0], self.dilation[1], | ||
self.dilation[0], self.deformable_groups) | ||
|
||
if self.needs_input_grad[2]: | ||
grad_weight = weight.new(*weight.size()).zero_() | ||
deform_conv.deform_conv_backward_parameters_cuda( | ||
input, offset, grad_output, | ||
grad_weight, self.bufs_[0], self.bufs_[1], weight.size(3), | ||
weight.size(2), self.stride[1], self.stride[0], | ||
self.padding[1], self.padding[0], self.dilation[1], | ||
self.dilation[0], self.deformable_groups, 1) | ||
|
||
return grad_input, grad_offset, grad_weight | ||
|
||
def _output_size(self, input, weight): | ||
channels = weight.size(0) | ||
|
||
output_size = (input.size(0), channels) | ||
for d in range(input.dim() - 2): | ||
in_size = input.size(d + 2) | ||
pad = self.padding[d] | ||
kernel = self.dilation[d] * (weight.size(d + 2) - 1) + 1 | ||
stride = self.stride[d] | ||
output_size += ((in_size + (2 * pad) - kernel) // stride + 1, ) | ||
if not all(map(lambda s: s > 0, output_size)): | ||
raise ValueError( | ||
"convolution input is too small (output would be {})".format( | ||
'x'.join(map(str, output_size)))) | ||
return output_size | ||
|
||
|
||
class ConvOffset2d(Module): | ||
def __init__(self, | ||
in_channels, | ||
out_channels, | ||
kernel_size, | ||
stride=1, | ||
padding=0, | ||
dilation=1, | ||
num_deformable_groups=1): | ||
super(ConvOffset2d, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.kernel_size = _pair(kernel_size) | ||
self.stride = _pair(stride) | ||
self.padding = _pair(padding) | ||
self.dilation = _pair(dilation) | ||
self.num_deformable_groups = num_deformable_groups | ||
|
||
self.weight = nn.Parameter( | ||
torch.Tensor(out_channels, in_channels, *self.kernel_size)) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
n = self.in_channels | ||
for k in self.kernel_size: | ||
n *= k | ||
stdv = 1. / math.sqrt(n) | ||
self.weight.data.uniform_(-stdv, stdv) | ||
|
||
def forward(self, input, offset): | ||
return conv_offset2d(input, offset, self.weight, self.stride, | ||
self.padding, self.dilation, | ||
self.num_deformable_groups) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
cd src | ||
nvcc -c -o deform_conv_cuda_kernel.cu.o deform_conv_cuda_kernel.cu -x cu -Xcompiler -fPIC -std=c++11 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#include <TH/TH.h> | ||
|
||
int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, | ||
THFloatTensor *output) | ||
{ | ||
// if (!THFloatTensor_isSameSizeAs(input1, input2)) | ||
// return 0; | ||
// THFloatTensor_resizeAs(output, input); | ||
// THFloatTensor_cadd(output, input1, 1.0, input2); | ||
return 1; | ||
} | ||
|
||
int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, | ||
THFloatTensor *grad_offset) | ||
{ | ||
// THFloatTensor_resizeAs(grad_input, grad_output); | ||
// THFloatTensor_fill(grad_input, 1); | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
int deform_conv_forward(THFloatTensor *input, THFloatTensor *offset, | ||
THFloatTensor *output); | ||
int deform_conv_backward(THFloatTensor *grad_output, THFloatTensor *grad_input, | ||
THFloatTensor *grad_offset); |
Oops, something went wrong.