Skip to content

Commit

Permalink
working halfway into dropout, machine down, changing machine
Browse files Browse the repository at this point in the history
  • Loading branch information
Yangqing committed Sep 16, 2013
1 parent 0a526a4 commit 002e004
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 18 deletions.
6 changes: 3 additions & 3 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ LIBRARY_DIRS := . /usr/local/lib $(CUDA_LIB_DIR) $(MKL_LIB_DIR)
LIBRARIES := cuda cudart cublas protobuf glog mkl_rt mkl_intel_thread
WARNINGS := -Wall

CXXFLAGS += -fPIC $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -fPIC -O2 $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
LDFLAGS += $(foreach librarydir,$(LIBRARY_DIRS),-L$(librarydir))
LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))

Expand All @@ -53,8 +53,8 @@ $(TEST_NAME): $(OBJS) $(TEST_OBJS)
$(NAME): $(PROTO_GEN_CC) $(OBJS)
$(LINK) -shared $(OBJS) -o $(NAME)

$(CU_OBJS): $(CU_SRCS)
$(NVCC) -c -o $(CU_OBJS) $(CU_SRCS)
$(CU_OBJS): %.o: %.cu
$(NVCC) -c $< -o $@

$(PROTO_GEN_CC): $(PROTO_SRCS)
protoc $(PROTO_SRCS) --cpp_out=. --python_out=.
Expand Down
101 changes: 101 additions & 0 deletions src/caffeine/dropout_layer.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "caffeine/layer.hpp"
#include "caffeine/vision_layers.hpp"
#include <algorithm>

using std::max;

namespace caffeine {

template <typename Dtype>
void DropoutLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
NeuronLayer<Dtype>::SetUp(bottom, top);
// Set up the cache for random number generation
rand_mat_.reset(new Blob<float>(bottom.num(), bottom.channels(),
bottom.height(), bottom.width());
filler_.reset(new UniformFiller<float>(FillerParameter()));
};

template <typename Dtype>
void DropoutLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
// First, create the random matrix
filler_->Fill(rand_mat_.get());
const Dtype* bottom_data = bottom[0]->cpu_data();
const Dtype* rand_vals = rand_mat_->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
float threshold = layer_param_->dropout_ratio();
float scale = layer_param_->dropo
const int count = bottom[0]->count();
for (int i = 0; i < count; ++i) {
top_data[i] = rand_mat_ > ;
}
}

template <typename Dtype>
Dtype DropoutLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->cpu_data();
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
const int count = (*bottom)[0]->count();
for (int i = 0; i < count; ++i) {
bottom_diff[i] = top_diff[i] * (bottom_data[i] >= 0);
}
}
return Dtype(0);
}

template <typename Dtype>
__global__ void DropoutForward(const int n, const Dtype* in, Dtype* out) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out[index] = max(in[index], Dtype(0.));
}
}

template <typename Dtype>
void DropoutLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
const int count = bottom[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
DropoutForward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, bottom_data,
top_data);
}

template <typename Dtype>
__global__ void DropoutBackward(const int n, const Dtype* in_diff,
const Dtype* in_data, Dtype* out_diff) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
if (index < n) {
out_diff[index] = in_diff[index] * (in_data[index] >= 0);
}
}

template <typename Dtype>
Dtype DropoutLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down,
vector<Blob<Dtype>*>* bottom) {
if (propagate_down) {
const Dtype* bottom_data = (*bottom)[0]->gpu_data();
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
const int count = (*bottom)[0]->count();
const int blocks = (count + CAFFEINE_CUDA_NUM_THREADS - 1) /
CAFFEINE_CUDA_NUM_THREADS;
DropoutBackward<<<blocks, CAFFEINE_CUDA_NUM_THREADS>>>(count, top_diff,
bottom_data, bottom_diff);
}
return Dtype(0);
}

template class DropoutLayer<float>;
template class DropoutLayer<double>;


} // namespace caffeine
18 changes: 18 additions & 0 deletions src/caffeine/neuron_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "caffeine/layer.hpp"
#include "caffeine/vision_layers.hpp"

namespace caffeine {

template <typename Dtype>
void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
};

template class NeuronLayer<float>;
template class NeuronLayer<double>;

} // namespace caffeine
22 changes: 19 additions & 3 deletions src/caffeine/proto/layer_param.proto
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
package caffeine;

message LayerParameter {
required string name = 1;
required string type = 2;
required string name = 1; // the layer name
required string type = 2; // the string to specify the layer type

// Parameters to specify layers with inner products.
optional int32 num_output = 3; // The number of outputs for the layer
optional bool biasterm = 4 [default = true]; // whether to have bias terms
optional FillerParameter weight_filler = 5; // The filler for the weight
optional FillerParameter bias_filler = 6; // The filler for the bias

optional uint32 pad = 7 [default = 0]; // The padding size
optional uint32 kernelsize = 8; // The kernel size
optional uint32 group = 9 [default = 1]; // The group size for group conv
optional uint32 stride = 10 [default = 1]; // The stride
optional string pool = 11 [default = 'max']; // The pooling method
optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio

optional float alpha = 13 [default = 1.]; // for local response norm
optional float beta = 14 [default = 0.75]; // for local response norm
}

message FillerParameter {
Expand All @@ -21,4 +37,4 @@ message BlobProto {
optional int32 channels = 4 [default = 0];
repeated float data = 5;
repeated float diff = 6;
}
}
12 changes: 0 additions & 12 deletions src/caffeine/neuron_layer.cu → src/caffeine/relu_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,6 @@ using std::max;

namespace caffeine {

template <typename Dtype>
void NeuronLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
CHECK_EQ(bottom.size(), 1) << "Neuron Layer takes a single blob as input.";
CHECK_EQ(top->size(), 1) << "Neuron Layer takes a single blob as output.";
(*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
bottom[0]->height(), bottom[0]->width());
};

template class NeuronLayer<float>;
template class NeuronLayer<double>;

template <typename Dtype>
void ReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top) {
Expand Down
26 changes: 26 additions & 0 deletions src/caffeine/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,32 @@ class ReLULayer : public NeuronLayer<Dtype> {
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
};

template <typename Dtype>
class DropoutLayer : public NeuronLayer<Dtype> {
public:
explicit DropoutLayer(const LayerParameter& param)
: NeuronLayer<Dtype>(param) {};
virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
vector<Blob<Dtype>*>* top);

virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
const bool propagate_down, vector<Blob<Dtype>*>* bottom);
private:
shared_ptr<Blob<float> > rand_mat_;
shared_ptr<UniformFiller<float> > filler_;
};





} // namespace caffeine

#endif // CAFFEINE_VISION_LAYERS_HPP_
Expand Down

0 comments on commit 002e004

Please sign in to comment.