Skip to content

Commit

Permalink
fix bug in temp tensor strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
unsky committed Aug 16, 2018
1 parent 3ef31f3 commit cbf2340
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 40 deletions.
7 changes: 7 additions & 0 deletions include/sita/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@
#include "sita/proto/sita.h"
namespace sita{

const int PRE_TEMP_TENSOR_NUM = 64;
template <typename Dtype>
struct TempTensor{
int key;
Tensor<Dtype> * tensor;
};

template <typename Dtype>
struct OperatorParam{
std::string type;
std::map<std::string, Tensor<Dtype> > params;
std::map<std::string, FillerParameter> fillers;
std::map<std::string, bool> is_inited;
};

};
Expand Down
28 changes: 11 additions & 17 deletions include/sita/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,30 @@ template <typename Dtype>
class GlobalWorkSpace : public WorkSpace{

public:
GlobalWorkSpace():WorkSpace(){};
GlobalWorkSpace():WorkSpace(){
_temp_tensor_control.resize(PRE_TEMP_TENSOR_NUM);
_temp_tensor.resize(PRE_TEMP_TENSOR_NUM);
_temp_tensor_num = 0;
};
~GlobalWorkSpace(){};

//temp tensor
std::pair<int, Tensor<Dtype> * > fetch_temp_tensor();

void release_temp_tensor(int released_id);

float temp_tensor_memory_size();
TempTensor<Dtype> new_tensor();
void free_tensor(TempTensor<Dtype> t);
void temp_tensor_memory();

//flow tensor
void init_input(std::string name);

void init_output(std::string name);

Tensor<Dtype>* fetch_input(std::string name);

Tensor<Dtype>* fetch_output(std::string name);

std::string flow_tensor_list();


//params
void init_param(std::string op_name, std::string op_type, std::string param_name, std::vector<int> shape, Filler);

void init_param(std::string op_name, std::string op_type, std::string param_name, std::vector<int> shape, FillerParameter filler);
Tensor<Dtype>* fetch_param(std::string op_name, std::string param_name);

std::string param_list();



//grap
inline void build_graph(Graph * graph){
_graph = graph;
Expand All @@ -92,12 +85,13 @@ class GlobalWorkSpace : public WorkSpace{
// temp_tensor bool true: using false:released
std::vector<std::pair<Tensor<Dtype> *, bool> > _temp_tensor_control;
std::vector<Tensor<Dtype> > _temp_tensor;
int _temp_tensor_num;
//graph
Graph * _graph;
std::vector<boost::shared_ptr<Operator<Dtype> > > _ops;
// input/output name
std::map<std::string, Tensor<Dtype> > _flow_tensor;
// params name type weight/bias name weight/bias
// params name type weight/bias name weight/bias
std::map<std::string, OperatorParam<Dtype> > _params;
DISABLE_COPY_AND_ASSIGN(GlobalWorkSpace);
};
Expand Down
12 changes: 9 additions & 3 deletions main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,14 @@ int main(int argc, char** argv) {

int k = 0;

while(k != 40) {
while(k != 10) {
k++;
sita::TempTensor<float> t = gws.new_tensor();
sita::Tensor<float> * a = t.tensor;
a->reshape(9,4,5,6);
// gws.free_tensor(t);


sita::MnistBatch<float> * batch = mnistdp.fetch_batch();
// LOG(INFO)<<batch->label()->cpu_data()[0];

Expand All @@ -43,11 +49,11 @@ int main(int argc, char** argv) {
// std::cout<<std::endl;
}
cv::imwrite("vis/" + std::to_string(k)+"_"+ std::to_string(b)+"__"+std::to_string(int(batch->label()->cpu_data()[b]))+ ".jpg", cv_img_original);
// LOG(INFO) << gws.temp_tensor_memory_size();

}
gws.train();
}

gws.temp_tensor_memory();

return 0;
}
51 changes: 31 additions & 20 deletions src/sita/workspace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,58 @@ void WorkSpace::set_device(int gpu_id){
}

template <typename Dtype>
std::pair<int, Tensor<Dtype> * > GlobalWorkSpace<Dtype>::fetch_temp_tensor() {
if (_temp_tensor.size() == 0) {
TempTensor<Dtype> GlobalWorkSpace<Dtype>::new_tensor() {
TempTensor<Dtype> t;
if (_temp_tensor_num == 0) {
Tensor<Dtype> temp_tensor;
_temp_tensor.push_back(temp_tensor);
_temp_tensor_control.push_back(std::make_pair(&(_temp_tensor[0]), true));
return std::make_pair(0, &(_temp_tensor[0]));
_temp_tensor[0] = temp_tensor;
_temp_tensor_control[0] = std::make_pair(&(_temp_tensor[0]), true);
_temp_tensor_num++;
t.key = 0;
t.tensor = &(_temp_tensor[0]);
return t;
} else {
int released_id = -1;
for (int i = 0; i < _temp_tensor.size(); i++) {
for (int i = 0; i < _temp_tensor_num; i++) {
if (_temp_tensor_control[i].second == false) {
released_id = i;
}
}
if (released_id == -1) {
Tensor<Dtype> temp_tensor;
_temp_tensor.push_back(temp_tensor);
_temp_tensor_control.push_back(std::make_pair(&(_temp_tensor[_temp_tensor.size() - 1]), true));
return std::make_pair(0, &(_temp_tensor[_temp_tensor.size()-1]));

_temp_tensor[_temp_tensor_num] = temp_tensor;
_temp_tensor_control[_temp_tensor_num]= std::make_pair(&(_temp_tensor[_temp_tensor_num]), true);
t.key = 0;
t.tensor = &(_temp_tensor[_temp_tensor_num]);
_temp_tensor_num++;
return t;
} else {
_temp_tensor_control[released_id].second = true;
return std::make_pair(released_id, &_temp_tensor[released_id]);
t.key = released_id;
t.tensor = &_temp_tensor[released_id];
return t;
}
}
}

template <typename Dtype>
void GlobalWorkSpace<Dtype>::release_temp_tensor(int released_id) {
_temp_tensor_control[released_id].second = false;
void GlobalWorkSpace<Dtype>::free_tensor(TempTensor<Dtype> t) {
_temp_tensor_control[t.key].second = false;
}

template <typename Dtype>
float GlobalWorkSpace<Dtype>::temp_tensor_memory_size(){
int memory_size = 0;
for(int i = 0; i < _temp_tensor.size(); i++){
void GlobalWorkSpace<Dtype>::temp_tensor_memory(){
float memory_size = 0;
for(int i = 0; i < _temp_tensor_num; i++){
memory_size += (_temp_tensor[i].count() * sizeof(Dtype));
}
return memory_size/(1024 * 1024 * 8);
LOG(INFO) << "the fact of temp tensor being used: "<<"[number:" << _temp_tensor_num << " " << "memory size: " <<
std::to_string(memory_size/(1024 * 8))<<" KB].";

return;
}

// flow tensor
//std::map<std::string, Tensor<Dtype> > _flow_tensor;
template <typename Dtype>
void GlobalWorkSpace<Dtype>::init_input(std::string name){
bool has_flow_tensor = false;
Expand Down Expand Up @@ -115,7 +125,7 @@ Tensor<Dtype>* GlobalWorkSpace<Dtype>::fetch_output(std::string name){
return &(it->second);
}
}
LOG(FATAL) << "no this onput in flow tensors, do you have init it?" << flow_tensor_list();
LOG(FATAL) << "no this output in flow tensors, do you have init it?" << flow_tensor_list();
}

template <typename Dtype>
Expand All @@ -129,7 +139,7 @@ std::string GlobalWorkSpace<Dtype>::flow_tensor_list(){
}

template <typename Dtype>
void GlobalWorkSpace<Dtype>::init_param(std::string op_name, std::string op_type, std::string param_name, std::vector<int> shape, Filler filler){
void GlobalWorkSpace<Dtype>::init_param(std::string op_name, std::string op_type, std::string param_name, std::vector<int> shape, FillerParameter filler){
bool has_param = false;
bool has_op_name = false;
for(auto i = _params.begin(); i != _params.end(); i++){
Expand All @@ -153,6 +163,7 @@ void GlobalWorkSpace<Dtype>::init_param(std::string op_name, std::string op_type
Tensor<Dtype> t(shape);
_params[op_name].params[param_name] = t;
_params[op_name].fillers[param_name] = filler;
_params[op_name].is_inited[param_name] = false;

}
}
Expand Down

0 comments on commit cbf2340

Please sign in to comment.