Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuDNN R2 #2038

Merged
merged 4 commits into from
Mar 24, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,8 +386,8 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down
12 changes: 6 additions & 6 deletions include/caffe/neuron_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down Expand Up @@ -516,8 +516,8 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down Expand Up @@ -601,8 +601,8 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_;
cudnnTensor4dDescriptor_t top_desc_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
};
#endif

Expand Down
34 changes: 19 additions & 15 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,41 +50,45 @@ template <typename Dtype> class dataType;
template<> class dataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static float oneval, zeroval;
static const void *one, *zero;
};
template<> class dataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static double oneval, zeroval;
static const void *one, *zero;
};

template <typename Dtype>
inline void createTensor4dDesc(cudnnTensor4dDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensor4dDescriptor(desc));
inline void createTensor4dDesc(cudnnTensorDescriptor_t* desc) {
CUDNN_CHECK(cudnnCreateTensorDescriptor(desc));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

template <typename Dtype>
inline void setTensor4dDesc(cudnnTensor4dDescriptor_t* desc,
inline void setTensor4dDesc(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w) {
const int stride_w = 1;
const int stride_h = w * stride_w;
const int stride_c = h * stride_h;
const int stride_n = c * stride_c;
setTensor4dDesc<Dtype>(desc, n, c, h, w,
stride_n, stride_c, stride_h, stride_w);
stride_n, stride_c, stride_h, stride_w);
}

template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilterDescriptor(*desc, dataType<Dtype>::type,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
}

Expand All @@ -95,29 +99,29 @@ inline void createConvolutionDesc(cudnnConvolutionDescriptor_t* conv) {

template <typename Dtype>
inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
cudnnTensor4dDescriptor_t bottom, cudnnFilterDescriptor_t filter,
cudnnTensorDescriptor_t bottom, cudnnFilterDescriptor_t filter,
int pad_h, int pad_w, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetConvolutionDescriptor(*conv, bottom, filter,
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
}

template <typename Dtype>
inline void createPoolingDesc(cudnnPoolingDescriptor_t* conv,
inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
PoolingParameter_PoolMethod poolmethod, cudnnPoolingMode_t* mode,
int h, int w, int stride_h, int stride_w) {
int h, int w, int pad_h, int pad_w, int stride_h, int stride_w) {
switch (poolmethod) {
case PoolingParameter_PoolMethod_MAX:
*mode = CUDNN_POOLING_MAX;
break;
case PoolingParameter_PoolMethod_AVE:
*mode = CUDNN_POOLING_AVERAGE;
*mode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
break;
default:
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(conv));
CUDNN_CHECK(cudnnSetPoolingDescriptor(*conv, *mode, h, w,
stride_h, stride_w));
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
}

} // namespace cudnn
Expand Down
8 changes: 5 additions & 3 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,13 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
bool handles_setup_;
cudnnHandle_t* handle_;
cudaStream_t* stream_;
vector<cudnnTensor4dDescriptor_t> bottom_descs_, top_descs_;
cudnnTensor4dDescriptor_t bias_desc_;
vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
cudnnTensorDescriptor_t bias_desc_;
cudnnFilterDescriptor_t filter_desc_;
vector<cudnnConvolutionDescriptor_t> conv_descs_;
int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
size_t workspaceSizeInBytes;
void *workspace;
};
#endif

Expand Down Expand Up @@ -445,7 +447,7 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {

bool handles_setup_;
cudnnHandle_t handle_;
cudnnTensor4dDescriptor_t bottom_desc_, top_desc_;
cudnnTensorDescriptor_t bottom_desc_, top_desc_;
cudnnPoolingDescriptor_t pooling_desc_;
cudnnPoolingMode_t mode_;
};
Expand Down
10 changes: 5 additions & 5 deletions src/caffe/layers/cudnn_conv_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(

// Create tensor descriptor(s) for data and corresponding convolution(s).
for (int i = 0; i < bottom.size(); i++) {
cudnnTensor4dDescriptor_t bottom_desc;
cudnnTensorDescriptor_t bottom_desc;
cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
bottom_descs_.push_back(bottom_desc);
cudnnTensor4dDescriptor_t top_desc;
cudnnTensorDescriptor_t top_desc;
cudnn::createTensor4dDesc<Dtype>(&top_desc);
top_descs_.push_back(top_desc);
cudnnConvolutionDescriptor_t conv_desc;
Expand Down Expand Up @@ -104,12 +104,12 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
if (!handles_setup_) { return; }

for (int i = 0; i < bottom_descs_.size(); i++) {
cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
cudnnDestroyTensor4dDescriptor(top_descs_[i]);
cudnnDestroyTensorDescriptor(bottom_descs_[i]);
cudnnDestroyTensorDescriptor(top_descs_[i]);
cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
}
if (this->bias_term_) {
cudnnDestroyTensor4dDescriptor(bias_desc_);
cudnnDestroyTensorDescriptor(bias_desc_);
}
cudnnDestroyFilterDescriptor(filter_desc_);

Expand Down
83 changes: 61 additions & 22 deletions src/caffe/layers/cudnn_conv_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,57 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(

// Forward through cuDNN in parallel over groups.
for (int g = 0; g < this->group_; g++) {
cudnnConvolutionFwdAlgo_t algo;

// pick the convolution algorithm
// TODO(shelhamer) this should be done during reshape
// TODO(shelhamer) the choice of automatic or manual algorithm picking
// should be exposed in proto
CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g],
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
top_descs_[i],
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0, // memoryLimitInBytes,
&algo));

// get minimum size of the workspace needed for the desired algorithm
size_t workspaceSizeInBytes_temp = 0;

CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g],
bottom_descs_[i],
filter_desc_,
conv_descs_[i],
top_descs_[i],
algo,
&workspaceSizeInBytes));

if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
workspaceSizeInBytes = workspaceSizeInBytes_temp;
// free the existing workspace and allocate a new (larger) one
cudaFree(this->workspace);
cudaMalloc(&(this->workspace), workspaceSizeInBytes);
}

// Filters.
CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
top_descs_[i], top_data + top_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
filter_desc_, weight + weight_offset_ * g,
conv_descs_[i],
algo, workspace, workspaceSizeInBytes,
cudnn::dataType<Dtype>::zero,
top_descs_[i], top_data + top_offset_ * g));

// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
Dtype alpha = 1.;
CUDNN_CHECK(cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
bias_desc_, bias_data + bias_offset_ * g,
top_descs_[i], top_data + top_offset_ * g));
CUDNN_CHECK(cudnnAddTensor(handle_[g], CUDNN_ADD_SAME_C,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_data + bias_offset_ * g,
cudnn::dataType<Dtype>::one,
top_descs_[i], top_data + top_offset_ * g));
}
}

Expand Down Expand Up @@ -68,20 +104,22 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Gradient w.r.t. bias.
if (this->bias_term_ && this->param_propagate_down_[1]) {
CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
top_descs_[i], top_diff + top_offset_ * g,
bias_desc_, bias_diff + bias_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
cudnn::dataType<Dtype>::one,
top_descs_[i], top_diff + top_offset_ * g,
cudnn::dataType<Dtype>::one,
bias_desc_, bias_diff + bias_offset_ * g));
}

// Gradient w.r.t. weights.
if (this->param_propagate_down_[0]) {
const Dtype* bottom_data = bottom[i]->gpu_data();
CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
filter_desc_, weight_diff + weight_offset_ * g,
CUDNN_RESULT_ACCUMULATE));
cudnn::dataType<Dtype>::one,
bottom_descs_[i], bottom_data + bottom_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
cudnn::dataType<Dtype>::one,
filter_desc_, weight_diff + weight_offset_ * g));
}

// Gradient w.r.t. bottom data.
Expand All @@ -91,11 +129,12 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
}
Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
bottom_descs_[i], bottom_diff + bottom_offset_ * g,
CUDNN_RESULT_NO_ACCUMULATE));
cudnn::dataType<Dtype>::one,
filter_desc_, weight + weight_offset_ * g,
top_descs_[i], top_diff + top_offset_ * g,
conv_descs_[i],
cudnn::dataType<Dtype>::zero,
bottom_descs_[i], bottom_diff + bottom_offset_ * g));
}
}

Expand Down
10 changes: 4 additions & 6 deletions src/caffe/layers/cudnn_pooling_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@ template <typename Dtype>
void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
PoolingLayer<Dtype>::LayerSetUp(bottom, top);
// Sanity check: CUDNN currently only supports pad == 0.
CHECK_EQ(this->pad_h_, 0);
CHECK_EQ(this->pad_w_, 0);
CUDNN_CHECK(cudnnCreate(&handle_));
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
this->layer_param_.pooling_param().pool(), &mode_,
this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
this->kernel_h_, this->kernel_w_, this->pad_h_, this->pad_w_,
this->stride_h_, this->stride_w_);
handles_setup_ = true;
}

Expand All @@ -40,8 +38,8 @@ CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

cudnnDestroyTensor4dDescriptor(bottom_desc_);
cudnnDestroyTensor4dDescriptor(top_desc_);
cudnnDestroyTensorDescriptor(bottom_desc_);
cudnnDestroyTensorDescriptor(top_desc_);
cudnnDestroyPoolingDescriptor(pooling_desc_);
cudnnDestroy(handle_);
}
Expand Down
12 changes: 9 additions & 3 deletions src/caffe/layers/cudnn_pooling_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ void CuDNNPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
CUDNN_CHECK(cudnnPoolingForward(handle_, pooling_desc_,
bottom_desc_, bottom_data, top_desc_, top_data));
cudnn::dataType<Dtype>::one,
bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
top_desc_, top_data));
}

template <typename Dtype>
Expand All @@ -29,8 +32,11 @@ void CuDNNPoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
CUDNN_CHECK(cudnnPoolingBackward(handle_, pooling_desc_,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data, bottom_desc_, bottom_diff));
cudnn::dataType<Dtype>::one,
top_desc_, top_data, top_desc_, top_diff,
bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
bottom_desc_, bottom_diff));
}

INSTANTIATE_LAYER_GPU_FUNCS(CuDNNPoolingLayer);
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/cudnn_relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }

cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
cudnnDestroyTensor4dDescriptor(this->top_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
cudnnDestroy(this->handle_);
}

Expand Down
Loading