Skip to content

Commit

Permalink
compatibility to cudnn-v5
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian Payer committed May 3, 2016
1 parent 46261f7 commit 6028022
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 7 deletions.
3 changes: 3 additions & 0 deletions include/caffe/layers/cudnn_relu_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnActivationDescriptor_t activation_desc_;
#endif
};
#endif

Expand Down
3 changes: 3 additions & 0 deletions include/caffe/layers/cudnn_sigmoid_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnActivationDescriptor_t activation_desc_;
#endif
};
#endif

Expand Down
3 changes: 3 additions & 0 deletions include/caffe/layers/cudnn_tanh_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnActivationDescriptor_t activation_desc_;
#endif
};
#endif

Expand Down
32 changes: 27 additions & 5 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,20 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
n, c, h, w));
}

template <typename Dtype>
inline void createNdFilterDesc(cudnnFilterDescriptor_t* desc,
std::vector<int> shape) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data()));
CUDNN_TENSOR_NCHW, shape.size(), shape.data()));
#else
CUDNN_CHECK(cudnnSetFilterNdDescriptor(*desc, dataType<Dtype>::type,
shape.size(), shape.data()));
#endif
}

template <typename Dtype>
Expand All @@ -149,7 +154,7 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
Expand All @@ -159,16 +164,22 @@ inline void setNdConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
int nbDims;
std::vector<int> shape(pad.size() + 2);
cudnnDataType_t cudnn_type;
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnTensorFormat_t tensor_format;
cudnnGetFilterNdDescriptor(filter,
shape.size(), &cudnn_type, &tensor_format, &nbDims, shape.data());
#else
cudnnGetFilterNdDescriptor(filter,
shape.size(), &cudnn_type, &nbDims, shape.data());
#endif
CHECK_EQ(nbDims, pad.size() + 2)
<< "Dimensions of filters and pad don't match !";
CHECK_EQ(nbDims, stride.size() + 2)
<< "Dimensions of filters and stride don't match !";
std::vector<int> upscale(pad.size(), 1);
CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
pad.size(), pad.data(), stride.data(), upscale.data(),
CUDNN_CROSS_CORRELATION, cudnn_type));
pad.size(), pad.data(), stride.data(), upscale.data(),
CUDNN_CROSS_CORRELATION, cudnn_type));
}

template <typename Dtype>
Expand All @@ -186,8 +197,13 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w, pad_h, pad_w, stride_h, stride_w));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
#endif
}

template <typename Dtype>
Expand All @@ -210,8 +226,14 @@ inline void createNdPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, shape.size(), shape.data(), pad.data(),
stride.data()));
#else
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode, shape.size(),
shape.data(), pad.data(), stride.data()));
#endif
}

} // namespace cudnn
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
CUDNN_CHECK(cudnnConvolutionBackwardFilter(
handle_[1*this->group_ + g],
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
Expand All @@ -100,7 +100,7 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
weight = this->blobs_[0]->gpu_data();
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
CUDNN_CHECK(cudnnConvolutionBackwardData(
handle_[2*this->group_ + g],
cudnn::dataType<Dtype>::one,
filter_desc_, weight + this->weight_offset_ * g,
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensorDesc<Dtype>(&bottom_desc_);
cudnn::createTensorDesc<Dtype>(&top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnCreateActivationDescriptor(&activation_desc_);
cudnnSetActivationDescriptor(activation_desc_,
CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0);
#endif
handles_setup_ = true;
}

Expand All @@ -31,6 +36,9 @@ CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {

cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnDestroyActivationDescriptor(this->activation_desc_);
#endif
cudnnDestroy(this->handle_);
}

Expand Down
19 changes: 19 additions & 0 deletions src/caffe/layers/cudnn_relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@ void CuDNNReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,

const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_RELU,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -40,13 +49,23 @@ void CuDNNReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_RELU,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/layers/cudnn_sigmoid_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnCreateActivationDescriptor(&activation_desc_);
cudnnSetActivationDescriptor(activation_desc_,
CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0);
#endif
handles_setup_ = true;
}

Expand All @@ -35,6 +40,9 @@ CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {

cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnDestroyActivationDescriptor(this->activation_desc_);
#endif
cudnnDestroy(this->handle_);
}

Expand Down
19 changes: 19 additions & 0 deletions src/caffe/layers/cudnn_sigmoid_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ void CuDNNSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -30,13 +39,23 @@ void CuDNNSigmoidLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_SIGMOID,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer);
Expand Down
8 changes: 8 additions & 0 deletions src/caffe/layers/cudnn_tanh_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnCreateActivationDescriptor(&activation_desc_);
cudnnSetActivationDescriptor(activation_desc_,
CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0);
#endif
handles_setup_ = true;
}

Expand All @@ -35,6 +40,9 @@ CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {

cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
#if CUDNN_VERSION_MIN(5, 0, 0)
cudnnDestroyActivationDescriptor(this->activation_desc_);
#endif
cudnnDestroy(this->handle_);
}

Expand Down
19 changes: 19 additions & 0 deletions src/caffe/layers/cudnn_tanh_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,21 @@ void CuDNNTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward(this->handle_,
CUDNN_ACTIVATION_TANH,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}

template <typename Dtype>
Expand All @@ -31,13 +40,23 @@ void CuDNNTanHLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();

#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
this->activation_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward(this->handle_,
CUDNN_ACTIVATION_TANH,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);
Expand Down

0 comments on commit 6028022

Please sign in to comment.