Skip to content

Commit

Permalink
hack for ps in image_rec-inl
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed May 29, 2015
1 parent fc79d65 commit 32836ea
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/io/iter_image_recordio-inl.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*!
* \file iter_image_recordio-inl.hpp
* \brief recordio data
* \brief recordio data
iterator
*/
#ifndef ITER_IMAGE_RECORDIO_INL_HPP_
Expand All @@ -25,10 +25,10 @@ iterator
namespace cxxnet {

/*! \brief data structure to hold labels for images */
class ImageLabelMap {
class ImageLabelMap {
public:
/*!
* \brief initialize the label list into memory
* \brief initialize the label list into memory
* \param path_imglist path to the image list
* \param label_width predefined label_width
*/
Expand Down Expand Up @@ -75,15 +75,15 @@ class ImageLabelMap {
= idx2label_.find(imid);
CHECK(it != idx2label_.end()) << "fail to find imagelabel for id " << imid;
return mshadow::Tensor<cpu, 1>(it->second, mshadow::Shape1(label_width_));
}
}

private:
// label with_
mshadow::index_t label_width_;
// image index of each record
std::vector<size_t> image_index_;
// real label content
std::vector<real_t> label_;
std::vector<real_t> label_;
// map index to label
std::unordered_map<size_t, real_t*> idx2label_;
};
Expand Down Expand Up @@ -180,6 +180,13 @@ inline void ImageRecordIOParser::Init(void) {
}
CHECK(path_imgrec_.length() != 0)
<< "ImageRecordIOIterator: must specify image_rec";
#if MSHADOW_DIST_PS
// TODO move to a better place
dist_num_worker_ = ::ps::RankSize();
dist_worker_rank_ = ::ps::MyRank();
LOG(INFO) << "rank " << dist_worker_rank_
<< " in " << dist_num_worker_;
#endif
source_ = dmlc::InputSplit::Create
(path_imgrec_.c_str(), dist_worker_rank_,
dist_num_worker_, "recordio");
Expand Down Expand Up @@ -227,7 +234,7 @@ ParseNext(std::vector<InstVector> *out_vec) {
rec.Load(blob.dptr, blob.size);
cv::Mat buf(1, rec.content_size, CV_8U, rec.content);
res = cv::imdecode(buf, 1);
res = augmenters_[tid]->Process(res, prnds_[tid]);
res = augmenters_[tid]->Process(res, prnds_[tid]);
out.Push(static_cast<unsigned>(rec.image_index()),
mshadow::Shape3(3, res.rows, res.cols),
mshadow::Shape1(label_width_));
Expand Down Expand Up @@ -259,7 +266,7 @@ class ImageRecordIOIterator : public IIterator<DataInst> {
rnd_.Seed(kRandMagic);
shuffle_ = 0;
}
virtual ~ImageRecordIOIterator(void) {
virtual ~ImageRecordIOIterator(void) {
iter_.Destroy();
// data can be NULL
delete data_;
Expand Down Expand Up @@ -304,7 +311,7 @@ class ImageRecordIOIterator : public IIterator<DataInst> {
for (unsigned i = 0; i < data_->size(); ++i) {
const InstVector &tmp = (*data_)[i];
for (unsigned j = 0; j < tmp.Size(); ++j) {
inst_order_.push_back(std::make_pair(i, j));
inst_order_.push_back(std::make_pair(i, j));
}
}
// shuffle instance order if needed
Expand All @@ -319,7 +326,7 @@ class ImageRecordIOIterator : public IIterator<DataInst> {
virtual const DataInst &Value(void) const {
return out_;
}

private:
// random magic
static const int kRandMagic = 111;
Expand Down

0 comments on commit 32836ea

Please sign in to comment.