Skip to content

Commit

Permalink
working update
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangqing committed Sep 16, 2013
1 parent 002e004 commit 3d9674a
Show file tree
Hide file tree
Showing 8 changed files with 82 additions and 39 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
caffeine
========

caffeine.
caffeine: Convolutional Algorithms For Feature Extraction.

Copyright Yangqing Jia
2 changes: 1 addition & 1 deletion src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ MKL_LIB_DIR := $(MKL_DIR)/lib $(MKL_DIR)/lib/intel64

INCLUDE_DIRS := . /usr/local/include $(CUDA_INCLUDE_DIR) $(MKL_INCLUDE_DIR)
LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread curand
WARNINGS := -Wall

CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
Expand Down
9 changes: 9 additions & 0 deletions src/caffeine/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,16 @@ shared_ptr<Caffeine> Caffeine::singleton_;
Caffeine::Caffeine()
: mode_(Caffeine::CPU), phase_(Caffeine::TRAIN) {
CUBLAS_CHECK(cublasCreate(&cublas_handle_));
CURAND_CHECK(curandCreateGenerator(&curand_generator_,
CURAND_RNG_PSEUDO_XORWOW));
VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701));
}

Caffeine::~Caffeine() {
if (!cublas_handle_) CUBLAS_CHECK(cublasDestroy(cublas_handle_));
if (!curand_generator_) {
CURAND_CHECK(curandDestroyGenerator(curand_generator_));
}
if (!vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
};

Expand All @@ -30,6 +35,10 @@ cublasHandle_t Caffeine::cublas_handle() {
return Get().cublas_handle_;
};

curandGenerator_t Caffeine::curand_generator() {
return Get().curand_generator_;
};

Caffeine::Brew Caffeine::mode() {
return Get().mode_;
}
Expand Down
9 changes: 9 additions & 0 deletions src/caffeine/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

#include <boost/shared_ptr.hpp>
#include <cublas_v2.h>
#include <cuda.h>
#include <curand.h>
#include <glog/logging.h>
#include <mkl_vsl.h>

#include "driver_types.h"

#define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
#define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)

namespace caffeine {
Expand All @@ -21,6 +24,10 @@ using boost::shared_ptr;
// For backward compatibility we will just use 512 threads per block
const int CAFFEINE_CUDA_NUM_THREADS = 512;

inline int CAFFEINE_GET_BLOCKS(const int N) {
return (N + CAFFEINE_CUDA_NUM_THREADS - 1) / CAFFEINE_CUDA_NUM_THREADS;
}

// A singleton class to hold common caffeine stuff, such as the handler that
// caffeine is going to use for cublas.
class Caffeine {
Expand All @@ -32,6 +39,7 @@ class Caffeine {

// The getters for the variables.
static cublasHandle_t cublas_handle();
static curandGenerator_t curand_generator();
static VSLStreamStatePtr vsl_stream();
static Brew mode();
static Phase phase();
Expand All @@ -42,6 +50,7 @@ class Caffeine {
Caffeine();
static shared_ptr<Caffeine> singleton_;
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
Brew mode_;
Phase phase_;
Expand Down
77 changes: 51 additions & 26 deletions src/caffeine/dropout_layer.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include <algorithm>
#include <limits>

#include "caffeine/common.hpp"
#include "caffeine/layer.hpp"
#include "caffeine/syncedmem.hpp"
#include "caffeine/vision_layers.hpp"
#include <algorithm>


using std::max;

Expand All @@ -11,48 +16,55 @@ void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
NeuronLayer<Dtype>::SetUp(bottom, top);
// Set up the cache for random number generation
rand_mat_.reset(new Blob<float>(bottom.num(), bottom.channels(),
bottom.height(), bottom.width());
filler_.reset(new UniformFiller<float>(FillerParameter()));
rand_vec_.reset(new SyncedMemory(bottom[0]->count() * sizeof(int)));
};

template <typename Dtype>
void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, create the random matrix
filler_->Fill(rand_mat_.get());
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* rand_vals = rand_mat_->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
float threshold = layer_param_->dropout_ratio();
float scale = layer_param_->dropo
float threshold = this->layer_param_.dropout_ratio();
DCHECK(threshold > 0.);
DCHECK(threshold < 1.);
float scale = 1. / threshold;
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = rand_mat_ > ;
if (Caffeine::phase() == Caffeine::TRAIN) {
// Create random numbers
viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, Caffeine::vsl_stream(),
count, (int*)(rand_vec_->mutable_cpu_data()),
1. - threshold);
for (int i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] * rand_vec_[i] * scale;
}
} else {
memcpy(top_data, bottom_data, bottom[0]->count() * sizeof(Dtype));
}
}

template <typename Dtype>
Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
CHECK(Caffeine::phase() == Caffeine::TRAIN);
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int* mask = (int*)(rand_vec_->cpu_data());
const int count = (*bottom)[0]->count();
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
bottom_diff[i] = top_diff[i] * mask[i];
}
}
return Dtype(0);
}

template <typename Dtype>
__global__ void DropoutForward(const int n, const Dtype* in, Dtype* out) {
__global__ void DropoutForward(const int n, const Dtype* in,
const unsigned int* mask, const unsigned int threshold, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out[index] = max(in[index], Dtype(0.));
out[index] = in[index] * (mask[index] > threshold);
}
}

Expand All @@ -61,35 +73,48 @@ void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
float threshold = this->layer_param_.dropout_ratio();
DCHECK(threshold > 0.);
DCHECK(threshold < 1.);
float scale = 1. / threshold;
const int count = bottom[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
DropoutForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
top_data);
if (Caffeine::phase() == Caffeine::TRAIN) {
// Create random numbers
CURAND_CHECK(curandGenerate(Caffeine::curand_generator(),
(unsigned int*)(rand_vec_->mutable_gpu_data()), count));
unsigned int uint_thres = (unsigned int)(UINT_MAX * threshold);
// set thresholds
DropoutForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, bottom_data, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
top_data);
} else {
CUDA_CHECK(cudaMemcpy(top_data, bottom_data,
count * sizeof(Dtype)));
}
}

template <typename Dtype>
__global__ void DropoutBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
const unsigned int* mask, const unsigned int threshold, Dtype* out_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out_diff[index] = in_diff[index] * (in_data[index] >= 0);
out_diff[index] = in_diff[index] * (mask[index] > threshold);
}
}

template <typename Dtype>
Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
CHECK(Caffeine::phase() == Caffeine::TRAIN);
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const unsigned int* mask = (int*)(rand_vec_->gpu_data());
const int count = (*bottom)[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
DropoutBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
bottom_data, bottom_diff);
DropoutBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, (unsigned int*)(rand_vec_->gpu_data(), uint_thres,
bottom_diff);
}
return Dtype(0);
}
Expand Down
4 changes: 4 additions & 0 deletions src/caffeine/filler.hpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
// Fillers are random number generators that fills a blob using the specified
// algorithm. The expectation is that they are only going to be used during
// initialization time and will not involve any GPUs.

#ifndef CAFFEINE_FILLER_HPP
#define CAFFEINE_FILLER_HPP

Expand Down
12 changes: 4 additions & 8 deletions src/caffeine/relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ void ReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
ReLUForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
top_data);
ReLUForward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, bottom_data, top_data);
}

template <typename Dtype>
Expand All @@ -71,10 +69,8 @@ Dtype ReLULayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
ReLUBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
bottom_data, bottom_diff);
ReLUBackward<Dtype><<<CAFFEINE_GET_BLOCKS(count), CAFFEINE_CUDA_NUM_THREADS>>>(
count, top_diff, bottom_data, bottom_diff);
}
return Dtype(0);
}
Expand Down
4 changes: 1 addition & 3 deletions src/caffeine/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ class DropoutLayer : public NeuronLayer<Dtype> {
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
private:
shared_ptr<Blob<float> > rand_mat_;
shared_ptr<UniformFiller<float> > filler_;
shared_ptr<SyncedMemory> rand_vec_;
};


Expand Down

0 comments on commit 3d9674a

Please sign in to comment.