Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Only allocate cudnn-rnn dropout memory if dropout p > 0 and acquire d…
Browse files Browse the repository at this point in the history
…escriptors during initialization (#11004)

* cudnn-rnn: Only allocate dropout memory if dropout p > 0

Also request cudnn descriptors during class initialization

* Don't call cudnnDropoutGetStatesSize when not allocating states

* Fixes
  • Loading branch information
leezu authored and szha committed May 20, 2018
1 parent bea5fd1 commit 10ac529
Showing 1 changed file with 47 additions and 40 deletions.
87 changes: 47 additions & 40 deletions src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,37 +76,56 @@ class CuDNNRNNOp : public Operator{
param_.lstm_q_ = true;
else
param_.lstm_q_ = false;

// Create descriptors
CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));

CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
}

~CuDNNRNNOp() {
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));

CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));

if (init_cudnn_) {
for (size_t i = 0; i < x_desc_vec_.size(); ++i) {
CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_vec_[i]));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_vec_[i]));
}
CUDNN_CALL(cudnnDestroyTensorDescriptor(hx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(hy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(cy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dhy_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dcy_desc_));

CUDNN_CALL(cudnnDestroyFilterDescriptor(w_desc_));
CUDNN_CALL(cudnnDestroyFilterDescriptor(dw_desc_));
CUDNN_CALL(cudnnDestroyRNNDescriptor(rnn_desc_));
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
Storage::Get()->Free(dropout_states_);
Storage::Get()->Free(reserve_space_);
init_cudnn_ = false;

Storage::Get()->Free(reserve_space_);
if (param_.p > 0) {
Storage::Get()->Free(dropout_states_);
}
}
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
virtual void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
Expand Down Expand Up @@ -395,15 +414,6 @@ class CuDNNRNNOp : public Operator{
strideA[1] = dimA[2];
strideA[2] = 1;

CUDNN_CALL(cudnnCreateTensorDescriptor(&hx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&hy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&cy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dhy_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dcy_desc_));

CUDNN_CALL(cudnnSetTensorNdDescriptor(hx_desc_,
dtype_,
3,
Expand Down Expand Up @@ -446,20 +456,19 @@ class CuDNNRNNOp : public Operator{
strideA));

// Create Dropout descriptors
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_,
&dropout_byte_));
dropout_size_ = dropout_byte_ / sizeof(DType);
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_,
s->dnn_handle_,
param_.p, // keep probability
dropout_states_.dptr,
dropout_byte_,
if (param_.p > 0) {
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
dropout_size_ = dropout_byte_ / sizeof(DType);
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
} else {
dropout_states_ = {};
dropout_byte_ = 0;
}
CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
param_.p, // discard probability
dropout_states_.dptr, dropout_byte_,
seed_));
// RNN descriptors
CUDNN_CALL(cudnnCreateRNNDescriptor(&rnn_desc_));

#if CUDNN_MAJOR >= 6
cudnnRNNAlgo_t rnn_algo = CUDNN_RNN_ALGO_STANDARD;
CUDNN_CALL(cudnnSetRNNDescriptor_v6(s->dnn_handle_,
Expand Down Expand Up @@ -514,8 +523,6 @@ class CuDNNRNNOp : public Operator{
CHECK_EQ(w.shape_[0] * sizeof(DType), cudnn_param_size);

// Set param descriptors
CUDNN_CALL(cudnnCreateFilterDescriptor(&w_desc_));
CUDNN_CALL(cudnnCreateFilterDescriptor(&dw_desc_));
int dim_w[3] = {1, 1, 1};
dim_w[0] = w.shape_[0];
CUDNN_CALL(cudnnSetFilterNdDescriptor(w_desc_,
Expand Down

0 comments on commit 10ac529

Please sign in to comment.