Skip to content

Commit

Permalink
Integrate dataprovider into gobal workspace
Browse files Browse the repository at this point in the history
  • Loading branch information
unsky committed Aug 17, 2018
1 parent f506962 commit 7d2a1ab
Show file tree
Hide file tree
Showing 15 changed files with 836 additions and 184 deletions.
23 changes: 15 additions & 8 deletions include/sita/dataprovider/dataprovider.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ namespace sita{
template <typename Dtype>
class Batch{
public:
Batch(): _batch_size(0){}
Batch(int batch_size): _batch_size(batch_size){}
Batch(){}
~Batch(){};
inline int batch_size(){
return _batch_size;
}
private:
int _batch_size;

virtual std::string product_name(int i) = 0;
virtual Tensor<Dtype>* product(int i) = 0;
virtual int product_size() = 0;
virtual Tensor<Dtype> *data() = 0;
virtual Tensor<Dtype> *label() = 0;
};


template <typename Dtype>
class DataProviderEntry: public InternalThread{
public:
Expand All @@ -38,10 +39,11 @@ template <typename Dtype>
class DataProvider{
public:
DataProvider(std::string data_file, std::string label_file, std::vector<Dtype> means, int batch_size, int thread_num,
bool shuffle):_means(means), _batch_size(batch_size), _num_thread(thread_num){
bool shuffle, std::string type):_means(means), _batch_size(batch_size), _num_thread(thread_num), _type(type){
};
~DataProvider(){};
static const int PREFETCH_COUNT = 3;
virtual Batch<Dtype>* fetch_batch()=0;

inline int num_thread(){
return _num_thread;
Expand All @@ -52,6 +54,10 @@ class DataProvider{
inline std::vector<Dtype> * means(){
return &_means;
}
inline std::string type(){
return _type;
}

template <class RandomAccessIterator>
void shuffle_data(RandomAccessIterator begin, RandomAccessIterator end){
LOG(INFO) << "shuffling data ...";
Expand All @@ -64,6 +70,7 @@ class DataProvider{
int _num_thread;
int _batch_size;
std::vector<Dtype> _means;
std::string _type;

};

Expand Down
33 changes: 27 additions & 6 deletions include/sita/dataprovider/mnist_dataprovider.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,38 @@ namespace sita {
template <typename Dtype>
class MnistBatch : public Batch<Dtype> {
public:
MnistBatch(): Batch<Dtype>(0){}
MnistBatch(int batch_size): Batch<Dtype>(batch_size){}
~MnistBatch(){}
inline Tensor<Dtype> *data(){
MnistBatch(): Batch<Dtype>(){
_product.clear();
_product.push_back(&_data);
_product.push_back(&_label);
_product_name.clear();
_product_name.push_back("data");
_product_name.push_back("label");
}

virtual~MnistBatch(){}
virtual Tensor<Dtype> *data(){
return &_data;
}
inline Tensor<Dtype> *label(){
virtual Tensor<Dtype> *label(){
return &_label;
}
virtual std::string product_name(int i){
CHECK_GT(_product.size(), i) << "dont have that product!!";
return _product_name[i];
}
virtual Tensor<Dtype>* product(int i){
CHECK_GT(_product.size(), i) << "dont have that product!!";
return _product[i];
}
virtual int product_size(){
return _product.size();
}
private:
Tensor<Dtype> _data;
Tensor<Dtype> _label;
std::vector<std::string> _product_name;
std::vector<Tensor<Dtype>* > _product;
};


Expand Down Expand Up @@ -62,7 +82,8 @@ class MnistDataProvider: public DataProvider<Dtype>{

}

MnistBatch<Dtype> * fetch_batch();
virtual MnistBatch<Dtype> * fetch_batch();

private:
std::vector<MnistDataProviderEntry<Dtype>> _threads;
MnistBatch<Dtype> _prefetch[DataProvider<Dtype>::PREFETCH_COUNT];
Expand Down
4 changes: 2 additions & 2 deletions include/sita/dlflow/operators/convolution_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ConvolutionOp: public Operator<Dtype>{
}
~ConvolutionOp(){};
void init();
void forward(){};
void forward();
void backward(){};
bool inline has_param(){ return _has_param;}

Expand All @@ -27,4 +27,4 @@ class ConvolutionOp: public Operator<Dtype>{

};
}
#endif //CS_WORK_CONVOLUTION_H
#endif //SITA_DLFLOW_CONVOLUTION_OP_H
2 changes: 1 addition & 1 deletion include/sita/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class Tensor {
void reshape(const int num, const int channels, const int height, const int width);
void reshape_like(const Tensor<Dtype> &t_other);

void copy_from(const Tensor<Dtype> &t_other, bool reshape = true);
void copy_from(const Tensor<Dtype> *t_other, bool reshape = false);

void set_data_zero();
void set_diff_zero();
Expand Down
15 changes: 9 additions & 6 deletions include/sita/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "macros.h"
#include "sita/dlflow/graph.h"
#include "types.h"
#include "sita/dataprovider/mnist_dataprovider.h"
#include "sita/dataprovider/dataprovider.h"
#include "sita/dlflow/registry.h"
#include "sita/dlflow/operator.h"
namespace sita{
Expand Down Expand Up @@ -53,31 +55,30 @@ class GlobalWorkSpace : public WorkSpace{
//temp tensor
TempTensor<Dtype> new_tensor();
void free_tensor(TempTensor<Dtype> t);
void temp_tensor_memory();
void temp_memory();

//flow tensor
void init_input(std::string name);
void init_output(std::string name);
Tensor<Dtype>* fetch_input(std::string name);
Tensor<Dtype>* fetch_output(std::string name);
std::string flow_tensor_list();
void flow_memory();

//params
void init_param(std::string op_name, std::string op_type, std::string param_name, std::vector<int> shape,
ParamConfig p_config, bool is_shared);
Tensor<Dtype>* fetch_param(std::string op_name, std::string param_name, bool is_shared);
std::string param_list();
void param_memory();


//grap
inline void build_graph(Graph * graph){
_graph = graph;
graph_show();
};
inline void graph_show(){
_graph->graph_symbol_show();
}

void global_init();
void global_init(Graph * graph, DataProvider<Dtype> * data_provider);
void forward();
void backward();
void train();
Expand All @@ -94,6 +95,8 @@ class GlobalWorkSpace : public WorkSpace{
std::map<std::string, Tensor<Dtype> > _flow_tensor;
// params name type weight/bias name weight/bias
std::map<std::string, OperatorParam<Dtype> > _params;

DataProvider<Dtype> * _data_provider;
DISABLE_COPY_AND_ASSIGN(GlobalWorkSpace);
};

Expand Down
69 changes: 38 additions & 31 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,51 @@ int main(int argc, char** argv) {
std::string model_file = "../test.prototxt";
sita::Graph graph(model_file);

gws.build_graph(&graph);
gws.global_init();

std::vector<float> means;
means.push_back(float(0));
sita::MnistDataProvider<float > mnistdp("../data/mnist/train-images-idx3-ubyte",
"../data/mnist/train-labels-idx1-ubyte", means, 1, 1,true);
"../data/mnist/train-labels-idx1-ubyte", means, 10, 10,true);

gws.global_init(&graph, &mnistdp);

int k = 0;

while(k != 10) {
k++;
sita::TempTensor<float> t = gws.new_tensor();
sita::Tensor<float> * a = t.tensor;
a->reshape(9,4,5,6);
// gws.free_tensor(t);


sita::MnistBatch<float> * batch = mnistdp.fetch_batch();
// LOG(INFO)<<batch->label()->cpu_data()[0];

const float *blob_data = batch->data()->cpu_data();
cv::Mat cv_img_original(batch->data()->shape(2), batch->data()->shape(3), CV_32FC1);
for(int b = 0; b<batch->data()->shape(0); b++){
int offset = batch->data()->get_site_by_coord(b, 0, 0, 0);
for(int h = 0; h < batch->data()->shape(2); h++){
for(int w = 0; w < batch->data()->shape(3); w++){
float value = blob_data[offset + h*batch->data()->shape(3) + w];
cv_img_original.at<float>(h, w) = value;
// std::cout<<value;
}
// std::cout<<std::endl;
}
cv::imwrite("vis/" + std::to_string(k)+"_"+ std::to_string(b)+"__"+std::to_string(int(batch->label()->cpu_data()[b]))+ ".jpg", cv_img_original);

}
while(k != 10){
gws.train();
k++;
}
gws.temp_tensor_memory();

gws.temp_memory();
gws.flow_memory();
gws.param_memory();
return 0;
}



//
//k++;
//sita::TempTensor<float> t = gws.new_tensor();
//sita::Tensor<float> * a = t.tensor;
//a->reshape(9,4,5,6);
//// gws.free_tensor(t);
//
//
//sita::MnistBatch<float> * batch = mnistdp.fetch_batch();
//// LOG(INFO)<<batch->label()->cpu_data()[0];
//
//const float *blob_data = batch->data()->cpu_data();
//cv::Mat cv_img_original(batch->data()->shape(2), batch->data()->shape(3), CV_32FC1);
//for(int b = 0; b<batch->data()->shape(0); b++){
//int offset = batch->data()->get_site_by_coord(b, 0, 0, 0);
//for(int h = 0; h < batch->data()->shape(2); h++){
//for(int w = 0; w < batch->data()->shape(3); w++){
//float value = blob_data[offset + h*batch->data()->shape(3) + w];
//cv_img_original.at<float>(h, w) = value;
//// std::cout<<value;
//}
//// std::cout<<std::endl;
//}
//cv::imwrite("vis/" + std::to_string(k)+"_"+ std::to_string(b)+"__"+std::to_string(int(batch->label()->cpu_data()[b]))+ ".jpg", cv_img_original);
//
//}
4 changes: 2 additions & 2 deletions src/sita/dataprovider/mnist_dataprovider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace sita{
template <typename Dtype>
MnistDataProvider<Dtype>::MnistDataProvider(std::string data_file, std::string label_file,
std::vector<Dtype> means, int batch_size, int thread_num, bool shuffle):DataProvider<Dtype>(data_file,
label_file, means, batch_size, thread_num, shuffle) {
label_file, means, batch_size, thread_num, shuffle, "mnsit") {
LOG(INFO) << "loading mnist dataset using "<< thread_num <<" threads ...";
_threads.resize(thread_num);
_thread_images.resize(thread_num);
Expand Down Expand Up @@ -87,8 +87,8 @@ void MnistDataProviderEntry<Dtype>::internal_thread_entry(){
{
// Interrupted exception is expected on shutdown
}

}

template <typename Dtype>
void MnistDataProviderEntry<Dtype>::load_batch(MnistBatch<Dtype>* batch){
Dtype* data = batch->data()->mutable_cpu_data();
Expand Down
10 changes: 6 additions & 4 deletions src/sita/dlflow/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@ void Graph::graph_symbol_show(){
for(int i = 0; i < _graph.operatordef_size(); i ++) {
LOG(INFO) << "operator name: " << _graph.operatordef(i).name();
LOG(INFO) << "type: " << _graph.operatordef(i).type();
LOG(INFO) << "inputs:";
std::string input_str = "inputs: ";
for(int in = 0; in < _graph.operatordef(i).input_size(); in++ ) {
LOG(INFO) << _graph.operatordef(i).input(in);
input_str += (_graph.operatordef(i).input(in) + " ");
}
LOG(INFO) << "outputs:";
LOG(INFO) << input_str;
std::string output_str = "outputs: ";
for(int ou = 0; ou < _graph.operatordef(i).output_size(); ou++){
LOG(INFO) << _graph.operatordef(i).output(ou) << " ";
output_str += (_graph.operatordef(i).output(ou)+ " ");
}
LOG(INFO)<< output_str;
LOG(INFO)<<"-------------------------";
}
}
Expand Down
18 changes: 12 additions & 6 deletions src/sita/dlflow/operators/convolution_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@ template<typename Dtype>
void ConvolutionOp<Dtype>::init(){
// params
std::vector<int> shape;
shape.push_back(5);
shape.push_back(6);
shape.push_back(7);
shape.push_back(8);
this->init_param("convolution_weight", shape, this->_param_configs[0]);
this->init_param("convolution_bias", shape, this->_param_configs[1]);
// shape.push_back();
// shape.push_back(0);
// shape.push_back(0);
// shape.push_back(0);
// this->init_param("convolution_weight", shape, this->_param_configs[0]);
// this->init_param("convolution_bias", shape, this->_param_configs[1]);
}

template<typename Dtype>
void ConvolutionOp<Dtype>::forward(){
// Tensor<Dtype> * data = this->fetch_input(this->_inputs[0]);
// Tensor<Dtype> * add_weight = this->fetch_param("add_weight");
//LOG(INFO)<<_add_op_param.kernel_h();
};


INSTANTIATE_CLASS(ConvolutionOp);
Expand Down
Loading

0 comments on commit 7d2a1ab

Please sign in to comment.