Skip to content

Commit

Permalink
cudnn: create general tensor, set 4D tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed Jan 15, 2015
1 parent 25e3748 commit 079ff55
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 33 deletions.
10 changes: 5 additions & 5 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,20 @@ template<> class dataType<double> {
};

template <typename Dtype>
inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
inline void createTensorDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
Expand All @@ -95,7 +95,7 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {

template <typename Dtype>
inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
Expand Down
16 changes: 8 additions & 8 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(

// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensor4dDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensorDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
cudnnTensor4dDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
cudnnTensorDescriptor_t top_desc;
cudnn::createTensorDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
cudnn::createConvolutionDesc<Dtype>(&conv_desc);
Expand All @@ -56,7 +56,7 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(

// Tensor descriptor for bias.
if (this->bias_term_) {
cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
cudnn::createTensorDesc<Dtype>(&bias_desc_);
}
}

Expand Down Expand Up @@ -99,12 +99,12 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
template <typename Dtype>
CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
for (int i = 0; i < bottom_descs_.size(); i++) {
cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
cudnnDestroyTensor4dDescriptor(top_descs_[i]);
cudnnDestroyTensorDescriptor(bottom_descs_[i]);
cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
cudnnDestroyTensor4dDescriptor(bias_desc_);
cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);

Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/cudnn_pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CHECK_EQ(this->pad_h_, 0);
CHECK_EQ(this->pad_w_, 0);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
Expand All @@ -36,8 +36,8 @@ void CuDNNPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
ReLULayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
}

template <typename Dtype>
Expand All @@ -31,8 +31,8 @@ void CuDNNReLULayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}

Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/cudnn_sigmoid_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
SigmoidLayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
}

template <typename Dtype>
Expand All @@ -31,8 +31,8 @@ void CuDNNSigmoidLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}

Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/cudnn_softmax_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ void CuDNNSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
SoftmaxLayer<Dtype>::LayerSetUp(bottom, top);
// Initialize CUDNN.
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
}

template <typename Dtype>
Expand All @@ -35,8 +35,8 @@ void CuDNNSoftmaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
CuDNNSoftmaxLayer<Dtype>::~CuDNNSoftmaxLayer() {
cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroy(handle_);
}

Expand Down
8 changes: 4 additions & 4 deletions src/caffe/layers/cudnn_tanh_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
TanHLayer<Dtype>::LayerSetUp(bottom, top);
// initialize cuDNN
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
}

template <typename Dtype>
Expand All @@ -31,8 +31,8 @@ void CuDNNTanHLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,

template <typename Dtype>
CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {
cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}

Expand Down

0 comments on commit 079ff55

Please sign in to comment.