Skip to content

Commit

Permalink
roi_pooling forward with cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
longcw committed Feb 5, 2017
1 parent 106127f commit a61a53a
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ by Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun.
- [x] forward pass for detecting
- [x] using models trained by Tensorflow
- [x] roi pooling layer implemented by python and pytorch
- [ ] roi pooling layer with C extensions for pytorch
- [x] roi pooling layer with C extensions for pytorch
- [ ] backward pass for training

### Installation and demo
Expand Down
5 changes: 0 additions & 5 deletions faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,6 @@ def forward(self, image):
roi_pooling_t = t.toc()
print('roi pooling spend: {}s'.format(roi_pooling_t))

print torch.max(pooled_features)
print torch.min(pooled_features)
print pooled_features.size()
print torch.max(pooled_features, 1)[0][0]

x = pooled_features.view(pooled_features.size()[0], -1)
x = self.fc6(x)
x = self.fc7(x)
Expand Down
16 changes: 7 additions & 9 deletions faster_rcnn/roi_pooling/functions/roi_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,28 @@ def __init__(self, pooled_height, pooled_width, spatial_scale):
self.pooled_width = int(pooled_width)
self.pooled_height = int(pooled_height)
self.spatial_scale = float(spatial_scale)
self.output = None
self.argmax = None

def forward(self, features, rois):
batch_size, num_channels, data_height, data_width = features.size()
num_rois = rois.size()[0]
output = torch.zeros(num_rois, num_channels, self.pooled_height, self.pooled_width)
argmax = torch.IntTensor(num_rois, num_channels, self.pooled_height, self.pooled_width).zero_()
# output, argmax = output.permute(0, 2, 3, 1), argmax.permute(0, 2, 3, 1)
_features = features.permute(0, 2, 3, 1)

if not features.is_cuda:
_features = features.permute(0, 2, 3, 1)
roi_pooling.roi_pooling_forward(self.pooled_height, self.pooled_width, self.spatial_scale,
_features.cpu(), rois.cpu(), output.cpu())
_features, rois, output)
output = output.cuda()
else:
# TODO: cuda
print('cuda:')
output = output.cuda()
argmax = argmax.cuda()
roi_pooling.roi_pooling_forward_cuda(self.pooled_height, self.pooled_width, self.spatial_scale,
_features, rois, output, argmax)
features, rois, output, argmax)
self.output = output
self.argmax = argmax

print argmax.cpu().numpy()[0, 25]
# print output
# output = output.permute(0, 3, 1, 2)
return output

def backward(self, grad_output):
Expand Down
24 changes: 9 additions & 15 deletions faster_rcnn/roi_pooling/src/cuda/roi_pooling_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
#if GOOGLE_CUDA


#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -71,17 +68,18 @@ __global__ void ROIPoolForward(const int nthreads, const float* bottom_data,
int maxidx = -1;
bottom_data += roi_batch_ind * channels * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int bottom_index = (h * width + w) * channels + c;
if (bottom_data[bottom_index] > maxval) {
maxval = bottom_data[bottom_index];
maxidx = bottom_index;
for (int w = wstart; w < wend; ++w) {
// int bottom_index = (h * width + w) * channels + c;
int bottom_index = (c * height + h) * width + w;
if (bottom_data[bottom_index] > maxval) {
maxval = bottom_data[bottom_index];
maxidx = bottom_index;
}
}
}
}
top_data[index] = maxval;
if (argmax_data != NULL)
argmax_data[index] = maxidx;
argmax_data[index] = maxidx;
}
}

Expand All @@ -97,8 +95,7 @@ int ROIPoolForwardLaucher(
const int output_size = num_rois * pooled_height * pooled_width * channels;
cudaError_t err;

ROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
kThreadsPerBlock>>>(
ROIPoolForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock>>>(
output_size, bottom_data, spatial_scale, height, width, channels, pooled_height,
pooled_width, bottom_rois, top_data, argmax_data);

Expand Down Expand Up @@ -222,6 +219,3 @@ int ROIPoolForwardLaucher(
#endif



#endif // GOOGLE_CUDA

6 changes: 3 additions & 3 deletions faster_rcnn/roi_pooling/src/roi_pooling_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ int roi_pooling_forward_cuda(int pooled_height, int pooled_width, float spatial_
return 0;
}
// data height
int data_height = THCudaTensor_size(state, features, 1);
int data_height = THCudaTensor_size(state, features, 2);
// data width
int data_width = THCudaTensor_size(state, features, 2);
int data_width = THCudaTensor_size(state, features, 3);
// Number of channels
int num_channels = THCudaTensor_size(state, features, 3);
int num_channels = THCudaTensor_size(state, features, 1);

ROIPoolForwardLaucher(
data_flat, spatial_scale, num_rois, data_height,
Expand Down

0 comments on commit a61a53a

Please sign in to comment.