Skip to content

Commit

Permalink
Merge branch 'joey/op_exchange' into 'v2.2.1-integration'
Browse files Browse the repository at this point in the history
Joey/op exchange

See merge request zehuanw/hugectr!214
  • Loading branch information
zehuanw committed Aug 29, 2020
2 parents 585bfbf + 0296f47 commit 8d0cd34
Show file tree
Hide file tree
Showing 8 changed files with 340 additions and 1 deletion.
88 changes: 88 additions & 0 deletions HugeCTR/include/layers/dropout_cudnn_layer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cudnn.h>
#include <layer.hpp>

namespace HugeCTR {

/**
* Dropout layer which selects an arbitrary fraction of inputs to 0
*/
template <typename T>
class DropoutCudnnLayer : public Layer {
/*
* stores the weight tensors of this layer.
*/
// Tensors<float> weights_; It is inherited from Layer.
/*
* stores the weight gradient tensors of this layer.
*/
Tensors2<T> wgrad_;
/*
* stores the references to the input tensors of this layer.
*/
Tensors2<T> in_tensors_;
/*
* stores the references to the output tensors of this layer.
*/
Tensors2<T> out_tensors_;

public:
/**
* Ctor of DropoutCudnnLayer.
* @param in_tensor the input tensor
* @param out_tensor the output tensor which has the same dim with in_tensor
* @param rate fraction of the inputs set to zero., 0 < rate < 1, default = 0.5
* @param device_id the id of GPU where this layer belongs
*/
DropoutCudnnLayer(const Tensor2<T>& in_tensor, const Tensor2<T>& out_tensor,
const std::shared_ptr<GeneralBuffer2<CudaAllocator>> blobs_buff, float rate,
cudnnHandle_t const& cudnn_handle, int device_id);

~DropoutCudnnLayer() override;

/**
* A method of implementing the forward pass of Dropout
* @param stream CUDA stream where the foward propagation is executed
*/
void fprop(bool is_train, cudaStream_t stream) override;
/**
* A method of implementing the backward pass of Dropout
* @param stream CUDA stream where the backward propagation is executed
*/
void bprop(cudaStream_t stream) override;

const float* mask() const { return mask_.get_ptr(); }

private:
int64_t get_seed() const;
cudnnDropoutDescriptor_t dropout_descriptor_;
float rate_;
float scale_;
void* cudnn_status_;
Tensor2<float> mask_;
const cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t in_out_desc_;
size_t reserveSpaceSizeInBytes_;
int n_sms_;


};

} // namespace HugeCTR
15 changes: 15 additions & 0 deletions HugeCTR/include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuda_runtime_api.h>
#include <cudnn.h>
#include <sys/stat.h>
#include <chrono>
#include <cmath>
Expand Down Expand Up @@ -528,5 +529,19 @@ struct TypeConvert<__half> {
};


template <typename T>
struct CudnnDataType;

template <>
struct CudnnDataType<float> {
static cudnnDataType_t getType(){return CUDNN_DATA_FLOAT;}
};

template <>
struct CudnnDataType<__half> {
static cudnnDataType_t getType(){return CUDNN_DATA_FLOAT;}
};



} // namespace HugeCTR
1 change: 1 addition & 0 deletions HugeCTR/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ file(GLOB huge_ctr_src
layers/cast_layer.cu
layers/concat_layer.cu
layers/dropout_layer.cu
layers/dropout_cudnn_layer.cu
layers/elu_layer.cu
layers/fully_connected_layer.cu
layers/fully_connected_layer_half.cu
Expand Down
2 changes: 1 addition & 1 deletion HugeCTR/src/layers/batch_norm_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -211,4 +211,4 @@ std::unique_ptr<DataSimulator<float>> BatchNormLayer::get_default_initializer(co
return simu;
}

} // namespace HugeCTR
} // namespace HugeCTR
131 changes: 131 additions & 0 deletions HugeCTR/src/layers/dropout_cudnn_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <HugeCTR/include/utils.hpp>
#include <algorithm>
#include <cstdio>
#include <ctime>
#include <functional>
#include <layers/dropout_cudnn_layer.hpp>
#include <prims/linalg/binary_op.cuh>
#include <utils.hpp>
#include <utils.cuh>

#ifndef NDEBUG
#include <iostream>
#endif

namespace HugeCTR {

template <typename T>
DropoutCudnnLayer<T>::DropoutCudnnLayer(const Tensor2<T>& in_tensor, const Tensor2<T>& out_tensor,
const std::shared_ptr<GeneralBuffer2<CudaAllocator>> blobs_buff, float rate,
cudnnHandle_t const& cudnn_handle, int device_id)

: Layer(device_id),
rate_(rate),
scale_(1.0 / (1.0 - rate)),
cudnn_handle_(cudnn_handle),
n_sms_(0) {
assert(in_tensor.get_num_elements() == out_tensor.get_num_elements());
assert(rate_ > 0.f && rate_ < 1.f);


const auto& in_tensor_dim = in_tensor.get_dimensions();
in_tensors_.emplace_back(in_tensor);
out_tensors_.emplace_back(out_tensor);

CudaDeviceContext context(get_device_id());

size_t num_feature = in_tensor_dim[1];
int batch_size = in_tensor_dim[0];
cudnnDataType_t data_type = CudnnDataType<T>::getType();
int n_stride = num_feature;
int w_stride = 1;
CK_CUDNN_THROW_(cudnnCreateTensorDescriptor(&in_out_desc_));
CK_CUDNN_THROW_(cudnnSetTensor4dDescriptorEx(in_out_desc_, data_type, batch_size, 1, 1,
num_feature, n_stride, 1, 1, w_stride));

CK_CUDNN_THROW_(cudnnCreateDropoutDescriptor(&dropout_descriptor_));

size_t sizeInBytes = 0;

CK_CUDNN_THROW_(cudnnDropoutGetStatesSize(cudnn_handle_, &sizeInBytes));

assert(sizeInBytes!=0);

CK_CUDNN_THROW_(cudnnDropoutGetReserveSpaceSize(in_out_desc_, &reserveSpaceSizeInBytes_));

blobs_buff->reserve({1, reserveSpaceSizeInBytes_}, &mask_);

cudaMalloc(&cudnn_status_, sizeInBytes);

CK_CUDNN_THROW_(cudnnSetDropoutDescriptor(dropout_descriptor_, cudnn_handle_, rate, cudnn_status_ ,sizeInBytes, 0));

}

template <typename T>
DropoutCudnnLayer<T>::~DropoutCudnnLayer() {
try {
CK_CUDNN_THROW_(cudnnDestroyDropoutDescriptor(dropout_descriptor_));
cudaFree(cudnn_status_);
CK_CUDNN_THROW_(cudnnDestroyTensorDescriptor(in_out_desc_));
} catch (const std::runtime_error& rt_err) {
std::cerr << rt_err.what() << std::endl;
}

}

template <typename T>
void DropoutCudnnLayer<T>::fprop(bool is_train, cudaStream_t stream) {
CudaDeviceContext context(get_device_id());

if (is_train) {
CK_CUDNN_THROW_(cudnnSetStream(cudnn_handle_, stream));
CK_CUDNN_THROW_(cudnnDropoutForward(cudnn_handle_, dropout_descriptor_, in_out_desc_, in_tensors_[0].get_ptr(), in_out_desc_, out_tensors_[0].get_ptr(), mask_.get_ptr(), reserveSpaceSizeInBytes_));
} else {
cudaMemcpyAsync(out_tensors_[0].get_ptr(), in_tensors_[0].get_ptr(),
in_tensors_[0].get_size_in_bytes(), cudaMemcpyDeviceToDevice, stream);
}
}

template <typename T>
void DropoutCudnnLayer<T>::bprop(cudaStream_t stream) {
CudaDeviceContext context(get_device_id());
CK_CUDNN_THROW_(cudnnSetStream(cudnn_handle_, stream));
CK_CUDNN_THROW_(cudnnDropoutBackward(
cudnn_handle_, dropout_descriptor_, in_out_desc_, out_tensors_[0].get_ptr(), in_out_desc_, in_tensors_[0].get_ptr(), mask_.get_ptr(), reserveSpaceSizeInBytes_));
}

template <typename T>
int64_t DropoutCudnnLayer<T>::get_seed() const {
FILE* f = fopen("/dev/urandom", "rb");
if (f) {
int64_t seed;
size_t ret = fread(&seed, 1, sizeof(seed), f);
fclose(f);
if (ret == sizeof(seed)) {
return seed;
}
}
return time(nullptr);
}


template class DropoutCudnnLayer<float>;
template class DropoutCudnnLayer<__half>;

} // namespace HugeCTR
13 changes: 13 additions & 0 deletions HugeCTR/src/parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <layers/concat_layer.hpp>
#include <layers/dot_product_layer.hpp>
#include <layers/dropout_layer.hpp>
#include <layers/dropout_cudnn_layer.hpp>
#include <layers/elu_layer.hpp>
#include <layers/fm_order2_layer.hpp>
#include <layers/fully_connected_layer.hpp>
Expand Down Expand Up @@ -452,9 +453,15 @@ Network* create_network(const nlohmann::json& j_array, const nlohmann::json& j_o
// get ELU params
auto rate_it = j.find("rate");
auto rate = (rate_it != j.end()) ? rate_it->get<float>() : 0.5f;
#ifndef PREFER_CUDNN
layers.emplace_back(new DropoutLayer<__half>(do_in_tensor, do_out_tensor, blobs_buff,
rate, gpu_resource->get_curand_generator(),
device_id));
#else
layers.emplace_back(new DropoutCudnnLayer<__half>(do_in_tensor, do_out_tensor, blobs_buff,
rate, gpu_resource->get_cudnn_handle(),
device_id));
#endif
} else {
// establish out tensor
Tensor2<float> do_in_tensor =
Expand All @@ -465,9 +472,15 @@ Network* create_network(const nlohmann::json& j_array, const nlohmann::json& j_o
// get ELU params
auto rate_it = j.find("rate");
auto rate = (rate_it != j.end()) ? rate_it->get<float>() : 0.5f;
#ifndef PREFER_CUDNN
layers.emplace_back(new DropoutLayer<float>(do_in_tensor, do_out_tensor, blobs_buff, rate,
gpu_resource->get_curand_generator(),
device_id));
#else
layers.emplace_back(new DropoutCudnnLayer<float>(do_in_tensor, do_out_tensor, blobs_buff,
rate, gpu_resource->get_cudnn_handle(),
device_id));
#endif
}
network->enable_cuda_graph_ = false;

Expand Down
1 change: 1 addition & 0 deletions test/utest/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

cmake_minimum_required(VERSION 3.8)
file(GLOB layers_test_src
dropout_cudnn_layer_test.cpp
batch_norm_layer_test.cpp
cast_layer_test.cpp
concat_layer_test.cpp
Expand Down
Loading

0 comments on commit 8d0cd34

Please sign in to comment.