Skip to content

Commit

Permalink
Merge pull request BVLC#2038 from shelhamer/cudnn-r2
Browse files Browse the repository at this point in the history
cuDNN R2

* shelhamer/cudnn-r2:
  cuDNN pooling can pad now
  replace cuDNN alphas and betas with coefficient values
  switch to cuDNN R2
  • Loading branch information
myfavouritekk committed Mar 16, 2015
2 parents a6d8640 + 4beebcc commit c4837dd
Show file tree
Hide file tree
Showing 18 changed files with 190 additions and 97 deletions.
4 changes: 2 additions & 2 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down
12 changes: 6 additions & 6 deletions include/caffe/neuron_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down Expand Up @@ -516,8 +516,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down Expand Up @@ -601,8 +601,8 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down
34 changes: 19 additions & 15 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,45 @@ template <typename Dtype> class dataType;
template<> class dataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static float oneval, zeroval;
static const void *one, *zero;
};
template<> class dataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static double oneval, zeroval;
static const void *one, *zero;
};

template <typename Dtype>
inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
setTensor4dDesc<Dtype>(desc, n, c, h, w,
stride_n, stride_c, stride_h, stride_w);
stride_n, stride_c, stride_h, stride_w);
}

template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType<Dtype>::type,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
}

Expand All @@ -95,29 +99,29 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {

template <typename Dtype>
inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
int h, int w, int stride_h, int stride_w) {
int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
*mode = CUDNN_POOLING_AVERAGE;
*mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv));
CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w,
stride_h, stride_w));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
}

} // namespace cudnn
Expand Down
8 changes: 5 additions & 3 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,11 +291,13 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;
vector<cudnnTensor4dDescriptor_t> bottom_descs_, top_descs_;
cudnnTensor4dDescriptor_t bias_desc_;
vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
size_t workspaceSizeInBytes;
void *workspace;
};
#endif

Expand Down Expand Up @@ -534,7 +536,7 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_, top_desc_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
};
Expand Down
13 changes: 8 additions & 5 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];

workspace = NULL;
workspaceSizeInBytes = (size_t)0;

for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
CUDA_CHECK(cudaStreamCreate(&stream_[g]));
CUDNN_CHECK(cudnnCreate(&handle_[g]));
Expand All @@ -43,10 +46,10 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(

// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensor4dDescriptor_t bottom_desc;
cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
cudnnTensor4dDescriptor_t top_desc;
cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
Expand Down Expand Up @@ -104,12 +107,12 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
if (!handles_setup_) { return; }

for (int i = 0; i < bottom_descs_.size(); i++) {
cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
cudnnDestroyTensor4dDescriptor(top_descs_[i]);
cudnnDestroyTensorDescriptor(bottom_descs_[i]);
cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
cudnnDestroyTensor4dDescriptor(bias_desc_);
cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);

Expand Down
83 changes: 61 additions & 22 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,57 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(

// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
cudnnConvolutionFwdAlgo_t algo;

// get the desired convolution algorithm
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g],
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
top_descs_[i],
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0, // memoryLimitInBytes,
&algo));

// get minimum size of the workspace needed for the desired algorithm
size_t workspaceSizeInBytes_temp = 0;

CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g],
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
top_descs_[i],
algo,
&workspaceSizeInBytes_temp));

if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
workspaceSizeInBytes = workspaceSizeInBytes_temp;
// free the existing workspace and allocate a new (larger) one
if (this->workspace != NULL) {
cudaFree(this->workspace);
}
cudaMalloc(&(this->workspace), workspaceSizeInBytes);
CUDA_POST_KERNEL_CHECK;
}

// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
top_descs_[i], top_data + top_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
algo, workspace, workspaceSizeInBytes,
cudnn::dataType<Dtype>::zero,
top_descs_[i], top_data + top_offset_ * g));

// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
Dtype alpha = 1.;
CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
bias_desc_, bias_data + bias_offset_ * g,
top_descs_[i], top_data + top_offset_ * g));
CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
cudnn::dataType<Dtype>::one,
top_descs_[i], top_data + top_offset_ * g));
}
}

Expand Down Expand Up @@ -68,20 +104,22 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
top_descs_[i], top_diff + top_offset_ * g,
bias_desc_, bias_diff + bias_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
cudnn::dataType<Dtype>::one,
top_descs_[i], top_diff + top_offset_ * g,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_diff + bias_offset_ * g));
}

// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
filter_desc_, weight_diff + weight_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
cudnn::dataType<Dtype>::one,
filter_desc_, weight_diff + weight_offset_ * g));
}

// Gradient w.r.t. bottom data.
Expand All @@ -91,11 +129,12 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
bottom_descs_[i], bottom_diff + bottom_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
cudnn::dataType<Dtype>::one,
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
cudnn::dataType<Dtype>::zero,
bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}

Expand Down
10 changes: 4 additions & 6 deletions src/caffe/layers/cudnn_pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingLayer<Dtype>::LayerSetUp(bottom, top);
// Sanity check: CUDNN currently only supports pad == 0.
CHECK_EQ(this->pad_h_, 0);
CHECK_EQ(this->pad_w_, 0);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_,
this->stride_h_, this->stride_w_);
handles_setup_ = true;
}

Expand All @@ -40,8 +38,8 @@ CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
Expand Down
12 changes: 9 additions & 3 deletions src/caffe/layers/cudnn_pooling_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ void CuDNNPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
bottom_desc_, bottom_data, top_desc_, top_data));
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
top_desc_, top_data));
}

template <typename Dtype>
Expand All @@ -29,8 +32,11 @@ void CuDNNPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data, bottom_desc_, bottom_diff));
cudnn::dataType<Dtype>::one,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
bottom_desc_, bottom_diff));
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}

Expand Down
Loading

0 comments on commit c4837dd

Please sign in to comment.