-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
C extension for roi pooling forward on cpu
- Loading branch information
Showing
12 changed files
with
238 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -92,4 +92,4 @@ ENV/ | |
.ropeproject | ||
|
||
.idea | ||
|
||
extension-ffi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
from torch.utils.ffi import _wrap_function | ||
from ._roi_pooling import lib as _lib, ffi as _ffi | ||
|
||
__all__ = [] | ||
def _import_symbols(locals): | ||
for symbol in dir(_lib): | ||
fn = getattr(_lib, symbol) | ||
locals[symbol] = _wrap_function(fn, _ffi) | ||
__all__.append(symbol) | ||
|
||
_import_symbols(locals()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import os | ||
import torch | ||
from torch.utils.ffi import create_extension | ||
|
||
this_file = os.path.dirname(__file__) | ||
|
||
sources = ['src/roi_pooling.c'] | ||
headers = ['src/roi_pooling.h'] | ||
defines = [] | ||
with_cuda = False | ||
|
||
if torch.cuda.is_available() and False: | ||
print('Including CUDA code.') | ||
sources += ['src/my_lib_cuda.c'] | ||
headers += ['src/my_lib_cuda.h'] | ||
defines += [('WITH_CUDA', None)] | ||
with_cuda = True | ||
|
||
ffi = create_extension( | ||
'_ext.roi_pooling', | ||
headers=headers, | ||
sources=sources, | ||
define_macros=defines, | ||
relative_to=__file__, | ||
with_cuda=with_cuda | ||
) | ||
|
||
if __name__ == '__main__': | ||
ffi.build() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
from torch.autograd import Function | ||
from .._ext import roi_pooling | ||
|
||
|
||
class RoIPoolFunction(Function): | ||
def __init__(self, pooled_height, pooled_width, spatial_scale): | ||
self.pooled_width = int(pooled_width) | ||
self.pooled_height = int(pooled_height) | ||
self.spatial_scale = float(spatial_scale) | ||
|
||
def forward(self, features, rois): | ||
batch_size, num_channels, data_height, data_width = features.size() | ||
num_rois = rois.size()[0] | ||
output = torch.zeros(num_rois, num_channels, self.pooled_height, self.pooled_width) | ||
_features = features.permute(0, 2, 3, 1) | ||
if not features.is_cuda: | ||
roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale, | ||
_features, rois, output) | ||
else: | ||
# TODO: cuda | ||
roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale, | ||
_features.cpu(), rois.cpu(), output) | ||
output = output.cuda() | ||
|
||
return output | ||
|
||
def backward(self, grad_output): | ||
# TODO: roi_pooling backward | ||
# grad_input = grad_output.new() | ||
# if not grad_output.is_cuda: | ||
# my_lib.my_lib_add_backward(grad_output, grad_input) | ||
# else: | ||
# my_lib.my_lib_add_backward_cuda(grad_output, grad_input) | ||
# return grad_input | ||
return None |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from torch.nn.modules.module import Module | ||
# from functions.roi_pool import RoIPoolFunction | ||
from ..functions.roi_pool import RoIPoolFunction | ||
|
||
|
||
class RoIPool(Module): | ||
def __init__(self, pooled_height, pooled_width, spatial_scale): | ||
super(RoIPool, self).__init__() | ||
self.roi_pool = RoIPoolFunction(pooled_height, pooled_width, spatial_scale) | ||
|
||
def forward(self, features, rois): | ||
return self.roi_pool(features, rois) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#include <TH/TH.h> | ||
#include <math.h> | ||
|
||
int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, | ||
THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output) | ||
{ | ||
// Grab the input tensor | ||
float * data_flat = THFloatTensor_data(features); | ||
float * rois_flat = THFloatTensor_data(rois); | ||
|
||
float * output_flat = THFloatTensor_data(output); | ||
|
||
// Number of ROIs | ||
int num_rois = THFloatTensor_size(rois, 0); | ||
int size_rois = THFloatTensor_size(rois, 1); | ||
// batch size | ||
int batch_size = THFloatTensor_size(features, 0); | ||
if(batch_size != 1) | ||
{ | ||
return 0; | ||
} | ||
// data height | ||
int data_height = THFloatTensor_size(features, 1); | ||
// data width | ||
int data_width = THFloatTensor_size(features, 2); | ||
// Number of channels | ||
int num_channels = THFloatTensor_size(features, 3); | ||
|
||
// Set all element of the output tensor to -inf. | ||
THFloatStorage_fill(THFloatTensor_storage(output), -1); | ||
|
||
// For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R | ||
int index_roi = 0; | ||
int index_output = 0; | ||
int n; | ||
for (n = 0; n < num_rois; ++n) | ||
{ | ||
int roi_batch_ind = rois_flat[index_roi + 0]; | ||
int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale); | ||
int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale); | ||
int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale); | ||
int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale); | ||
// CHECK_GE(roi_batch_ind, 0); | ||
// CHECK_LT(roi_batch_ind, batch_size); | ||
|
||
int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1); | ||
int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1); | ||
float bin_size_h = (float)(roi_height) / (float)(pooled_height); | ||
float bin_size_w = (float)(roi_width) / (float)(pooled_width); | ||
|
||
int index_data = roi_batch_ind * data_height * data_width * num_channels; | ||
const int output_area = pooled_width * pooled_height; | ||
|
||
int c, ph, pw; | ||
for (ph = 0; ph < pooled_height; ++ph) | ||
{ | ||
for (pw = 0; pw < pooled_width; ++pw) | ||
{ | ||
int hstart = (floor((float)(ph) * bin_size_h)); | ||
int wstart = (floor((float)(pw) * bin_size_w)); | ||
int hend = (ceil((float)(ph + 1) * bin_size_h)); | ||
int wend = (ceil((float)(pw + 1) * bin_size_w)); | ||
|
||
hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height); | ||
hend = fminf(fmaxf(hend + roi_start_h, 0), data_height); | ||
wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width); | ||
wend = fminf(fmaxf(wend + roi_start_w, 0), data_width); | ||
|
||
const int pool_index = index_output + (ph * pooled_width + pw); | ||
int is_empty = (hend <= hstart) || (wend <= wstart); | ||
if (is_empty) | ||
{ | ||
for (c = 0; c < num_channels * output_area; c += output_area) | ||
{ | ||
output_flat[pool_index + c] = 0; | ||
} | ||
} | ||
else | ||
{ | ||
int h, w, c; | ||
for (h = hstart; h < hend; ++h) | ||
{ | ||
for (w = wstart; w < wend; ++w) | ||
{ | ||
for (c = 0; c < num_channels; ++c) | ||
{ | ||
const int index = (h * data_width + w) * num_channels + c; | ||
if (data_flat[index_data + index] > output_flat[pool_index + c * output_area]) | ||
{ | ||
output_flat[pool_index + c * output_area] = data_flat[index_data + index]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
|
||
// int ph; | ||
// for (ph = 0; ph < pooled_height; ++ph) | ||
// { | ||
// int pw; | ||
// for (pw = 0; pw < pooled_width; ++pw) | ||
// { | ||
// // Compute pooling region for this output unit: | ||
// // start (included) = floor(ph * roi_height / pooled_height_) | ||
// // end (excluded) = ceil((ph + 1) * roi_height / pooled_height_) | ||
// | ||
// | ||
// const int pool_index = index_output + (ph * pooled_width + pw) * num_channels; | ||
// if (is_empty) | ||
// { | ||
// int c; | ||
// for (c = 0; c < num_channels; ++c) | ||
// { | ||
// output_flat[pool_index + c] = 0; | ||
// } | ||
// } | ||
// | ||
// int h, w, c; | ||
// for (h = hstart; h < hend; ++h) | ||
// { | ||
// for (w = wstart; w < wend; ++w) | ||
// { | ||
// for (c = 0; c < num_channels; ++c) | ||
// { | ||
// const int index = (h * data_width + w) * num_channels + c; | ||
//// const int index = c * (data_width * data_height) + (h * data_width + w); | ||
// if (data_flat[index_data + index] > output_flat[pool_index + c]) | ||
// { | ||
// output_flat[pool_index + c] = data_flat[index_data + index]; | ||
// } | ||
// } | ||
// } | ||
// } | ||
// } | ||
// } | ||
// Increment ROI index | ||
index_roi += size_rois; | ||
index_output += pooled_height * pooled_width * num_channels; | ||
} | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale, | ||
THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output); |