diff --git a/README.md b/README.md index 7b930ccd..72ff00a4 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ caffeine ======== -caffeine. +caffeine: Convolutional Algorithms For Feature Extraction. + +Copyright Yangqing Jia diff --git a/src/Makefile b/src/Makefile index 9ab43e5b..111ca164 100644 --- a/src/Makefile +++ b/src/Makefile @@ -31,7 +31,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64 INCLUDE_DIRS := . /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR) LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR) -LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread +LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand WARNINGS := -Wall CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir)) diff --git a/src/caffeine/common.cpp b/src/caffeine/common.cpp index 7ac0eada..d5a7b224 100644 --- a/src/caffeine/common.cpp +++ b/src/caffeine/common.cpp @@ -7,11 +7,16 @@ shared_ptr Caffeine::singleton_; Caffeine::Caffeine() : mode_(Caffeine::CPU), phase_(Caffeine::TRAIN) { CUBLAS_CHECK(cublasCreate(&cublas_handle_)); + CURAND_CHECK(curandCreateGenerator(&curand_generator_, + CURAND_RNG_PSEUDO_XORWOW)); VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701)); } Caffeine::~Caffeine() { if (!cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_)); + if (!curand_generator_) { + CURAND_CHECK(curandDestroyGenerator(curand_generator_)); + } if (!vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_)); }; @@ -30,6 +35,10 @@ cublasHandle_t Caffeine::cublas_handle() { return Get().cublas_handle_; }; +curandGenerator_t Caffeine::curand_generator() { + return Get().curand_generator_; +}; + Caffeine::Brew Caffeine::mode() { return Get().mode_; } diff --git a/src/caffeine/common.hpp b/src/caffeine/common.hpp index 060d1f7d..080cb9a6 100644 --- a/src/caffeine/common.hpp +++ b/src/caffeine/common.hpp @@ -3,6 +3,8 @@ #include #include +#include +#include #include #include @@ -10,6 +12,7 @@ #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess) #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS) +#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK) namespace caffeine { @@ -21,6 +24,10 @@ using boost::shared_ptr; // For backward compatibility we will just use 512 threads per block const int CAFFEINE_CUDA_NUM_THREADS = 512; +inline int CAFFEINE_GET_BLOCKS(const int N) { + return (N + CAFFEINE_CUDA_NUM_THREADS - 1) / CAFFEINE_CUDA_NUM_THREADS; +} + // A singleton class to hold common caffeine stuff, such as the handler that // caffeine is going to use for cublas. class Caffeine { @@ -32,6 +39,7 @@ class Caffeine { // The getters for the variables. static cublasHandle_t cublas_handle(); + static curandGenerator_t curand_generator(); static VSLStreamStatePtr vsl_stream(); static Brew mode(); static Phase phase(); @@ -42,6 +50,7 @@ class Caffeine { Caffeine(); static shared_ptr singleton_; cublasHandle_t cublas_handle_; + curandGenerator_t curand_generator_; VSLStreamStatePtr vsl_stream_; Brew mode_; Phase phase_; diff --git a/src/caffeine/dropout_layer.cu b/src/caffeine/dropout_layer.cu index 23999fb0..29398dd6 100644 --- a/src/caffeine/dropout_layer.cu +++ b/src/caffeine/dropout_layer.cu @@ -1,6 +1,11 @@ +#include +#include + +#include "caffeine/common.hpp" #include "caffeine/layer.hpp" +#include "caffeine/syncedmem.hpp" #include "caffeine/vision_layers.hpp" -#include + using std::max; @@ -11,24 +16,29 @@ void DropoutLayer::SetUp(const vector*>& bottom, vector*>* top) { NeuronLayer::SetUp(bottom, top); // Set up the cache for random number generation - rand_mat_.reset(new Blob(bottom.num(), bottom.channels(), - bottom.height(), bottom.width()); - filler_.reset(new UniformFiller(FillerParameter())); + rand_vec_.reset(new SyncedMemory(bottom[0]->count() * sizeof(int))); }; template void DropoutLayer::Forward_cpu(const vector*>& bottom, vector*>* top) { - // First, create the random matrix - filler_->Fill(rand_mat_.get()); const Dtype* bottom_data = bottom[0]->cpu_data(); - const Dtype* rand_vals = rand_mat_->cpu_data(); Dtype* top_data = (*top)[0]->mutable_cpu_data(); - float threshold = layer_param_->dropout_ratio(); - float scale = layer_param_->dropo + float threshold = this->layer_param_.dropout_ratio(); + DCHECK(threshold > 0.); + DCHECK(threshold < 1.); + float scale = 1. / threshold; const int count = bottom[0]->count(); - for (int i = 0; i < count; ++i) { - top_data[i] = rand_mat_ > ; + if (Caffeine::phase() == Caffeine::TRAIN) { + // Create random numbers + viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(), + count, (int*)(rand_vec_->mutable_cpu_data()), + 1. - threshold); + for (int i = 0; i < count; ++i) { + top_data[i] = bottom_data[i] * rand_vec_[i] * scale; + } + } else { + memcpy(top_data, bottom_data, bottom[0]->count() * sizeof(Dtype)); } } @@ -36,23 +46,25 @@ template Dtype DropoutLayer::Backward_cpu(const vector*>& top, const bool propagate_down, vector*>* bottom) { + CHECK(Caffeine::phase() == Caffeine::TRAIN); if (propagate_down) { - const Dtype* bottom_data = (*bottom)[0]->cpu_data(); const Dtype* top_diff = top[0]->cpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + const int* mask = (int*)(rand_vec_->cpu_data()); const int count = (*bottom)[0]->count(); for (int i = 0; i < count; ++i) { - bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0); + bottom_diff[i] = top_diff[i] * mask[i]; } } return Dtype(0); } template -__global__ void DropoutForward(const int n, const Dtype* in, Dtype* out) { +__global__ void DropoutForward(const int n, const Dtype* in, + const unsigned int* mask, const unsigned int threshold, Dtype* out) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { - out[index] = max(in[index], Dtype(0.)); + out[index] = in[index] * (mask[index] > threshold); } } @@ -61,19 +73,32 @@ void DropoutLayer::Forward_gpu(const vector*>& bottom, vector*>* top) { const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = (*top)[0]->mutable_gpu_data(); + float threshold = this->layer_param_.dropout_ratio(); + DCHECK(threshold > 0.); + DCHECK(threshold < 1.); + float scale = 1. / threshold; const int count = bottom[0]->count(); - const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) / - CAFFEINE_CUDA_NUM_THREADS; - DropoutForward<<>>(count, bottom_data, - top_data); + if (Caffeine::phase() == Caffeine::TRAIN) { + // Create random numbers + CURAND_CHECK(curandGenerate(Caffeine::curand_generator(), + (unsigned int*)(rand_vec_->mutable_gpu_data()), count)); + unsigned int uint_thres = (unsigned int)(UINT_MAX * threshold); + // set thresholds + DropoutForward<<>>( + count, bottom_data, (unsigned int*)(rand_vec_->gpu_data(), uint_thres, + top_data); + } else { + CUDA_CHECK(cudaMemcpy(top_data, bottom_data, + count * sizeof(Dtype))); + } } template __global__ void DropoutBackward(const int n, const Dtype* in_diff, - const Dtype* in_data, Dtype* out_diff) { + const unsigned int* mask, const unsigned int threshold, Dtype* out_diff) { int index = threadIdx.x + blockIdx.x * blockDim.x; if (index < n) { - out_diff[index] = in_diff[index] * (in_data[index] >= 0); + out_diff[index] = in_diff[index] * (mask[index] > threshold); } } @@ -81,15 +106,15 @@ template Dtype DropoutLayer::Backward_gpu(const vector*>& top, const bool propagate_down, vector*>* bottom) { + CHECK(Caffeine::phase() == Caffeine::TRAIN); if (propagate_down) { - const Dtype* bottom_data = (*bottom)[0]->gpu_data(); const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + const unsigned int* mask = (int*)(rand_vec_->gpu_data()); const int count = (*bottom)[0]->count(); - const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) / - CAFFEINE_CUDA_NUM_THREADS; - DropoutBackward<<>>(count, top_diff, - bottom_data, bottom_diff); + DropoutBackward<<>>( + count, top_diff, (unsigned int*)(rand_vec_->gpu_data(), uint_thres, + bottom_diff); } return Dtype(0); } diff --git a/src/caffeine/filler.hpp b/src/caffeine/filler.hpp index 04ba6499..880e6152 100644 --- a/src/caffeine/filler.hpp +++ b/src/caffeine/filler.hpp @@ -1,3 +1,7 @@ +// Fillers are random number generators that fills a blob using the specified +// algorithm. The expectation is that they are only going to be used during +// initialization time and will not involve any GPUs. + #ifndef CAFFEINE_FILLER_HPP #define CAFFEINE_FILLER_HPP diff --git a/src/caffeine/relu_layer.cu b/src/caffeine/relu_layer.cu index 158131a0..fb95b043 100644 --- a/src/caffeine/relu_layer.cu +++ b/src/caffeine/relu_layer.cu @@ -47,10 +47,8 @@ void ReLULayer::Forward_gpu(const vector*>& bottom, const Dtype* bottom_data = bottom[0]->gpu_data(); Dtype* top_data = (*top)[0]->mutable_gpu_data(); const int count = bottom[0]->count(); - const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) / - CAFFEINE_CUDA_NUM_THREADS; - ReLUForward<<>>(count, bottom_data, - top_data); + ReLUForward<<>>( + count, bottom_data, top_data); } template @@ -71,10 +69,8 @@ Dtype ReLULayer::Backward_gpu(const vector*>& top, const Dtype* top_diff = top[0]->gpu_diff(); Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); const int count = (*bottom)[0]->count(); - const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) / - CAFFEINE_CUDA_NUM_THREADS; - ReLUBackward<<>>(count, top_diff, - bottom_data, bottom_diff); + ReLUBackward<<>>( + count, top_diff, bottom_data, bottom_diff); } return Dtype(0); } diff --git a/src/caffeine/vision_layers.hpp b/src/caffeine/vision_layers.hpp index 08561bce..f1cea344 100644 --- a/src/caffeine/vision_layers.hpp +++ b/src/caffeine/vision_layers.hpp @@ -48,9 +48,7 @@ class DropoutLayer : public NeuronLayer { const bool propagate_down, vector*>* bottom); virtual Dtype Backward_gpu(const vector*>& top, const bool propagate_down, vector*>* bottom); - private: - shared_ptr > rand_mat_; - shared_ptr > filler_; + shared_ptr rand_vec_; };