-
Notifications
You must be signed in to change notification settings - Fork 18.7k
/
cudnn_ndpooling_layer.cpp
112 lines (95 loc) · 3.81 KB
/
cudnn_ndpooling_layer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#ifdef USE_CUDNN
#include <vector>
#include "caffe/filler.hpp"
#include "caffe/layer.hpp"
#include "caffe/util/im2col.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"
namespace caffe {
template <typename Dtype>
void CudnnNdPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingParameter pool_param = this->layer_param_.pooling_param();
CHECK(pool_param.has_kernel_shape()
&& pool_param.has_pad_shape()
&& pool_param.has_stride_shape())
<< "Kernel, Pad and Stride shape required.";
CHECK_EQ(pool_param.kernel_shape().dim_size(), pool_param.pad_shape().dim_size())
<< "Kernel and Pad shape don't match !";
CHECK_EQ(pool_param.kernel_shape().dim_size(), pool_param.stride_shape().dim_size())
<< "Kernel and Stride shape don't match !";
global_pooling_ = pool_param.global_pooling();
if(global_pooling_) {
kernel_shape_ = vector<int>(bottom[0]->shape().begin()+2, bottom[0]->shape().end());
} else {
for(int i = 0; i < pool_param.kernel_shape().dim_size(); ++i) {
kernel_shape_.push_back(pool_param.kernel_shape().dim(i));
CHECK_GT(kernel_shape_[i], 0) << "Filter dimensions cannot be zero.";
}
}
for(int i = 0; i < pool_param.kernel_shape().dim_size(); ++i) {
pad_shape_.push_back(pool_param.pad_shape().dim(i));
stride_shape_.push_back(pool_param.stride_shape().dim(i));
}
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
cudnn::createNdPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
kernel_shape_, pad_shape_, stride_shape_);
handles_setup_ = true;
}
template <typename Dtype>
void CudnnNdPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
channels_ = bottom[0]->channels();
input_shape_ = bottom[0]->shape();
if(global_pooling_) {
kernel_shape_ = vector<int>(bottom[0]->shape().begin()+2, bottom[0]->shape().end());
}
pooled_shape_ = input_shape_;
for(int i = 2; i < pooled_shape_.size(); ++i) {
pooled_shape_[i] += 2 * pad_shape_[i-2] - kernel_shape_[i-2];
pooled_shape_[i] /= stride_shape_[i-2];
++pooled_shape_[i];
if(pad_shape_[i-2] > 0) {
if ((pooled_shape_[i] - 1) * stride_shape_[i-2] >= input_shape_[i] + pad_shape_[i-2]) {
--pooled_shape_[i];
}
CHECK_LT((pooled_shape_[i] - 1) * stride_shape_[i-2], input_shape_[i] + pad_shape_[i-2]);
}
}
top[0]->Reshape(pooled_shape_);
// If max pooling, we will initialize the vector index part.
if (this->layer_param_.pooling_param().pool() ==
PoolingParameter_PoolMethod_MAX && top.size() == 1) {
max_idx_.Reshape(pooled_shape_);
}
// If stochastic pooling, we will initialize the random index part.
if (this->layer_param_.pooling_param().pool() ==
PoolingParameter_PoolMethod_STOCHASTIC) {
rand_idx_.Reshape(pooled_shape_);
}
cudnn::setTensorNdDesc<Dtype>(&bottom_desc_, input_shape_);
cudnn::setTensorNdDesc<Dtype>(&top_desc_, pooled_shape_);
}
template <typename Dtype>
void CudnnNdPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>* >& bottom, const vector<Blob<Dtype>* >& top) {
NOT_IMPLEMENTED;
}
template <typename Dtype>
void CudnnNdPoolingLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>* >& bottom, const vector<bool>& propagate_down, const vector<Blob<Dtype>* >& top) {
NOT_IMPLEMENTED;
}
template <typename Dtype>
CudnnNdPoolingLayer<Dtype>::~CudnnNdPoolingLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
INSTANTIATE_CLASS(CudnnNdPoolingLayer);
} // namespace caffe
#endif