Skip to content

Commit

Permalink
Merge pull request #5713 from Noiredd/filler
Browse files Browse the repository at this point in the history
fix bilinear filler (and make constant filler more strict, as it should be)
  • Loading branch information
shelhamer authored Oct 3, 2017
2 parents effcdb0 + 888597e commit ef2eb4b
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
6 changes: 3 additions & 3 deletions include/caffe/filler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,10 +250,10 @@ class BilinearFiller : public Filler<Dtype> {
CHECK_EQ(blob->width(), blob->height()) << "Filter must be square";
Dtype* data = blob->mutable_cpu_data();
int f = ceil(blob->width() / 2.);
float c = (2 * f - 1 - f % 2) / (2. * f);
Dtype c = (blob->width() - 1) / (2. * f);
for (int i = 0; i < blob->count(); ++i) {
float x = i % blob->width();
float y = (i / blob->width()) % blob->height();
Dtype x = i % blob->width();
Dtype y = (i / blob->width()) % blob->height();
data[i] = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
}
CHECK_EQ(this->filler_param_.sparse(), -1)
Expand Down
43 changes: 42 additions & 1 deletion src/caffe/test/test_filler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ TYPED_TEST(ConstantFillerTest, TestFill) {
const int count = this->blob_->count();
const TypeParam* data = this->blob_->cpu_data();
for (int i = 0; i < count; ++i) {
EXPECT_GE(data[i], this->filler_param_.value());
EXPECT_EQ(data[i], this->filler_param_.value());
}
}

Expand Down Expand Up @@ -238,4 +238,45 @@ TYPED_TEST(MSRAFillerTest, TestFillAverage) {
this->test_params(FillerParameter_VarianceNorm_AVERAGE, n);
}

template <typename Dtype>
class BilinearFillerTest : public ::testing::Test {
protected:
BilinearFillerTest() : filler_param_() {}
virtual void test_params(const int n) {
this->blob_ = new Blob<Dtype>(1000, 2, n, n);
this->filler_.reset(new BilinearFiller<Dtype>(this->filler_param_));
this->filler_->Fill(blob_);
EXPECT_TRUE(this->blob_);
const int outer_num = this->blob_->count(0, 2);
const int inner_num = this->blob_->count(2, 4);
const Dtype* data = this->blob_->cpu_data();
int f = ceil(this->blob_->width() / 2.);
Dtype c = (this->blob_->width() - 1) / (2. * f);
for (int i = 0; i < outer_num; ++i) {
for (int j = 0; j < inner_num; ++j) {
Dtype x = j % this->blob_->width();
Dtype y = (j / this->blob_->width()) % this->blob_->height();
Dtype expected_value = (1 - fabs(x / f - c)) * (1 - fabs(y / f - c));
const Dtype actual_value = data[i * inner_num + j];
EXPECT_NEAR(expected_value, actual_value, 0.01);
}
}
}
virtual ~BilinearFillerTest() { delete blob_; }
Blob<Dtype>* blob_;
FillerParameter filler_param_;
shared_ptr<BilinearFiller<Dtype> > filler_;
};

TYPED_TEST_CASE(BilinearFillerTest, TestDtypes);

TYPED_TEST(BilinearFillerTest, TestFillOdd) {
const int n = 7;
this->test_params(n);
}
TYPED_TEST(BilinearFillerTest, TestFillEven) {
const int n = 6;
this->test_params(n);
}

} // namespace caffe

0 comments on commit ef2eb4b

Please sign in to comment.