Skip to content

Commit

Permalink
modify operator
Browse files Browse the repository at this point in the history
  • Loading branch information
unsky committed Aug 2, 2018
1 parent e7fa2b5 commit 8705f84
Show file tree
Hide file tree
Showing 17 changed files with 309 additions and 205 deletions.
325 changes: 167 additions & 158 deletions .idea/workspace.xml

Large diffs are not rendered by default.

Binary file modified build/CMakeFiles/sita.dir/main.cpp.o
Binary file not shown.
Binary file modified build/CMakeFiles/sita.dir/src/sita/stuff/graph.cpp.o
Binary file not shown.
Binary file modified build/CMakeFiles/sita.dir/src/sita/stuff/operator.cpp.o
Binary file not shown.
Binary file not shown.
Binary file modified build/CMakeFiles/sita.dir/src/sita/stuff/workspace.cpp.o
Binary file not shown.
Binary file modified build/sita
Binary file not shown.
13 changes: 12 additions & 1 deletion include/sita/stuff/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,23 @@ class Operator{
public:
Operator(const OperatorDef& opdef, GlobalWorkSpace<Dtype> *gws):_opdef(opdef),_gws(gws){}
~Operator(){}
virtual void init();
void setup();
void init_param(std::string param_name, std::vector<int> shape);

Tensor<Dtype> * fetch_input(std::string name);
Tensor<Dtype> * fetch_output(std::string name);
Tensor<Dtype> * fetch_param(std::string name);

virtual void init(){};
virtual void forward(){};
virtual void backward(){};
protected:
GlobalWorkSpace<Dtype> *_gws;
OperatorDef _opdef;
Filler _filler;
std::vector<std::string> _inputs;
std::vector<std::string> _outputs;
std::vector<std::string> _params;
};

}//namespace
Expand Down
9 changes: 1 addition & 8 deletions include/sita/stuff/operators/add_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ template<typename Dtype>
class AddOp: public Operator<Dtype>{
public:
AddOp(const OperatorDef& opdef, GlobalWorkSpace<Dtype> *gws):Operator<Dtype>(opdef,gws){
if(_has_param){
_filler = opdef.param.filler;
}
_add_op_param = opdef.param.add_op_param;
}
~AddOp(){};
void init();
Expand All @@ -26,10 +22,7 @@ class AddOp: public Operator<Dtype>{

protected:
bool _has_param = true;
Filler _filler;
AddOpParameter _add_op_param;
std::vector<std::string> _inputs;
std::vector<std::string> _outputs;


};
}//namespace
Expand Down
40 changes: 40 additions & 0 deletions include/sita/stuff/operators/data_test_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//
// Created by cs on 02/08/18.
//

#ifndef SITA_STUFF_DATA_TEST_OP_H
#define SITA_STUFF_DATA_TEST_OP_H
#include <string>
#include <vector>
#include "sita/stuff/operator.h"
#include "sita/stuff/registry.h"
namespace sita{

template<typename Dtype>
class DataTestOp: public Operator<Dtype>{
public:
DataTestOp(const OperatorDef& opdef, GlobalWorkSpace<Dtype> *gws):Operator<Dtype>(opdef,gws){
if(_has_param){
_filler = opdef.param.filler;
}
_data_test_op_param = opdef.param.data_test_op_param;
}
~DataTestOp(){};
void init();
void forward();
void backward();
bool inline has_param(){ return _has_param;}

protected:
bool _has_param = false;
Filler _filler;
AddOpParameter _data_op_param;
std::vector<std::string> _inputs;
std::vector<std::string> _outputs;

};
}//namespace



#endif //SITA_DATA_TEST_OP_H
8 changes: 6 additions & 2 deletions include/sita/stuff/sita_parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
namespace sita {

struct AddOpParameter {
int stride_h;
int stride_w;
int stride_h = 1;
int stride_w =1;
};
struct DataTestParameter{
int batch_size = 1;
};

struct SitaParameter {
AddOpParameter add_op_param;
DataTestParameter data_test_param;
Filler filler;
};
struct OperatorDef{
Expand Down
2 changes: 1 addition & 1 deletion include/sita/stuff/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "tensor.h"
namespace sita{
struct Filler{
std::string type;
std::string type = "gauss";
};

template <typename Dtype>
Expand Down
5 changes: 2 additions & 3 deletions src/sita/stuff/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ void Graph::append(std::string op_type, std::string name, std::vector<std::strin
}

void Graph::graph_symbol_show(){
LOG(INFO) << "#############################";
LOG(INFO) << "" << _graph_sym.graph_name << "结构如下";
LOG(INFO) << "graph " << _graph_sym.graph_name << ":";
LOG(INFO)<<"-------------------------";
for(int i = 0; i < _graph_sym.ops.size(); i ++) {
LOG(INFO) << "op name: " << _graph_sym.ops[i].name;
LOG(INFO) << "type: " << _graph_sym.ops[i].type;
Expand All @@ -48,7 +48,6 @@ void Graph::graph_symbol_show(){
}
LOG(INFO)<<"-------------------------";
}
LOG(INFO) << "#############################";
}


Expand Down
63 changes: 62 additions & 1 deletion src/sita/stuff/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,71 @@
namespace sita{

template<typename Dtype>
void Operator<Dtype>::init(){
void Operator<Dtype>::setup(){
//inputs and outputs
_inputs.clear();
for(int i = 0; i < _opdef.inputs.size(); i++){
_gws->init_input(_opdef.inputs[i]);
_inputs.push_back(_opdef.inputs[i]);
}
_outputs.clear();
for(int i = 0; i < _opdef.outputs.size(); i++){
_gws->init_output(_opdef.outputs[i]);
_outputs.push_back(_opdef.outputs[i]);
}
_filler = _opdef.param.filler;
_params.clear();

}

template<typename Dtype>
void Operator<Dtype>::init_param(std::string param_name, std::vector<int> shape){
_gws->init_param(_opdef.name, _opdef.type, param_name, shape, _filler);
_params.push_back(param_name);

}

template<typename Dtype>
Tensor<Dtype> * Operator<Dtype>::fetch_input(std::string name){
bool has_input = false;
for(int i = 0; i < _inputs.size(); i++)
if(_inputs[i] == name)
has_input = true;
if(has_input) {
return this->_gws->fetch_input(name);
}else{
LOG(FATAL) << "no " << name <<" in the inputs of " << _opdef.name;
}

}

template<typename Dtype>
Tensor<Dtype> * Operator<Dtype>::fetch_output(std::string name){
bool has_output = false;
for(int i = 0; i < _outputs.size(); i++)
if(_outputs[i] == name)
has_output = true;
if(has_output) {
return this->_gws->fetch_output(name);
}else{
LOG(FATAL) << "no " << name <<" in the outputs of " << _opdef.name;
}
}

template<typename Dtype>
Tensor<Dtype> * Operator<Dtype>::fetch_param(std::string name){
bool has_param = false;
for(int i = 0; i < _params.size(); i++)
if(_params[i] == name)
has_param = true;

if(has_param) {
return this->_gws->fetch_param(_opdef.name, name);;
}else{
LOG(FATAL) << "no " << name <<" in the params of " << _opdef.name;
}

}
INSTANTIATE_CLASS(Operator);

}//namespace;
24 changes: 5 additions & 19 deletions src/sita/stuff/operators/add_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,25 @@ namespace sita{
template<typename Dtype>
void AddOp<Dtype>::init(){

//inputs and outputs
_inputs.clear();
for(int i = 0; i < this->_opdef.inputs.size(); i++){
this->_gws->init_input(this->_opdef.inputs[i]);
_inputs.push_back(this->_opdef.inputs[i]);
}
_outputs.clear();
for(int i = 0; i < this->_opdef.outputs.size(); i++){
this->_gws->init_output(this->_opdef.outputs[i]);
_outputs.push_back(this->_opdef.outputs[i]);
}

// params
std::vector<int> shape;
shape.push_back(5);
shape.push_back(6);
shape.push_back(7);
shape.push_back(8);
this->_gws->init_param(this->_opdef.name,this->_opdef.type, "add_weight", shape, _filler);
this->_gws->init_param(this->_opdef.name,this->_opdef.type, "add_bias", shape, _filler);
this->init_param("add_weight", shape);
this->init_param("add_bias", shape);
}

template<typename Dtype>
void AddOp<Dtype>::forward(){
Tensor<Dtype> * data = this->_gws->fetch_input(this->_opdef.inputs[0]);
Tensor<Dtype> * add_weight = this->_gws->fetch_param(this->_opdef.name, "add_weight");
// Tensor<Dtype> * bias = this->_gws->fetch_param(this->_opdef.name(), "add_bias");
// Tensor<Dtype> * output1 = this->_gws->fetch_output("aaa");
Tensor<Dtype> * data = this->fetch_input(this->_inputs[0]);
Tensor<Dtype> * add_weight = this->fetch_param("add_weight");


};
template<typename Dtype>
void AddOp<Dtype>::backward(){
LOG(INFO)<<"NUM: ";
}
INSTANTIATE_CLASS(AddOp);
REGISTER_OPERATOR_CLASS(AddOp);
Expand Down
4 changes: 4 additions & 0 deletions src/sita/stuff/operators/data_test_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
//
// Created by cs on 02/08/18.
//

21 changes: 9 additions & 12 deletions src/sita/stuff/workspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,18 @@ void WorkSpace::device_query(){
cudaError = cudaGetDeviceCount(&deviceCount);
for (int i = 0; i < deviceCount; i++) {
cudaError = cudaGetDeviceProperties(&deviceProp, i);
LOG(INFO) << "===============================================================";
LOG(INFO) << "设备 " << i << " 的主要属性: ";
LOG(INFO) << "设备显卡型号: " << deviceProp.name;
LOG(INFO) << "设备全局内存总量(以MB为单位): " << deviceProp.totalGlobalMem / 1024 / 1024;
LOG(INFO) << "设备上一个线程块(Block)中可用的最大共享内存(以KB为单位): " << deviceProp.sharedMemPerBlock / 1024;
LOG(INFO) << "设备上一个线程块(Block)种可用的32位寄存器数量: " << deviceProp.regsPerBlock;
LOG(INFO) << "设备上一个线程块(Block)可包含的最大线程数量: " << deviceProp.maxThreadsPerBlock;
LOG(INFO) << "设备的计算功能集(Compute Capability)的版本号: " << deviceProp.major << "." << deviceProp.minor;
LOG(INFO) << "设备上多处理器的数量: " << deviceProp.multiProcessorCount;
}
LOG(INFO) << "===============================================================";
LOG(INFO) << "device " << i << " properties ";
LOG(INFO) << "device name: " << deviceProp.name;
LOG(INFO) << "Total memory: " << deviceProp.totalGlobalMem / 1024 / 1024;
LOG(INFO) << "The max Threads: " << deviceProp.maxThreadsPerBlock;
LOG(INFO) << "The Compute Capability version: " << deviceProp.major << "." << deviceProp.minor;
LOG(INFO) << "The number of multi-processor in this device: " << deviceProp.multiProcessorCount;
}
}
void WorkSpace::set_device(int gpu_id){
_gpu_id = gpu_id;
CUDA_CHECK(cudaSetDevice(_gpu_id));
LOG(INFO) << "正在使用GPU:" << _gpu_id << "进行计算...";
LOG(INFO) << "using GPU:" << _gpu_id;
}

template <typename Dtype>
Expand Down Expand Up @@ -194,6 +190,7 @@ void GlobalWorkSpace<Dtype>::global_init(){
GlobalWorkSpace<Dtype> *gws = this;
OperatorDef opdef = _graph->graph_sym()->ops[i];
boost::shared_ptr<Operator<Dtype> > op = OperatorRegistry<Dtype>::CreateOperator(opdef,gws);
op->setup();
op->init();
_ops.push_back(op);
}
Expand Down

0 comments on commit 8705f84

Please sign in to comment.