Skip to content

Commit

Permalink
Merge pull request #5924 from bowang/cudnn_deconv
Browse files Browse the repository at this point in the history
speed-up: add cuDNN deconvolution layer and test
  • Loading branch information
shelhamer committed Jan 29, 2018
2 parents bb4ffa4 + fb31463 commit 7c573ca
Show file tree
Hide file tree
Showing 6 changed files with 839 additions and 1 deletion.
68 changes: 68 additions & 0 deletions include/caffe/layers/cudnn_deconv_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#ifndef CAFFE_CUDNN_DECONV_LAYER_HPP_
#define CAFFE_CUDNN_DECONV_LAYER_HPP_

#include <vector>

#include "caffe/blob.hpp"
#include "caffe/layer.hpp"
#include "caffe/proto/caffe.pb.h"

#include "caffe/layers/deconv_layer.hpp"

namespace caffe {

#ifdef USE_CUDNN
/*
* @brief cuDNN implementation of DeConvolutionLayer.
* Fallback to DeConvolutionLayer for CPU mode.
*
* cuDNN accelerates deconvolution through forward kernels for filtering and
* bias plus backward kernels for the gradient w.r.t. the filters, biases, and
* inputs. Caffe + cuDNN further speeds up the computation through forward
* parallelism across groups and backward parallelism across gradients.
*/
template <typename Dtype>
class CuDNNDeconvolutionLayer : public DeconvolutionLayer<Dtype> {
public:
explicit CuDNNDeconvolutionLayer(const LayerParameter& param)
: DeconvolutionLayer<Dtype>(param), handles_setup_(false) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual ~CuDNNDeconvolutionLayer();

protected:
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down,
const vector<Blob<Dtype>*>& bottom);

bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;

// algorithms for forward and backwards convolutions
cudnnConvolutionFwdAlgo_t *fwd_algo_;
cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;

vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, bias_offset_;

size_t *workspace_fwd_sizes_;
size_t *workspace_bwd_data_sizes_;
size_t *workspace_bwd_filter_sizes_;
size_t workspaceSizeInBytes; // size of underlying storage
void *workspaceData; // underlying storage
void **workspace; // aliases into workspaceData
};
#endif

} // namespace caffe

#endif // CAFFE_CUDNN_DECONV_LAYER_HPP_
41 changes: 41 additions & 0 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
#include "caffe/layers/conv_layer.hpp"
#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
#include "caffe/layers/pooling_layer.hpp"
#include "caffe/layers/relu_layer.hpp"
Expand All @@ -18,6 +19,7 @@

#ifdef USE_CUDNN
#include "caffe/layers/cudnn_conv_layer.hpp"
#include "caffe/layers/cudnn_deconv_layer.hpp"
#include "caffe/layers/cudnn_lcn_layer.hpp"
#include "caffe/layers/cudnn_lrn_layer.hpp"
#include "caffe/layers/cudnn_pooling_layer.hpp"
Expand Down Expand Up @@ -73,6 +75,45 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(

REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);

// Get deconvolution layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetDeconvolutionLayer(const LayerParameter& param) {
ConvolutionParameter conv_param = param.convolution_param();
ConvolutionParameter_Engine engine = conv_param.engine();
#ifdef USE_CUDNN
bool use_dilation = false;
for (int i = 0; i < conv_param.dilation_size(); ++i) {
if (conv_param.dilation(i) > 1) {
use_dilation = true;
}
}
#endif
if (engine == ConvolutionParameter_Engine_DEFAULT) {
engine = ConvolutionParameter_Engine_CAFFE;
#ifdef USE_CUDNN
if (!use_dilation) {
engine = ConvolutionParameter_Engine_CUDNN;
}
#endif
}
if (engine == ConvolutionParameter_Engine_CAFFE) {
return shared_ptr<Layer<Dtype> >(new DeconvolutionLayer<Dtype>(param));
#ifdef USE_CUDNN
} else if (engine == ConvolutionParameter_Engine_CUDNN) {
if (use_dilation) {
LOG(FATAL) << "CuDNN doesn't support the dilated deconvolution at Layer "
<< param.name();
}
return shared_ptr<Layer<Dtype> >(new CuDNNDeconvolutionLayer<Dtype>(param));
#endif
} else {
LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
throw; // Avoids missing return warning
}
}

REGISTER_LAYER_CREATOR(Deconvolution, GetDeconvolutionLayer);

// Get pooling layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
Expand Down
Loading

0 comments on commit 7c573ca

Please sign in to comment.