Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pruning possibilities at inner_product_layer #4294

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/caffe/layers/inner_product_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ class InnerProductLayer : public Layer<Dtype> {
int K_;
int N_;
bool bias_term_;
bool pruned_;
// contain masks for deep compression
vector<shared_ptr<Blob<Dtype> > > masks_;
Dtype pruning_coeff_; // pruning rate for deep compression
Blob<Dtype> bias_multiplier_;
bool transpose_; ///< if true, assume transposed weights
};
Expand Down
4 changes: 4 additions & 0 deletions include/caffe/util/math_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ Dtype caffe_cpu_strided_dot(const int n, const Dtype* x, const int incx,
template <typename Dtype>
Dtype caffe_cpu_asum(const int n, const Dtype* x);

// Prune a blob with a mask according to the coeff% lowest absolute values of x.
template <typename Dtype>
void caffe_cpu_prune(const int n, const Dtype coeff, Dtype* x, Dtype* mask);

// the branchless, type-safe version from
// https://stackoverflow.com/questions/1903954/is-there-a-standard-sign-function-signum-sgn-in-c-c
template<typename Dtype>
Expand Down
38 changes: 38 additions & 0 deletions src/caffe/layers/inner_product_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const int num_output = this->layer_param_.inner_product_param().num_output();
bias_term_ = this->layer_param_.inner_product_param().bias_term();
transpose_ = this->layer_param_.inner_product_param().transpose();
pruning_coeff_ = this->layer_param_.pruning_param().coeff();
CHECK_GE(pruning_coeff_, 0);
CHECK_GT(1, pruning_coeff_);
pruned_ = (pruning_coeff_ == 0);
N_ = num_output;
const int axis = bottom[0]->CanonicalAxisIndex(
this->layer_param_.inner_product_param().axis());
Expand All @@ -28,6 +32,10 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
} else {
this->blobs_.resize(1);
}
// Deep compression pruning
if (pruning_coeff_ > 0) {
masks_.resize(this->blobs_.size());
}
// Initialize the weights
vector<int> weight_shape(2);
if (transpose_) {
Expand All @@ -42,13 +50,23 @@ void InnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
shared_ptr<Filler<Dtype> > weight_filler(GetFiller<Dtype>(
this->layer_param_.inner_product_param().weight_filler()));
weight_filler->Fill(this->blobs_[0].get());
if (pruning_coeff_ != 0) {
masks_[0].reset(new Blob<Dtype>(weight_shape));
caffe_set<Dtype>(this->blobs_[0]->count(), (Dtype)1.,
masks_[0]->mutable_cpu_data());
}
// If necessary, intiialize and fill the bias term
if (bias_term_) {
vector<int> bias_shape(1, N_);
this->blobs_[1].reset(new Blob<Dtype>(bias_shape));
shared_ptr<Filler<Dtype> > bias_filler(GetFiller<Dtype>(
this->layer_param_.inner_product_param().bias_filler()));
bias_filler->Fill(this->blobs_[1].get());
if (pruning_coeff_ != 0) {
masks_[1].reset(new Blob<Dtype>(bias_shape));
caffe_set<Dtype>(this->blobs_[1]->count(), (Dtype)1.,
masks_[1]->mutable_cpu_data());
}
}
} // parameter initialization
this->param_propagate_down_.resize(this->blobs_.size(), true);
Expand Down Expand Up @@ -86,6 +104,16 @@ void InnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const Dtype* weight = this->blobs_[0]->cpu_data();
// prune only once after loading the caffemodel
if (!pruned_) {
caffe_cpu_prune(this->blobs_[0]->count(), pruning_coeff_,
this->blobs_[0]->mutable_cpu_data(), masks_[0]->mutable_cpu_data());
if (bias_term_) {
caffe_cpu_prune(this->blobs_[1]->count(), pruning_coeff_,
this->blobs_[1]->mutable_cpu_data(), masks_[1]->mutable_cpu_data());
}
pruned_ = true;
}
caffe_cpu_gemm<Dtype>(CblasNoTrans, transpose_ ? CblasNoTrans : CblasTrans,
M_, N_, K_, (Dtype)1.,
bottom_data, weight, (Dtype)0., top_data);
Expand Down Expand Up @@ -138,6 +166,16 @@ void InnerProductLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
(Dtype)0., bottom[0]->mutable_cpu_diff());
}
}
if (pruning_coeff_ > 0) {
if (this->param_propagate_down_[0]) {
caffe_mul(this->blobs_[0]->count(), this->blobs_[0]->cpu_diff(),
masks_[0]->cpu_data(), this->blobs_[0]->mutable_cpu_diff());
}
if (bias_term_ && this->param_propagate_down_[1]) {
caffe_mul(this->blobs_[1]->count(), this->blobs_[1]->cpu_diff(),
masks_[1]->cpu_data(), this->blobs_[1]->mutable_cpu_diff());
}
}
}

#ifdef CPU_ONLY
Expand Down
20 changes: 20 additions & 0 deletions src/caffe/layers/inner_product_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,16 @@ namespace caffe {
template <typename Dtype>
void InnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
// prune only once after loading the caffemodel
if (!pruned_) {
caffe_cpu_prune(this->blobs_[0]->count(), pruning_coeff_,
this->blobs_[0]->mutable_cpu_data(), masks_[0]->mutable_cpu_data());
if (bias_term_) {
caffe_cpu_prune(this->blobs_[1]->count(), pruning_coeff_,
this->blobs_[1]->mutable_cpu_data(), masks_[1]->mutable_cpu_data());
}
pruned_ = true;
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
const Dtype* weight = this->blobs_[0]->gpu_data();
Expand Down Expand Up @@ -72,6 +82,16 @@ void InnerProductLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
(Dtype)0., bottom[0]->mutable_gpu_diff());
}
}
if (pruning_coeff_ > 0) {
if (this->param_propagate_down_[0]) {
caffe_gpu_mul(this->blobs_[0]->count(), this->blobs_[0]->gpu_diff(),
masks_[0]->gpu_data(), this->blobs_[0]->mutable_gpu_diff());
}
if (bias_term_ && this->param_propagate_down_[1]) {
caffe_gpu_mul(this->blobs_[1]->count(), this->blobs_[1]->gpu_diff(),
masks_[1]->gpu_data(), this->blobs_[1]->mutable_gpu_diff());
}
}
}

INSTANTIATE_LAYER_GPU_FUNCS(InnerProductLayer);
Expand Down
8 changes: 7 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ message ParamSpec {
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
// LayerParameter next available layer-specific ID: 147 (last added: recurrent_param)
// LayerParameter next available layer-specific ID: 148 (last added: pruning_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
Expand Down Expand Up @@ -389,6 +389,7 @@ message LayerParameter {
optional PoolingParameter pooling_param = 121;
optional PowerParameter power_param = 122;
optional PReLUParameter prelu_param = 131;
optional PruningParameter pruning_param = 148;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be 147? Since in the comment above you say that the next available ID is 148.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, it should be 147.

optional PythonParameter python_param = 130;
optional RecurrentParameter recurrent_param = 146;
optional ReductionParameter reduction_param = 136;
Expand Down Expand Up @@ -915,6 +916,11 @@ message PowerParameter {
optional float shift = 3 [default = 0.0];
}

message PruningParameter {
// Pruning coefficient for deep compression
Copy link

@seanbell seanbell Jun 12, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to document what this parameter does here. The current comment adds no information. It looks like it's the fraction of weights to keep, sorted by absolute value?

optional float coeff = 1 [default = 0];
}

message PythonParameter {
optional string module = 1;
optional string layer = 2;
Expand Down
29 changes: 29 additions & 0 deletions src/caffe/util/math_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#include <boost/math/special_functions/next.hpp>
#include <boost/random.hpp>

#include <cmath>
#include <functional>
#include <limits>
#include <utility>
#include <vector>

#include "caffe/common.hpp"
#include "caffe/util/math_functions.hpp"
Expand Down Expand Up @@ -372,4 +376,29 @@ void caffe_cpu_scale<double>(const int n, const double alpha, const double *x,
cblas_dscal(n, alpha, y, 1);
}

template <typename Dtype>
void caffe_cpu_prune(const int n, const Dtype coeff, Dtype* x,
Dtype* mask) {
// Partial sort to find the %coeff lowest absolute values of x
std::vector<std::pair<Dtype, int> > indexed_x;
for (int k = 0; k < n; ++k) {
indexed_x.push_back(std::make_pair(std::abs(x[k]), k));
}
std::partial_sort(
indexed_x.begin(), indexed_x.begin() + std::floor(coeff*n),
indexed_x.end(), std::less<std::pair<Dtype, int> >());
for (int k = 0; k < std::floor(coeff * n); k++) {
x[indexed_x[k].second] = 0;
mask[indexed_x[k].second] = 0;
}
}

template
void caffe_cpu_prune<double>(const int n, const double coeff, double* x,
double* mask);

template
void caffe_cpu_prune<float>(const int n, const float coeff, float* x,
float* mask);

} // namespace caffe