From ec705918e7bea8430692895edd395e2c9084d112 Mon Sep 17 00:00:00 2001 From: unsky <2081264@qq.com> Date: Tue, 25 Sep 2018 20:37:40 +0800 Subject: [PATCH] cudnn convolution support --- CMakeLists.txt | 7 +- include/sita/context.h | 28 +- include/sita/dlflow/operator.h | 2 + include/sita/dlflow/operators/add_op.h | 31 - .../{convolution_op.h => convolution.h} | 18 +- include/sita/dlflow/operators/data_test_op.h | 40 - include/sita/macros.h | 12 + include/sita/tensor.h | 3 + include/sita/workspace.h | 1 - src/sita/context.cpp | 19 + src/sita/dlflow/operator.cpp | 2 +- src/sita/dlflow/operators/add_op.cpp | 30 - src/sita/dlflow/operators/convolution.cpp | 334 ++++++ src/sita/dlflow/operators/convolution.cu | 123 ++ src/sita/dlflow/operators/convolution_op.cpp | 157 --- src/sita/dlflow/operators/data_test_op.cpp | 4 - src/sita/proto/sita_operators.pb.cc | 1008 ++++------------- src/sita/proto/sita_operators.pb.h | 820 ++++---------- src/sita/proto/sita_operators.proto | 40 +- src/sita/workspace.cpp | 15 +- test.prototxt | 28 +- tools/{main.cpp => sita.cpp} | 8 +- 22 files changed, 1005 insertions(+), 1725 deletions(-) delete mode 100644 include/sita/dlflow/operators/add_op.h rename include/sita/dlflow/operators/{convolution_op.h => convolution.h} (74%) delete mode 100644 include/sita/dlflow/operators/data_test_op.h create mode 100644 src/sita/context.cpp delete mode 100644 src/sita/dlflow/operators/add_op.cpp create mode 100644 src/sita/dlflow/operators/convolution.cpp create mode 100644 src/sita/dlflow/operators/convolution.cu delete mode 100644 src/sita/dlflow/operators/convolution_op.cpp delete mode 100644 src/sita/dlflow/operators/data_test_op.cpp rename tools/{main.cpp => sita.cpp} (94%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 99aeb354..15830d16 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -44,11 +44,12 @@ execute_process(COMMAND protoc --cpp_out=${PROJECT_SOURCE_DIR}/src/sita/proto/ set(SITA_FILES src/sita/proto/sita_operators.pb.cc src/sita/proto/sita_utils.pb.cc - tools/main.cpp + tools/sita.cpp src/sita/memory_control.cpp src/sita/tensor.cpp - # src/sita/dlflow/operators/add_op.cpp - src/sita/dlflow/operators/convolution_op.cpp + src/sita/context.cpp + src/sita/dlflow/operators/convolution.cpp + src/sita/dlflow/operators/convolution.cu src/sita/workspace.cpp src/sita/dlflow/operator.cpp src/sita/dlflow/graph.cpp diff --git a/include/sita/context.h b/include/sita/context.h index 75bcab9f..b19cacc9 100644 --- a/include/sita/context.h +++ b/include/sita/context.h @@ -12,25 +12,21 @@ namespace sita{ -template class dataType; -template<> class dataType { +template class CudnnDataType; +template<> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - static float oneval = 1.0; - static float zeroval = 0.0; - static const void *one = static_cast(&dataType::oneval); - static const void *zero = static_cast(&dataType::zeroval); + static float oneval, zeroval; + static const void *one, *zero; }; -template<> class dataType { +template<> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - static double oneval = 1.0; - static double zeroval = 0.0; - static const void *one = static_cast(&dataType::oneval); - static const void *zero = static_cast(&dataType::zeroval); + static double oneval, zeroval; + static const void *one, *zero; }; - class Context{ +class Context{ public: Context() {} ~Context() {} @@ -111,7 +107,7 @@ template<> class dataType { inline static void set_tensor4d_descriptor(cudnnTensorDescriptor_t* desc, int n, int c, int h, int w, int stride_n, int stride_c, int stride_h, int stride_w) { - CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType::type, + CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CudnnDataType::type, n, c, h, w, stride_n, stride_c, stride_h, stride_w)); } @@ -132,10 +128,10 @@ template<> class dataType { CUDNN_CHECK(cudnnCreateFilterDescriptor(desc)); #if CUDNN_VERSION_MIN(5, 0, 0) - CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType::type, + CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, CudnnDataType::type, CUDNN_TENSOR_NCHW, n, c, h, w)); #else - CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType::type, + CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, CudnnDataType::type, CUDNN_TENSOR_NCHW, n, c, h, w)); #endif } @@ -152,7 +148,7 @@ template<> class dataType { #if CUDNN_VERSION_MIN(6, 0, 0) CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION, - dataType::type)); + CudnnDataType::type)); #else CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv, pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION)); diff --git a/include/sita/dlflow/operator.h b/include/sita/dlflow/operator.h index 4cd3c6f8..8c87620f 100644 --- a/include/sita/dlflow/operator.h +++ b/include/sita/dlflow/operator.h @@ -43,8 +43,10 @@ class Operator{ std::vector _params; std::map > _input_shapes; std::map > _output_shapes; + std::map > _param_shapes; bool _is_shared; std::vector > _shared_param_pairs; + bool _gradient_block; }; }//namespace diff --git a/include/sita/dlflow/operators/add_op.h b/include/sita/dlflow/operators/add_op.h deleted file mode 100644 index 1efac297..00000000 --- a/include/sita/dlflow/operators/add_op.h +++ /dev/null @@ -1,31 +0,0 @@ -//--------------------------------- -//write by unsky -//--------------------------------- -#ifndef SITA_STUFF_OPERATORS_ADD_OP_H -#define SITA_STUFF_OPERATORS_ADD_OP_H -#include -#include -#include "sita/dlflow/operator.h" -#include "sita/dlflow/registry.h" -namespace sita{ - -template -class AddOp: public Operator{ -public: - AddOp(const OperatorParameter& opdef, GlobalWorkSpace *gws):Operator(opdef,gws){ - _add_op_param = opdef.add_op_param(); - } - ~AddOp(){}; - void init(); - void forward(); - void backward(); - bool inline has_param(){ return _has_param;} - -protected: - bool _has_param = true; - AddOpParameter _add_op_param; - - -}; -}//namespace -#endif //SITA_STUFF_OPERATORS_ADD_OP_H diff --git a/include/sita/dlflow/operators/convolution_op.h b/include/sita/dlflow/operators/convolution.h similarity index 74% rename from include/sita/dlflow/operators/convolution_op.h rename to include/sita/dlflow/operators/convolution.h index 12f0f0b3..79f2af65 100644 --- a/include/sita/dlflow/operators/convolution_op.h +++ b/include/sita/dlflow/operators/convolution.h @@ -2,29 +2,29 @@ // Created by unsky on 15/08/18. // -#ifndef SITA_DLFLOW_CONVOLUTION_OP_H -#define SITA_DLFLOW_CONVOLUTION_OP_H +#ifndef SITA_DLFLOW_CONVOLUTION_H +#define SITA_DLFLOW_CONVOLUTION_H #include "sita/dlflow/operator.h" #include "sita/proto/sita.h" namespace sita{ template -class ConvolutionOp: public Operator{ +class Convolution: public Operator{ public: - ConvolutionOp(const OperatorParameter& opdef, GlobalWorkSpace *gws):Operator(opdef,gws){ - _op_param = opdef.convolution_op_param(); + Convolution(const OperatorParameter& opdef, GlobalWorkSpace *gws):Operator(opdef,gws){ + _op_param = opdef.convolution_param(); } - ~ConvolutionOp(){}; + ~Convolution(); void init(); void infer_shape(); void forward(); - void backward(){}; + void backward(); bool inline has_param(){ return _has_param;} protected: bool _has_param = true; - ConvolutionOpParameter _op_param; + ConvolutionParameter _op_param; private: bool _handles_setup; @@ -54,4 +54,4 @@ class ConvolutionOp: public Operator{ }; } -#endif //SITA_DLFLOW_CONVOLUTION_OP_H +#endif //SITA_DLFLOW_CONVOLUTION_H diff --git a/include/sita/dlflow/operators/data_test_op.h b/include/sita/dlflow/operators/data_test_op.h deleted file mode 100644 index d28306c5..00000000 --- a/include/sita/dlflow/operators/data_test_op.h +++ /dev/null @@ -1,40 +0,0 @@ -// -// Created by cs on 02/08/18. -// - -#ifndef SITA_STUFF_DATA_TEST_OP_H -#define SITA_STUFF_DATA_TEST_OP_H -#include -#include -#include "sita/stuff/operator.h" -#include "sita/stuff/registry.h" -namespace sita{ - - template - class DataTestOp: public Operator{ - public: - DataTestOp(const OperatorDef& opdef, GlobalWorkSpace *gws):Operator(opdef,gws){ - if(_has_param){ - _filler = opdef.param.filler; - } - _data_test_op_param = opdef.param.data_test_op_param; - } - ~DataTestOp(){}; - void init(); - void forward(); - void backward(); - bool inline has_param(){ return _has_param;} - - protected: - bool _has_param = false; - Filler _filler; - AddOpParameter _data_op_param; - std::vector _inputs; - std::vector _outputs; - - }; -}//namespace - - - -#endif //SITA_DATA_TEST_OP_H diff --git a/include/sita/macros.h b/include/sita/macros.h index 22bdd20d..d656099f 100644 --- a/include/sita/macros.h +++ b/include/sita/macros.h @@ -80,5 +80,17 @@ private:\ template class classname; \ + +#define INSTANTIATE_OPERATOR_GPU_FORWARD(classname) \ + template void classname::forward(); \ + template void classname::forward(); + +#define INSTANTIATE_OPERATOR_GPU_BACKWARD(classname) \ + template void classname::backward(); \ + template void classname::backward() + +#define INSTANTIATE_OPERATOR_GPU_FUNCS(classname) \ + INSTANTIATE_OPERATOR_GPU_FORWARD(classname); \ + INSTANTIATE_OPERATOR_GPU_BACKWARD(classname) }//namespace #endif //SITA_MACROS_H diff --git a/include/sita/tensor.h b/include/sita/tensor.h index db3cff35..dc9572ab 100644 --- a/include/sita/tensor.h +++ b/include/sita/tensor.h @@ -77,6 +77,9 @@ class Tensor { inline const int dim() const { return _dim; } + inline const int size(){ + return _count * sizeof(Dtype); + } inline const int count() const { return _count; diff --git a/include/sita/workspace.h b/include/sita/workspace.h index b8cfc7dd..6ad30d61 100644 --- a/include/sita/workspace.h +++ b/include/sita/workspace.h @@ -79,7 +79,6 @@ class GlobalWorkSpace : public WorkSpace{ } void global_init(Graph * graph, DataProvider * data_provider); - void infer_shape(); void forward(); void backward(); void train(); diff --git a/src/sita/context.cpp b/src/sita/context.cpp new file mode 100644 index 00000000..fb6c0c61 --- /dev/null +++ b/src/sita/context.cpp @@ -0,0 +1,19 @@ +#include "sita/context.h" + +namespace sita{ + +float CudnnDataType::oneval = 1.0; +float CudnnDataType::zeroval = 0.0; +const void* CudnnDataType::one = + static_cast(&CudnnDataType::oneval); +const void* CudnnDataType::zero = + static_cast(&CudnnDataType::zeroval); + +double CudnnDataType::oneval = 1.0; +double CudnnDataType::zeroval = 0.0; +const void* CudnnDataType::one = + static_cast(&CudnnDataType::oneval); +const void* CudnnDataType::zero = + static_cast(&CudnnDataType::zeroval); + +} \ No newline at end of file diff --git a/src/sita/dlflow/operator.cpp b/src/sita/dlflow/operator.cpp index d049fd71..9bb9862b 100644 --- a/src/sita/dlflow/operator.cpp +++ b/src/sita/dlflow/operator.cpp @@ -24,8 +24,8 @@ void Operator::setup(){ _param_configs.push_back(_opdef.param(i)); } _is_shared = false; + _gradient_block = _opdef.gradient_block(); _shared_param_pairs.clear(); - } template diff --git a/src/sita/dlflow/operators/add_op.cpp b/src/sita/dlflow/operators/add_op.cpp deleted file mode 100644 index 4a72c62b..00000000 --- a/src/sita/dlflow/operators/add_op.cpp +++ /dev/null @@ -1,30 +0,0 @@ -//--------------------------------- -//write by unsky -//--------------------------------- -#include "sita/dlflow/operators/add_op.h" -namespace sita{ - -template -void AddOp::init(){ - // params - std::vector shape; - shape.push_back(5); - shape.push_back(6); - shape.push_back(7); - shape.push_back(8); - this->init_param("add_weight", shape); - this->init_param("add_bias", shape); -} - -template -void AddOp::forward(){ - Tensor * data = this->fetch_input(this->_inputs[0]); - Tensor * add_weight = this->fetch_param("add_weight"); - //LOG(INFO)<<_add_op_param.kernel_h(); -}; -template -void AddOp::backward(){ -} -INSTANTIATE_CLASS(AddOp); -REGISTER_OPERATOR_CLASS(AddOp); -}//namespace diff --git a/src/sita/dlflow/operators/convolution.cpp b/src/sita/dlflow/operators/convolution.cpp new file mode 100644 index 00000000..9bb4110a --- /dev/null +++ b/src/sita/dlflow/operators/convolution.cpp @@ -0,0 +1,334 @@ +// +// Created by unsky on 16/08/18. +// +#include "sita/dlflow/operators/convolution.h" +namespace sita{ +template +Convolution::~Convolution(){ + // Check that handles have been setup before destroying. + if (!_handles_setup) { return; } + + for (int i = 0; i < _input_descs.size(); i++) { + cudnnDestroyTensorDescriptor(_input_descs[i]); + cudnnDestroyTensorDescriptor(_output_descs[i]); + cudnnDestroyConvolutionDescriptor(_conv_descs[i]); + } + if (this->_op_param.bias_term()) { + cudnnDestroyTensorDescriptor(_bias_desc); + } + cudnnDestroyFilterDescriptor(_filter_desc); + + for (int g = 0; g < this->_op_param.group() * CUDNN_STREAMS_PER_GROUP; g++) { + cudaStreamDestroy(_stream[g]); + cudnnDestroy(_handle[g]); + } + + cudaFree(workspaceData); + delete [] workspace; + delete [] _stream; + delete [] _handle; + delete [] _fwd_algo; + delete [] _bwd_filter_algo; + delete [] _bwd_data_algo; + delete [] _workspace_fwd_sizes; + delete [] _workspace_bwd_data_sizes; + delete [] _workspace_bwd_filter_sizes; + +} + +template +void Convolution::init(){ + + + CHECK_EQ(this->_inputs.size(), this->_outputs.size()) << "input size should equal to output size"; + LOG(INFO) << "Inputs:"; + //Initialize param input and outputs + for(int i = 0; i < this->_inputs.size(); i++){ + Tensor *input = this->fetch_input(this->_inputs[i]); + this->_input_shapes[this->_inputs[i]] = input->shape(); + CHECK_GT(input->count(), 0) << "check your graph, cannot infer " << this->_inputs[i] << " shape,in " << this->operator_name()<<"!!"; + LOG(INFO) << this->_inputs[i]<<": "<< this->fetch_input(this->_inputs[i])->shape_string(); + } + + LOG(INFO) << "Params:"; + int kernel_h; + int kernel_w; + if(this->_op_param.has_kernel_h() && this->_op_param.has_kernel_w()){ + kernel_h = this->_op_param.kernel_h(); + kernel_w = this->_op_param.kernel_w(); + }else{ + kernel_h = this->_op_param.kernel_size(); + kernel_w = this->_op_param.kernel_size(); + } + std::vector weight_shape; + weight_shape.push_back(this->_op_param.num_output()); + weight_shape.push_back(this->_input_shapes[this->_inputs[0]][1]/this->_op_param.group()); + weight_shape.push_back(kernel_h); + weight_shape.push_back(kernel_w); + this->init_param("convolution_weight",weight_shape,this->_param_configs[0]); + Tensor * weight = this->fetch_param("convolution_weight"); + LOG(INFO) << "convolution_weight:" <shape_string(); + this->_param_shapes["convolution_weight"] = weight_shape; + if(this->_op_param.bias_term()) { + std::vector bias_shape; + bias_shape.push_back(this->_op_param.num_output()); + this->init_param("convolution_bias", bias_shape,this->_param_configs[1]); + + Tensor * bias = this->fetch_param("convolution_bias"); + LOG(INFO) << "convolution_bias:" <shape_string(); + this->_param_shapes["convolution_bias"] = weight_shape; + } + LOG(INFO) << "Outputs:"; + for(int i = 0; i < this->_outputs.size(); i++){ + std::vector output_shape; + output_shape.push_back(this->_input_shapes[this->_inputs[i]][0]); + output_shape.push_back(this->_op_param.num_output()); + + int in_height = this->_input_shapes[this->_inputs[i]][2]; + int in_width = this->_input_shapes[this->_inputs[i]][3]; + + int pad_h, pad_w, stride_h, stride_w; + + if(this->_op_param.has_pad_h() && this->_op_param.has_pad_w()){ + pad_h = int(this->_op_param.pad_h()); + pad_w = int(this->_op_param.pad_w()); + }else{ + pad_h = this->_op_param.pad(); + pad_w = this->_op_param.pad(); + } + + if(this->_op_param.has_stride_w() && this->_op_param.has_stride_h()){ + stride_h = int(this->_op_param.stride_h()); + stride_w = int(this->_op_param.stride_w()); + }else{ + stride_h = this->_op_param.stride(); + stride_w = this->_op_param.stride(); + } + int out_height = (in_height + 2 * pad_h - kernel_h) / stride_h + 1; + int out_width = (in_width + 2* pad_w - kernel_w) / stride_w + 1; + output_shape.push_back(out_height); + output_shape.push_back(out_width); + + Tensor *output_data = this->fetch_output(this->_outputs[i]); + output_data->reshape(output_shape); + this->_output_shapes[this->_outputs[i]] = output_shape; + LOG(INFO) << this->_outputs[i]<<": "<< this->fetch_output(this->_outputs[i])->shape_string(); + } + + // Initialize CUDA streams and cuDNN. + _stream = new cudaStream_t[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; + _handle = new cudnnHandle_t[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; + + // Initialize algorithm arrays + _fwd_algo = new cudnnConvolutionFwdAlgo_t[this->_inputs.size()]; + _bwd_filter_algo = new cudnnConvolutionBwdFilterAlgo_t[this->_inputs.size()]; + _bwd_data_algo = new cudnnConvolutionBwdDataAlgo_t[this->_inputs.size()]; + + // initialize size arrays + _workspace_fwd_sizes = new size_t[this->_inputs.size()]; + _workspace_bwd_filter_sizes = new size_t[this->_inputs.size()]; + _workspace_bwd_data_sizes = new size_t[this->_inputs.size()]; + + // workspace data + workspaceSizeInBytes = 0; + workspaceData = NULL; + workspace = new void*[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; + + for (size_t i = 0; i < this->_inputs.size(); ++i) { + // initialize all to default algorithms + _fwd_algo[i] = (cudnnConvolutionFwdAlgo_t)0; + _bwd_filter_algo[i] = (cudnnConvolutionBwdFilterAlgo_t)0; + _bwd_data_algo[i] = (cudnnConvolutionBwdDataAlgo_t)0; + // default algorithms don't require workspace + _workspace_fwd_sizes[i] = 0; + _workspace_bwd_data_sizes[i] = 0; + _workspace_bwd_filter_sizes[i] = 0; + } + for (int g = 0; g < this->_op_param.group() * CUDNN_STREAMS_PER_GROUP; g++) { + CUDA_CHECK(cudaStreamCreate(&_stream[g])); + CUDNN_CHECK(cudnnCreate(&_handle[g])); + CUDNN_CHECK(cudnnSetStream(_handle[g], _stream[g])); + workspace[g] = NULL; + } + + // Create filter descriptor. + Context::create_filter_descriptor(&_filter_desc, + this->_op_param.num_output()/this->_op_param.group(), weight_shape[1], + kernel_h, kernel_w); + + // Create tensor descriptor(s) for data and corresponding convolution(s). + for (int i = 0; i < this->_inputs.size(); i++) { + cudnnTensorDescriptor_t input_desc; + Context::create_tensor4d_descriptor(&input_desc); + _input_descs.push_back(input_desc); + + cudnnTensorDescriptor_t output_desc; + Context::create_tensor4d_descriptor(&output_desc); + _output_descs.push_back(output_desc); + + cudnnConvolutionDescriptor_t conv_desc; + Context::create_convolution_descriptor(&conv_desc); + _conv_descs.push_back(conv_desc); + } + + // Tensor descriptor for bias. + if (this->_op_param.bias_term()) { + Context::create_tensor4d_descriptor(&_bias_desc); + } + _handles_setup = true; + + infer_shape(); +} + +template +void Convolution::infer_shape() { + size_t workspace_limit_bytes = 8*1024*1024; + + for (int i = 0; i < this->_inputs.size(); i++) { + Context::set_tensor4d_descriptor(&_input_descs[i], + this->_input_shapes[this->_inputs[i]][0], + this->_input_shapes[this->_inputs[i]][1] / this->_op_param.group(), + this->_input_shapes[this->_inputs[i]][2], this->_input_shapes[this->_inputs[i]][3], + this->_input_shapes[this->_inputs[i]][1] * this->_input_shapes[this->_inputs[i]][2] * this->_input_shapes[this->_inputs[i]][3], + this->_input_shapes[this->_inputs[i]][2] * this->_input_shapes[this->_inputs[i]][3], + this->_input_shapes[this->_inputs[i]][3], 1); + Context::set_tensor4d_descriptor(&_output_descs[i], + this->_input_shapes[this->_inputs[i]][0], + this->_output_shapes[this->_outputs[i]][1] / this->_op_param.group(), + this->_output_shapes[this->_outputs[i]][2], this->_output_shapes[this->_outputs[i]][3], + this->_output_shapes[this->_outputs[i]][1]* this->_output_shapes[this->_outputs[i]][2] * this->_output_shapes[this->_outputs[i]][3], + this->_output_shapes[this->_outputs[i]][2] * this->_output_shapes[this->_outputs[i]][3], + this->_output_shapes[this->_outputs[i]][3], 1); + + int pad_h, pad_w, stride_h, stride_w; + if(this->_op_param.has_pad_h() && this->_op_param.has_pad_w()){ + pad_h = int(this->_op_param.pad_h()); + pad_w = int(this->_op_param.pad_w()); + }else{ + pad_h = this->_op_param.pad(); + pad_w = this->_op_param.pad(); + } + + if(this->_op_param.has_stride()){ + stride_h = int(this->_op_param.stride()); + stride_w = int(this->_op_param.stride()); + }else{ + stride_h = this->_op_param.stride_h(); + stride_w = this->_op_param.stride_w(); + } + + Context::set_convolution_descriptor(&_conv_descs[i], _input_descs[i], + _filter_desc, pad_h, pad_w, + stride_h, stride_w); + + // choose forward and backward algorithms + workspace(s) + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(_handle[0], + _input_descs[i], + _filter_desc, + _conv_descs[i], + _output_descs[i], + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, + &_fwd_algo[i])); + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(_handle[0], + _input_descs[i], + _filter_desc, + _conv_descs[i], + _output_descs[i], + _fwd_algo[i], + &(_workspace_fwd_sizes[i]))); + + // choose backward algorithm for filter + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(_handle[0], + _input_descs[i], _output_descs[i], _conv_descs[i], _filter_desc, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, &_bwd_filter_algo[i]) ); + + // get workspace for backwards filter algorithm + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(_handle[0], + _input_descs[i], _output_descs[i], _conv_descs[i], _filter_desc, + _bwd_filter_algo[i], &_workspace_bwd_filter_sizes[i])); + + // choose backward algo for data + CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(_handle[0], + _filter_desc, _output_descs[i], _conv_descs[i], _input_descs[i], + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, &_bwd_data_algo[i])); + + // get workspace size + CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(_handle[0], + _filter_desc, _output_descs[i], _conv_descs[i], _input_descs[i], + _bwd_data_algo[i], &_workspace_bwd_data_sizes[i]) ); + } + + + // reduce over all workspace sizes to get a maximum to allocate / reallocate + size_t total_workspace_fwd = 0; + size_t total_workspace_bwd_data = 0; + size_t total_workspace_bwd_filter = 0; + + for (size_t i = 0; i < this->_inputs.size(); i++) { + total_workspace_fwd = std::max(total_workspace_fwd, + _workspace_fwd_sizes[i]); + total_workspace_bwd_data = std::max(total_workspace_bwd_data, + _workspace_bwd_data_sizes[i]); + total_workspace_bwd_filter = std::max(total_workspace_bwd_filter, + _workspace_bwd_filter_sizes[i]); + } + // get max over all operations + size_t max_workspace = std::max(total_workspace_fwd, + total_workspace_bwd_data); + max_workspace = std::max(max_workspace, total_workspace_bwd_filter); + // ensure all groups have enough workspace + size_t total_max_workspace = max_workspace * + (this->_op_param.group() * CUDNN_STREAMS_PER_GROUP); + + // this is the total amount of storage needed over all groups + streams + if (total_max_workspace > workspaceSizeInBytes) { + DLOG(INFO) << "Reallocating workspace storage: " << total_max_workspace; + workspaceSizeInBytes = total_max_workspace; + + // free the existing workspace and allocate a new (larger) one + cudaFree(this->workspaceData); + + cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes); + if (err != cudaSuccess) { + // force zero memory path + for (int i = 0; i < this->_inputs.size(); i++) { + _workspace_fwd_sizes[i] = 0; + _workspace_bwd_filter_sizes[i] = 0; + _workspace_bwd_data_sizes[i] = 0; + _fwd_algo[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + _bwd_filter_algo[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; + _bwd_data_algo[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; + } + + // NULL out all workspace pointers + for (int g = 0; g < (this->_op_param.group() * CUDNN_STREAMS_PER_GROUP); g++) { + workspace[g] = NULL; + } + // NULL out underlying data + workspaceData = NULL; + workspaceSizeInBytes = 0; + } + + // if we succeed in the allocation, set pointer aliases for workspaces + for (int g = 0; g < (this->_op_param.group() * CUDNN_STREAMS_PER_GROUP); g++) { + workspace[g] = reinterpret_cast(workspaceData) + g*max_workspace; + } + } + + // Tensor descriptor for bias. + if (this->_op_param.bias_term()) { + Context::set_tensor4d_descriptor(&_bias_desc, + 1, this->_output_shapes[this->_outputs[0]][1] / this->_op_param.group(), 1, 1); + } + +} + + + +INSTANTIATE_CLASS(Convolution); +REGISTER_OPERATOR_CLASS(Convolution); +}//namespace diff --git a/src/sita/dlflow/operators/convolution.cu b/src/sita/dlflow/operators/convolution.cu new file mode 100644 index 00000000..addd356a --- /dev/null +++ b/src/sita/dlflow/operators/convolution.cu @@ -0,0 +1,123 @@ +#include "sita/dlflow/operators/convolution.h" +namespace sita{ + +__global__ void sync_conv_groups() { } + +template +void Convolution::forward(){ + Tensor* weight_tensor = this->fetch_param("convolution_weight"); + const Dtype * weight = weight_tensor->gpu_data(); + for (int i = 0; i < this->_inputs.size(); ++i) { + Tensor* input_tensor = this->fetch_input(this->_inputs[i]); + const Dtype * input_data = input_tensor->gpu_data(); + Tensor* output_tensor = this->fetch_output(this->_outputs[i]); + Dtype * output_data = output_tensor->mutable_gpu_data(); + int input_offset = input_tensor->count() / this->_op_param.group(); + int output_offset = output_tensor->count() / this->_op_param.group(); + int weight_offset = weight_tensor->count() / this->_op_param.group(); + + + // Forward through cuDNN in parallel over groups. + for (int g = 0; g < this->_op_param.group(); g++) { + //Filters. + CUDNN_CHECK(cudnnConvolutionForward(_handle[g], + CudnnDataType::one, + _input_descs[i], input_data + input_offset * g, + _filter_desc, weight + weight_offset * g, + _conv_descs[i], + _fwd_algo[i], workspace[g], _workspace_fwd_sizes[i], + CudnnDataType::zero, + _output_descs[i], output_data + output_offset * g)); + + // Bias. + if (this->_op_param.bias_term()) { + + Tensor * bias_tensor = this->fetch_param("convolution_bias"); + const Dtype * bias_data = bias_tensor->gpu_data(); + int bias_offset = bias_tensor->count() / this->_op_param.group(); + CUDNN_CHECK(cudnnAddTensor(_handle[g], + CudnnDataType::one, + _bias_desc, bias_data + bias_offset * g, + CudnnDataType::one, + _output_descs[i], output_data + output_offset * g)); + } + } + + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_conv_groups<<<1, 1>>>(); + } +}; + +template +void Convolution::backward(){ + Tensor* weight_tensor = this->fetch_param("convolution_weight"); + + for (int i = 0; i < this->_inputs.size(); ++i) { + Tensor* input_tensor = this->fetch_input(this->_inputs[i]); + const Dtype * input_data = input_tensor->gpu_data(); + Dtype * input_diff = input_tensor->mutable_gpu_data(); + if(this->_gradient_block){ + Context::gpu_memset(input_diff, input_tensor->size()); + + }else{ + Tensor* output_tensor = this->fetch_output(this->_outputs[i]); + const Dtype * output_data = output_tensor->gpu_data(); + Dtype * output_diff = output_tensor->mutable_gpu_diff(); + const Dtype * weight = weight_tensor->gpu_data(); + Dtype * weight_diff = weight_tensor->mutable_gpu_diff(); + Dtype * bias_diff = NULL; + int input_offset = input_tensor->count() / this->_op_param.group(); + int output_offset = output_tensor->count() / this->_op_param.group(); + int weight_offset = weight_tensor->count() / this->_op_param.group(); + + for (int g = 0; g < this->_op_param.group(); g++) { + // Gradient w.r.t. bias. + if (this->_op_param.bias_term()) { + Tensor* bias_tensor = this->fetch_param("convolution_bias"); + bias_diff = bias_tensor->mutable_gpu_diff(); + int bias_offset = bias_tensor->count() / this->_op_param.group(); + CUDNN_CHECK(cudnnConvolutionBackwardBias(_handle[0 * this->_op_param.group() + g], + CudnnDataType::one, + _output_descs[i], + output_diff + output_offset * g, + CudnnDataType::one, + _bias_desc, bias_diff + bias_offset * g)); + } + + // Gradient w.r.t. weights. + CUDNN_CHECK(cudnnConvolutionBackwardFilter( + _handle[1*this->_op_param.group() + g], + CudnnDataType::one, + _input_descs[i], input_data + input_offset * g, + _output_descs[i], output_diff + output_offset * g, + _conv_descs[i], + _bwd_filter_algo[i], workspace[1*this->_op_param.group() + g], + _workspace_bwd_filter_sizes[i], + CudnnDataType::one, + _filter_desc, weight_diff + weight_offset * g)); + + + // Gradient w.r.t. bottom data. + CUDNN_CHECK(cudnnConvolutionBackwardData( + _handle[2*this->_op_param.group() + g], + CudnnDataType::one, + _filter_desc, weight + weight_offset * g, + _output_descs[i], output_diff + output_offset * g, + _conv_descs[i], + _bwd_data_algo[i], workspace[2 * this->_op_param.group() + g], + _workspace_bwd_data_sizes[i], + CudnnDataType::zero, + _input_descs[i], input_diff + input_offset * g)); + } + } + // Synchronize the work across groups, each of which went into its own + // stream, by launching an empty kernel into the default (null) stream. + // NOLINT_NEXT_LINE(whitespace/operators) + sync_conv_groups<<<1, 1>>>(); + } +}; + +INSTANTIATE_OPERATOR_GPU_FUNCS(Convolution); +}//namespace diff --git a/src/sita/dlflow/operators/convolution_op.cpp b/src/sita/dlflow/operators/convolution_op.cpp deleted file mode 100644 index 22dfd5b3..00000000 --- a/src/sita/dlflow/operators/convolution_op.cpp +++ /dev/null @@ -1,157 +0,0 @@ -// -// Created by unsky on 16/08/18. -// -#include "sita/dlflow/operators/convolution_op.h" -namespace sita{ - -template -void ConvolutionOp::init(){ - - //Initialize param input and outputs - for(int i = 0; i < this->_inputs.size(); i++){ - Tensor *input = this->fetch_input(this->_inputs[i]); - this->_input_shapes[this->_inputs[i]] = input->shape(); - CHECK_GT(input->count(), 0) << "check your graph, cannot infer " << this->_inputs[i] << " shape,in " << this->operator_name()<<"!!"; - LOG(INFO) << this->_inputs[i]<<": "<< this->fetch_input(this->_inputs[i])->shape_string(); - } - - int kernel_h; - int kernel_w; - if(this->_op_param.has_kernel_size()){ - kernel_h = this->_op_param.kernel_size(); - kernel_w = this->_op_param.kernel_size(); - }else{ - kernel_h = this->_op_param.kernel_h(); - kernel_w = this->_op_param.kernel_w(); - } - std::vector weight_shape; - weight_shape.push_back(this->_op_param.num_output()); - weight_shape.push_back(this->_input_shapes[this->_inputs[0]][1]/this->_op_param.group()); - weight_shape.push_back(kernel_h); - weight_shape.push_back(kernel_w); - this->init_param("convolution_weight",weight_shape,this->_param_configs[0]); - - if(this->_op_param.bias_term()) { - std::vector bias_shape; - bias_shape.push_back(this->_op_param.num_output()); - this->init_param("convolution_bias", bias_shape,this->_param_configs[1]); - } - - Tensor * weight = this->fetch_param("convolution_weight"); - LOG(INFO) << "convolution_weight:" <shape_string(); - Tensor * bias = this->fetch_param("convolution_bias"); - LOG(INFO) << "convolution_bias:" <shape_string(); - - std::vector output_shape; - output_shape.push_back(this->_input_shapes[this->_inputs[0]][0]); - output_shape.push_back(this->_op_param.num_output()); - - int in_height = this->_input_shapes[this->_inputs[0]][2]; - int in_width = this->_input_shapes[this->_inputs[0]][3]; - - int pad_h, pad_w, dilation, stride_h, stride_w; - - if(this->_op_param.has_pad()){ - pad_h = int(this->_op_param.pad()); - kernel_w = int(this->_op_param.pad()); - }else{ - pad_h = this->_op_param.pad_h(); - pad_w = this->_op_param.pad_w(); - } - - if(this->_op_param.has_stride()){ - stride_h = int(this->_op_param.stride()); - stride_w = int(this->_op_param.stride()); - }else{ - stride_h = this->_op_param.stride_h(); - stride_w = this->_op_param.stride_w(); - } - dilation = this->_op_param.dilation(); - int out_height = (in_height + 2 * pad_h - dilation)/stride_h + 1; - int out_width = (in_width + 2* pad_w - dilation)/stride_w +1; - output_shape.push_back(out_height); - output_shape.push_back(out_width); - - Tensor *output_data = this->fetch_output(this->_outputs[0]); - output_data->reshape(output_shape); - LOG(INFO) << this->_outputs[0]<<": "<< this->fetch_output(this->_outputs[0])->shape_string(); - - - // Initialize CUDA streams and cuDNN. - _stream = new cudaStream_t[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; - _handle = new cudnnHandle_t[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; - - // Initialize algorithm arrays - _fwd_algo = new cudnnConvolutionFwdAlgo_t[this->_inputs.size()]; - _bwd_filter_algo = new cudnnConvolutionBwdFilterAlgo_t[this->_inputs.size()]; - _bwd_data_algo = new cudnnConvolutionBwdDataAlgo_t[this->_inputs.size()]; - - // initialize size arrays - _workspace_fwd_sizes = new size_t[this->_inputs.size()]; - _workspace_bwd_filter_sizes = new size_t[this->_inputs.size()]; - _workspace_bwd_data_sizes = new size_t[this->_inputs.size()]; - - // workspace data - workspaceSizeInBytes = 0; - workspaceData = NULL; - workspace = new void*[this->_op_param.group() * CUDNN_STREAMS_PER_GROUP]; - - for (size_t i = 0; i < this->_inputs.size(); ++i) { - // initialize all to default algorithms - _fwd_algo[i] = (cudnnConvolutionFwdAlgo_t)0; - _bwd_filter_algo[i] = (cudnnConvolutionBwdFilterAlgo_t)0; - _bwd_data_algo[i] = (cudnnConvolutionBwdDataAlgo_t)0; - // default algorithms don't require workspace - _workspace_fwd_sizes[i] = 0; - _workspace_bwd_data_sizes[i] = 0; - _workspace_bwd_filter_sizes[i] = 0; - } - for (int g = 0; g < this->_op_param.group() * CUDNN_STREAMS_PER_GROUP; g++) { - CUDA_CHECK(cudaStreamCreate(&_stream[g])); - CUDNN_CHECK(cudnnCreate(&_handle[g])); - CUDNN_CHECK(cudnnSetStream(_handle[g], _stream[g])); - workspace[g] = NULL; - } - - // Create filter descriptor. - Context::create_filter_descriptor(&_filter_desc, - this->_op_param.num_output()/this->_op_param.group(), weight_shape[1], - kernel_h, kernel_w); - - // Create tensor descriptor(s) for data and corresponding convolution(s). - for (int i = 0; i < this->_inputs.size(); i++) { - cudnnTensorDescriptor_t input_desc; - Context::create_tensor4d_descriptor(&input_desc); - _input_descs.push_back(input_desc); - - cudnnTensorDescriptor_t output_desc; - Context::create_tensor4d_descriptor(&output_desc); - _output_descs.push_back(output_desc); - - cudnnConvolutionDescriptor_t conv_desc; - Context::create_convolution_descriptor(&conv_desc); - _conv_descs.push_back(conv_desc); - } - - // Tensor descriptor for bias. - if (this->_op_param.bias_term()) { - Context::create_tensor4d_descriptor(&_bias_desc); - } - _handles_setup = true; -} - -template -void ConvolutionOp::infer_shape() { - -} -template -void ConvolutionOp::forward(){ -// Tensor * data = this->fetch_input(this->_inputs[0]); -// Tensor * add_weight = this->fetch_param("add_weight"); - //LOG(INFO)<<_add_op_param.kernel_h(); - -}; - -INSTANTIATE_CLASS(ConvolutionOp); -REGISTER_OPERATOR_CLASS(ConvolutionOp); -}//namespace diff --git a/src/sita/dlflow/operators/data_test_op.cpp b/src/sita/dlflow/operators/data_test_op.cpp deleted file mode 100644 index 3633d36d..00000000 --- a/src/sita/dlflow/operators/data_test_op.cpp +++ /dev/null @@ -1,4 +0,0 @@ -// -// Created by cs on 02/08/18. -// - diff --git a/src/sita/proto/sita_operators.pb.cc b/src/sita/proto/sita_operators.pb.cc index df4b90f3..3e72ffc8 100644 --- a/src/sita/proto/sita_operators.pb.cc +++ b/src/sita/proto/sita_operators.pb.cc @@ -26,12 +26,9 @@ const ::google::protobuf::internal::GeneratedMessageReflection* const ::google::protobuf::Descriptor* OperatorParameter_descriptor_ = NULL; const ::google::protobuf::internal::GeneratedMessageReflection* OperatorParameter_reflection_ = NULL; -const ::google::protobuf::Descriptor* AddOpParameter_descriptor_ = NULL; +const ::google::protobuf::Descriptor* ConvolutionParameter_descriptor_ = NULL; const ::google::protobuf::internal::GeneratedMessageReflection* - AddOpParameter_reflection_ = NULL; -const ::google::protobuf::Descriptor* ConvolutionOpParameter_descriptor_ = NULL; -const ::google::protobuf::internal::GeneratedMessageReflection* - ConvolutionOpParameter_reflection_ = NULL; + ConvolutionParameter_reflection_ = NULL; } // namespace @@ -59,15 +56,14 @@ void protobuf_AssignDesc_sita_5foperators_2eproto() { ::google::protobuf::MessageFactory::generated_factory(), sizeof(GraphParameter)); OperatorParameter_descriptor_ = file->message_type(1); - static const int OperatorParameter_offsets_[8] = { + static const int OperatorParameter_offsets_[7] = { GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, name_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, type_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, input_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, output_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, loss_weight_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, gradient_block_), GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, param_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, add_op_param_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, convolution_op_param_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(OperatorParameter, convolution_param_), }; OperatorParameter_reflection_ = new ::google::protobuf::internal::GeneratedMessageReflection( @@ -80,53 +76,32 @@ void protobuf_AssignDesc_sita_5foperators_2eproto() { ::google::protobuf::DescriptorPool::generated_pool(), ::google::protobuf::MessageFactory::generated_factory(), sizeof(OperatorParameter)); - AddOpParameter_descriptor_ = file->message_type(2); - static const int AddOpParameter_offsets_[6] = { - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, pad_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, pad_w_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, kernel_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, kernel_w_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, stride_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, stride_w_), + ConvolutionParameter_descriptor_ = file->message_type(2); + static const int ConvolutionParameter_offsets_[12] = { + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, num_output_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, bias_term_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, pad_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, kernel_size_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, stride_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, pad_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, pad_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, kernel_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, kernel_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, stride_h_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, stride_w_), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, group_), }; - AddOpParameter_reflection_ = + ConvolutionParameter_reflection_ = new ::google::protobuf::internal::GeneratedMessageReflection( - AddOpParameter_descriptor_, - AddOpParameter::default_instance_, - AddOpParameter_offsets_, - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, _has_bits_[0]), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(AddOpParameter, _unknown_fields_), + ConvolutionParameter_descriptor_, + ConvolutionParameter::default_instance_, + ConvolutionParameter_offsets_, + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, _has_bits_[0]), + GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionParameter, _unknown_fields_), -1, ::google::protobuf::DescriptorPool::generated_pool(), ::google::protobuf::MessageFactory::generated_factory(), - sizeof(AddOpParameter)); - ConvolutionOpParameter_descriptor_ = file->message_type(3); - static const int ConvolutionOpParameter_offsets_[13] = { - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, num_output_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, bias_term_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, pad_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, kernel_size_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, stride_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, dilation_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, pad_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, pad_w_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, kernel_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, kernel_w_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, stride_h_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, stride_w_), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, group_), - }; - ConvolutionOpParameter_reflection_ = - new ::google::protobuf::internal::GeneratedMessageReflection( - ConvolutionOpParameter_descriptor_, - ConvolutionOpParameter::default_instance_, - ConvolutionOpParameter_offsets_, - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, _has_bits_[0]), - GOOGLE_PROTOBUF_GENERATED_MESSAGE_FIELD_OFFSET(ConvolutionOpParameter, _unknown_fields_), - -1, - ::google::protobuf::DescriptorPool::generated_pool(), - ::google::protobuf::MessageFactory::generated_factory(), - sizeof(ConvolutionOpParameter)); + sizeof(ConvolutionParameter)); } namespace { @@ -144,9 +119,7 @@ void protobuf_RegisterTypes(const ::std::string&) { ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( OperatorParameter_descriptor_, &OperatorParameter::default_instance()); ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( - AddOpParameter_descriptor_, &AddOpParameter::default_instance()); - ::google::protobuf::MessageFactory::InternalRegisterGeneratedMessage( - ConvolutionOpParameter_descriptor_, &ConvolutionOpParameter::default_instance()); + ConvolutionParameter_descriptor_, &ConvolutionParameter::default_instance()); } } // namespace @@ -156,10 +129,8 @@ void protobuf_ShutdownFile_sita_5foperators_2eproto() { delete GraphParameter_reflection_; delete OperatorParameter::default_instance_; delete OperatorParameter_reflection_; - delete AddOpParameter::default_instance_; - delete AddOpParameter_reflection_; - delete ConvolutionOpParameter::default_instance_; - delete ConvolutionOpParameter_reflection_; + delete ConvolutionParameter::default_instance_; + delete ConvolutionParameter_reflection_; } void protobuf_AddDesc_sita_5foperators_2eproto() { @@ -173,33 +144,26 @@ void protobuf_AddDesc_sita_5foperators_2eproto() { "\n\024sita_operators.proto\022\004sita\032\020sita_utils" ".proto\"L\n\016GraphParameter\022\014\n\004name\030\001 \001(\t\022," "\n\013operatordef\030\002 \003(\0132\027.sita.OperatorParam" - "eter\"\355\001\n\021OperatorParameter\022\014\n\004name\030\001 \001(\t" + "eter\"\306\001\n\021OperatorParameter\022\014\n\004name\030\001 \001(\t" "\022\014\n\004type\030\002 \001(\t\022\r\n\005input\030\003 \003(\t\022\016\n\006output\030" - "\004 \003(\t\022\023\n\013loss_weight\030\005 \003(\002\022 \n\005param\030\006 \003(" - "\0132\021.sita.ParamConfig\022*\n\014add_op_param\030d \001" - "(\0132\024.sita.AddOpParameter\022:\n\024convolution_" - "op_param\030e \001(\0132\034.sita.ConvolutionOpParam" - "eter\"|\n\016AddOpParameter\022\020\n\005pad_h\030\001 \001(\r:\0010" - "\022\020\n\005pad_w\030\002 \001(\r:\0010\022\020\n\010kernel_h\030\003 \001(\r\022\020\n\010" - "kernel_w\030\004 \001(\r\022\020\n\010stride_h\030\005 \001(\r\022\020\n\010stri" - "de_w\030\006 \001(\r\"\220\002\n\026ConvolutionOpParameter\022\022\n" - "\nnum_output\030\001 \001(\r\022\027\n\tbias_term\030\002 \001(\010:\004tr" - "ue\022\013\n\003pad\030\003 \001(\r\022\023\n\013kernel_size\030\004 \001(\r\022\016\n\006" - "stride\030\005 \001(\r\022\023\n\010dilation\030\006 \001(\r:\0011\022\020\n\005pad" - "_h\030\007 \001(\r:\0010\022\020\n\005pad_w\030\010 \001(\r:\0010\022\023\n\010kernel_" - "h\030\t \001(\r:\0013\022\023\n\010kernel_w\030\n \001(\r:\0013\022\020\n\010strid" - "e_h\030\013 \001(\r\022\020\n\010stride_w\030\014 \001(\r\022\020\n\005group\030\r \001" - "(\r:\0011", 765); + "\004 \003(\t\022\035\n\016gradient_block\030\005 \001(\010:\005false\022 \n\005" + "param\030\006 \003(\0132\021.sita.ParamConfig\0225\n\021convol" + "ution_param\030e \001(\0132\032.sita.ConvolutionPara" + "meter\"\366\001\n\024ConvolutionParameter\022\022\n\nnum_ou" + "tput\030\001 \001(\r\022\027\n\tbias_term\030\002 \001(\010:\004true\022\016\n\003p" + "ad\030\003 \001(\r:\0011\022\026\n\013kernel_size\030\004 \001(\r:\0013\022\021\n\006s" + "tride\030\005 \001(\r:\0011\022\r\n\005pad_h\030\006 \001(\r\022\r\n\005pad_w\030\007" + " \001(\r\022\020\n\010kernel_h\030\010 \001(\r\022\020\n\010kernel_w\030\t \001(\r" + "\022\020\n\010stride_h\030\n \001(\r\022\020\n\010stride_w\030\013 \001(\r\022\020\n\005" + "group\030\014 \001(\r:\0011", 574); ::google::protobuf::MessageFactory::InternalRegisterGeneratedFile( "sita_operators.proto", &protobuf_RegisterTypes); GraphParameter::default_instance_ = new GraphParameter(); OperatorParameter::default_instance_ = new OperatorParameter(); - AddOpParameter::default_instance_ = new AddOpParameter(); - ConvolutionOpParameter::default_instance_ = new ConvolutionOpParameter(); + ConvolutionParameter::default_instance_ = new ConvolutionParameter(); GraphParameter::default_instance_->InitAsDefaultInstance(); OperatorParameter::default_instance_->InitAsDefaultInstance(); - AddOpParameter::default_instance_->InitAsDefaultInstance(); - ConvolutionOpParameter::default_instance_->InitAsDefaultInstance(); + ConvolutionParameter::default_instance_->InitAsDefaultInstance(); ::google::protobuf::internal::OnShutdown(&protobuf_ShutdownFile_sita_5foperators_2eproto); } @@ -499,10 +463,9 @@ const int OperatorParameter::kNameFieldNumber; const int OperatorParameter::kTypeFieldNumber; const int OperatorParameter::kInputFieldNumber; const int OperatorParameter::kOutputFieldNumber; -const int OperatorParameter::kLossWeightFieldNumber; +const int OperatorParameter::kGradientBlockFieldNumber; const int OperatorParameter::kParamFieldNumber; -const int OperatorParameter::kAddOpParamFieldNumber; -const int OperatorParameter::kConvolutionOpParamFieldNumber; +const int OperatorParameter::kConvolutionParamFieldNumber; #endif // !_MSC_VER OperatorParameter::OperatorParameter() @@ -512,8 +475,7 @@ OperatorParameter::OperatorParameter() } void OperatorParameter::InitAsDefaultInstance() { - add_op_param_ = const_cast< ::sita::AddOpParameter*>(&::sita::AddOpParameter::default_instance()); - convolution_op_param_ = const_cast< ::sita::ConvolutionOpParameter*>(&::sita::ConvolutionOpParameter::default_instance()); + convolution_param_ = const_cast< ::sita::ConvolutionParameter*>(&::sita::ConvolutionParameter::default_instance()); } OperatorParameter::OperatorParameter(const OperatorParameter& from) @@ -528,8 +490,8 @@ void OperatorParameter::SharedCtor() { _cached_size_ = 0; name_ = const_cast< ::std::string*>(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); type_ = const_cast< ::std::string*>(&::google::protobuf::internal::GetEmptyStringAlreadyInited()); - add_op_param_ = NULL; - convolution_op_param_ = NULL; + gradient_block_ = false; + convolution_param_ = NULL; ::memset(_has_bits_, 0, sizeof(_has_bits_)); } @@ -546,8 +508,7 @@ void OperatorParameter::SharedDtor() { delete type_; } if (this != default_instance_) { - delete add_op_param_; - delete convolution_op_param_; + delete convolution_param_; } } @@ -573,7 +534,7 @@ OperatorParameter* OperatorParameter::New() const { } void OperatorParameter::Clear() { - if (_has_bits_[0 / 32] & 195) { + if (_has_bits_[0 / 32] & 83) { if (has_name()) { if (name_ != &::google::protobuf::internal::GetEmptyStringAlreadyInited()) { name_->clear(); @@ -584,16 +545,13 @@ void OperatorParameter::Clear() { type_->clear(); } } - if (has_add_op_param()) { - if (add_op_param_ != NULL) add_op_param_->::sita::AddOpParameter::Clear(); - } - if (has_convolution_op_param()) { - if (convolution_op_param_ != NULL) convolution_op_param_->::sita::ConvolutionOpParameter::Clear(); + gradient_block_ = false; + if (has_convolution_param()) { + if (convolution_param_ != NULL) convolution_param_->::sita::ConvolutionParameter::Clear(); } } input_.Clear(); output_.Clear(); - loss_weight_.Clear(); param_.Clear(); ::memset(_has_bits_, 0, sizeof(_has_bits_)); mutable_unknown_fields()->Clear(); @@ -676,25 +634,21 @@ bool OperatorParameter::MergePartialFromCodedStream( goto handle_unusual; } if (input->ExpectTag(34)) goto parse_output; - if (input->ExpectTag(45)) goto parse_loss_weight; + if (input->ExpectTag(40)) goto parse_gradient_block; break; } - // repeated float loss_weight = 5; + // optional bool gradient_block = 5 [default = false]; case 5: { - if (tag == 45) { - parse_loss_weight: - DO_((::google::protobuf::internal::WireFormatLite::ReadRepeatedPrimitive< - float, ::google::protobuf::internal::WireFormatLite::TYPE_FLOAT>( - 1, 45, input, this->mutable_loss_weight()))); - } else if (tag == 42) { - DO_((::google::protobuf::internal::WireFormatLite::ReadPackedPrimitiveNoInline< - float, ::google::protobuf::internal::WireFormatLite::TYPE_FLOAT>( - input, this->mutable_loss_weight()))); + if (tag == 40) { + parse_gradient_block: + DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< + bool, ::google::protobuf::internal::WireFormatLite::TYPE_BOOL>( + input, &gradient_block_))); + set_has_gradient_block(); } else { goto handle_unusual; } - if (input->ExpectTag(45)) goto parse_loss_weight; if (input->ExpectTag(50)) goto parse_param; break; } @@ -709,29 +663,16 @@ bool OperatorParameter::MergePartialFromCodedStream( goto handle_unusual; } if (input->ExpectTag(50)) goto parse_param; - if (input->ExpectTag(802)) goto parse_add_op_param; + if (input->ExpectTag(810)) goto parse_convolution_param; break; } - // optional .sita.AddOpParameter add_op_param = 100; - case 100: { - if (tag == 802) { - parse_add_op_param: - DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( - input, mutable_add_op_param())); - } else { - goto handle_unusual; - } - if (input->ExpectTag(810)) goto parse_convolution_op_param; - break; - } - - // optional .sita.ConvolutionOpParameter convolution_op_param = 101; + // optional .sita.ConvolutionParameter convolution_param = 101; case 101: { if (tag == 810) { - parse_convolution_op_param: + parse_convolution_param: DO_(::google::protobuf::internal::WireFormatLite::ReadMessageNoVirtual( - input, mutable_convolution_op_param())); + input, mutable_convolution_param())); } else { goto handle_unusual; } @@ -804,10 +745,9 @@ void OperatorParameter::SerializeWithCachedSizes( 4, this->output(i), output); } - // repeated float loss_weight = 5; - for (int i = 0; i < this->loss_weight_size(); i++) { - ::google::protobuf::internal::WireFormatLite::WriteFloat( - 5, this->loss_weight(i), output); + // optional bool gradient_block = 5 [default = false]; + if (has_gradient_block()) { + ::google::protobuf::internal::WireFormatLite::WriteBool(5, this->gradient_block(), output); } // repeated .sita.ParamConfig param = 6; @@ -816,16 +756,10 @@ void OperatorParameter::SerializeWithCachedSizes( 6, this->param(i), output); } - // optional .sita.AddOpParameter add_op_param = 100; - if (has_add_op_param()) { - ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( - 100, this->add_op_param(), output); - } - - // optional .sita.ConvolutionOpParameter convolution_op_param = 101; - if (has_convolution_op_param()) { + // optional .sita.ConvolutionParameter convolution_param = 101; + if (has_convolution_param()) { ::google::protobuf::internal::WireFormatLite::WriteMessageMaybeToArray( - 101, this->convolution_op_param(), output); + 101, this->convolution_param(), output); } if (!unknown_fields().empty()) { @@ -880,10 +814,9 @@ ::google::protobuf::uint8* OperatorParameter::SerializeWithCachedSizesToArray( WriteStringToArray(4, this->output(i), target); } - // repeated float loss_weight = 5; - for (int i = 0; i < this->loss_weight_size(); i++) { - target = ::google::protobuf::internal::WireFormatLite:: - WriteFloatToArray(5, this->loss_weight(i), target); + // optional bool gradient_block = 5 [default = false]; + if (has_gradient_block()) { + target = ::google::protobuf::internal::WireFormatLite::WriteBoolToArray(5, this->gradient_block(), target); } // repeated .sita.ParamConfig param = 6; @@ -893,18 +826,11 @@ ::google::protobuf::uint8* OperatorParameter::SerializeWithCachedSizesToArray( 6, this->param(i), target); } - // optional .sita.AddOpParameter add_op_param = 100; - if (has_add_op_param()) { - target = ::google::protobuf::internal::WireFormatLite:: - WriteMessageNoVirtualToArray( - 100, this->add_op_param(), target); - } - - // optional .sita.ConvolutionOpParameter convolution_op_param = 101; - if (has_convolution_op_param()) { + // optional .sita.ConvolutionParameter convolution_param = 101; + if (has_convolution_param()) { target = ::google::protobuf::internal::WireFormatLite:: WriteMessageNoVirtualToArray( - 101, this->convolution_op_param(), target); + 101, this->convolution_param(), target); } if (!unknown_fields().empty()) { @@ -933,18 +859,16 @@ int OperatorParameter::ByteSize() const { this->type()); } - // optional .sita.AddOpParameter add_op_param = 100; - if (has_add_op_param()) { - total_size += 2 + - ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( - this->add_op_param()); + // optional bool gradient_block = 5 [default = false]; + if (has_gradient_block()) { + total_size += 1 + 1; } - // optional .sita.ConvolutionOpParameter convolution_op_param = 101; - if (has_convolution_op_param()) { + // optional .sita.ConvolutionParameter convolution_param = 101; + if (has_convolution_param()) { total_size += 2 + ::google::protobuf::internal::WireFormatLite::MessageSizeNoVirtual( - this->convolution_op_param()); + this->convolution_param()); } } @@ -962,13 +886,6 @@ int OperatorParameter::ByteSize() const { this->output(i)); } - // repeated float loss_weight = 5; - { - int data_size = 0; - data_size = 4 * this->loss_weight_size(); - total_size += 1 * this->loss_weight_size() + data_size; - } - // repeated .sita.ParamConfig param = 6; total_size += 1 * this->param_size(); for (int i = 0; i < this->param_size(); i++) { @@ -1004,7 +921,6 @@ void OperatorParameter::MergeFrom(const OperatorParameter& from) { GOOGLE_CHECK_NE(&from, this); input_.MergeFrom(from.input_); output_.MergeFrom(from.output_); - loss_weight_.MergeFrom(from.loss_weight_); param_.MergeFrom(from.param_); if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { if (from.has_name()) { @@ -1013,11 +929,11 @@ void OperatorParameter::MergeFrom(const OperatorParameter& from) { if (from.has_type()) { set_type(from.type()); } - if (from.has_add_op_param()) { - mutable_add_op_param()->::sita::AddOpParameter::MergeFrom(from.add_op_param()); + if (from.has_gradient_block()) { + set_gradient_block(from.gradient_block()); } - if (from.has_convolution_op_param()) { - mutable_convolution_op_param()->::sita::ConvolutionOpParameter::MergeFrom(from.convolution_op_param()); + if (from.has_convolution_param()) { + mutable_convolution_param()->::sita::ConvolutionParameter::MergeFrom(from.convolution_param()); } } mutable_unknown_fields()->MergeFrom(from.unknown_fields()); @@ -1046,10 +962,9 @@ void OperatorParameter::Swap(OperatorParameter* other) { std::swap(type_, other->type_); input_.Swap(&other->input_); output_.Swap(&other->output_); - loss_weight_.Swap(&other->loss_weight_); + std::swap(gradient_block_, other->gradient_block_); param_.Swap(&other->param_); - std::swap(add_op_param_, other->add_op_param_); - std::swap(convolution_op_param_, other->convolution_op_param_); + std::swap(convolution_param_, other->convolution_param_); std::swap(_has_bits_[0], other->_has_bits_[0]); _unknown_fields_.Swap(&other->_unknown_fields_); std::swap(_cached_size_, other->_cached_size_); @@ -1068,517 +983,87 @@ ::google::protobuf::Metadata OperatorParameter::GetMetadata() const { // =================================================================== #ifndef _MSC_VER -const int AddOpParameter::kPadHFieldNumber; -const int AddOpParameter::kPadWFieldNumber; -const int AddOpParameter::kKernelHFieldNumber; -const int AddOpParameter::kKernelWFieldNumber; -const int AddOpParameter::kStrideHFieldNumber; -const int AddOpParameter::kStrideWFieldNumber; +const int ConvolutionParameter::kNumOutputFieldNumber; +const int ConvolutionParameter::kBiasTermFieldNumber; +const int ConvolutionParameter::kPadFieldNumber; +const int ConvolutionParameter::kKernelSizeFieldNumber; +const int ConvolutionParameter::kStrideFieldNumber; +const int ConvolutionParameter::kPadHFieldNumber; +const int ConvolutionParameter::kPadWFieldNumber; +const int ConvolutionParameter::kKernelHFieldNumber; +const int ConvolutionParameter::kKernelWFieldNumber; +const int ConvolutionParameter::kStrideHFieldNumber; +const int ConvolutionParameter::kStrideWFieldNumber; +const int ConvolutionParameter::kGroupFieldNumber; #endif // !_MSC_VER -AddOpParameter::AddOpParameter() +ConvolutionParameter::ConvolutionParameter() : ::google::protobuf::Message() { SharedCtor(); - // @@protoc_insertion_point(constructor:sita.AddOpParameter) + // @@protoc_insertion_point(constructor:sita.ConvolutionParameter) } -void AddOpParameter::InitAsDefaultInstance() { +void ConvolutionParameter::InitAsDefaultInstance() { } -AddOpParameter::AddOpParameter(const AddOpParameter& from) +ConvolutionParameter::ConvolutionParameter(const ConvolutionParameter& from) : ::google::protobuf::Message() { SharedCtor(); MergeFrom(from); - // @@protoc_insertion_point(copy_constructor:sita.AddOpParameter) + // @@protoc_insertion_point(copy_constructor:sita.ConvolutionParameter) } -void AddOpParameter::SharedCtor() { - _cached_size_ = 0; - pad_h_ = 0u; - pad_w_ = 0u; - kernel_h_ = 0u; - kernel_w_ = 0u; - stride_h_ = 0u; - stride_w_ = 0u; - ::memset(_has_bits_, 0, sizeof(_has_bits_)); -} - -AddOpParameter::~AddOpParameter() { - // @@protoc_insertion_point(destructor:sita.AddOpParameter) - SharedDtor(); -} - -void AddOpParameter::SharedDtor() { - if (this != default_instance_) { - } -} - -void AddOpParameter::SetCachedSize(int size) const { - GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); - _cached_size_ = size; - GOOGLE_SAFE_CONCURRENT_WRITES_END(); -} -const ::google::protobuf::Descriptor* AddOpParameter::descriptor() { - protobuf_AssignDescriptorsOnce(); - return AddOpParameter_descriptor_; -} - -const AddOpParameter& AddOpParameter::default_instance() { - if (default_instance_ == NULL) protobuf_AddDesc_sita_5foperators_2eproto(); - return *default_instance_; -} - -AddOpParameter* AddOpParameter::default_instance_ = NULL; - -AddOpParameter* AddOpParameter::New() const { - return new AddOpParameter; -} - -void AddOpParameter::Clear() { -#define OFFSET_OF_FIELD_(f) (reinterpret_cast( \ - &reinterpret_cast(16)->f) - \ - reinterpret_cast(16)) - -#define ZR_(first, last) do { \ - size_t f = OFFSET_OF_FIELD_(first); \ - size_t n = OFFSET_OF_FIELD_(last) - f + sizeof(last); \ - ::memset(&first, 0, n); \ - } while (0) - - if (_has_bits_[0 / 32] & 63) { - ZR_(pad_h_, stride_w_); - } - -#undef OFFSET_OF_FIELD_ -#undef ZR_ - - ::memset(_has_bits_, 0, sizeof(_has_bits_)); - mutable_unknown_fields()->Clear(); -} - -bool AddOpParameter::MergePartialFromCodedStream( - ::google::protobuf::io::CodedInputStream* input) { -#define DO_(EXPRESSION) if (!(EXPRESSION)) goto failure - ::google::protobuf::uint32 tag; - // @@protoc_insertion_point(parse_start:sita.AddOpParameter) - for (;;) { - ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(127); - tag = p.first; - if (!p.second) goto handle_unusual; - switch (::google::protobuf::internal::WireFormatLite::GetTagFieldNumber(tag)) { - // optional uint32 pad_h = 1 [default = 0]; - case 1: { - if (tag == 8) { - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &pad_h_))); - set_has_pad_h(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(16)) goto parse_pad_w; - break; - } - - // optional uint32 pad_w = 2 [default = 0]; - case 2: { - if (tag == 16) { - parse_pad_w: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &pad_w_))); - set_has_pad_w(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(24)) goto parse_kernel_h; - break; - } - - // optional uint32 kernel_h = 3; - case 3: { - if (tag == 24) { - parse_kernel_h: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &kernel_h_))); - set_has_kernel_h(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(32)) goto parse_kernel_w; - break; - } - - // optional uint32 kernel_w = 4; - case 4: { - if (tag == 32) { - parse_kernel_w: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &kernel_w_))); - set_has_kernel_w(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(40)) goto parse_stride_h; - break; - } - - // optional uint32 stride_h = 5; - case 5: { - if (tag == 40) { - parse_stride_h: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &stride_h_))); - set_has_stride_h(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(48)) goto parse_stride_w; - break; - } - - // optional uint32 stride_w = 6; - case 6: { - if (tag == 48) { - parse_stride_w: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &stride_w_))); - set_has_stride_w(); - } else { - goto handle_unusual; - } - if (input->ExpectAtEnd()) goto success; - break; - } - - default: { - handle_unusual: - if (tag == 0 || - ::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) == - ::google::protobuf::internal::WireFormatLite::WIRETYPE_END_GROUP) { - goto success; - } - DO_(::google::protobuf::internal::WireFormat::SkipField( - input, tag, mutable_unknown_fields())); - break; - } - } - } -success: - // @@protoc_insertion_point(parse_success:sita.AddOpParameter) - return true; -failure: - // @@protoc_insertion_point(parse_failure:sita.AddOpParameter) - return false; -#undef DO_ -} - -void AddOpParameter::SerializeWithCachedSizes( - ::google::protobuf::io::CodedOutputStream* output) const { - // @@protoc_insertion_point(serialize_start:sita.AddOpParameter) - // optional uint32 pad_h = 1 [default = 0]; - if (has_pad_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(1, this->pad_h(), output); - } - - // optional uint32 pad_w = 2 [default = 0]; - if (has_pad_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(2, this->pad_w(), output); - } - - // optional uint32 kernel_h = 3; - if (has_kernel_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(3, this->kernel_h(), output); - } - - // optional uint32 kernel_w = 4; - if (has_kernel_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(4, this->kernel_w(), output); - } - - // optional uint32 stride_h = 5; - if (has_stride_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(5, this->stride_h(), output); - } - - // optional uint32 stride_w = 6; - if (has_stride_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(6, this->stride_w(), output); - } - - if (!unknown_fields().empty()) { - ::google::protobuf::internal::WireFormat::SerializeUnknownFields( - unknown_fields(), output); - } - // @@protoc_insertion_point(serialize_end:sita.AddOpParameter) -} - -::google::protobuf::uint8* AddOpParameter::SerializeWithCachedSizesToArray( - ::google::protobuf::uint8* target) const { - // @@protoc_insertion_point(serialize_to_array_start:sita.AddOpParameter) - // optional uint32 pad_h = 1 [default = 0]; - if (has_pad_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(1, this->pad_h(), target); - } - - // optional uint32 pad_w = 2 [default = 0]; - if (has_pad_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(2, this->pad_w(), target); - } - - // optional uint32 kernel_h = 3; - if (has_kernel_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(3, this->kernel_h(), target); - } - - // optional uint32 kernel_w = 4; - if (has_kernel_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(4, this->kernel_w(), target); - } - - // optional uint32 stride_h = 5; - if (has_stride_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(5, this->stride_h(), target); - } - - // optional uint32 stride_w = 6; - if (has_stride_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(6, this->stride_w(), target); - } - - if (!unknown_fields().empty()) { - target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( - unknown_fields(), target); - } - // @@protoc_insertion_point(serialize_to_array_end:sita.AddOpParameter) - return target; -} - -int AddOpParameter::ByteSize() const { - int total_size = 0; - - if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { - // optional uint32 pad_h = 1 [default = 0]; - if (has_pad_h()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->pad_h()); - } - - // optional uint32 pad_w = 2 [default = 0]; - if (has_pad_w()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->pad_w()); - } - - // optional uint32 kernel_h = 3; - if (has_kernel_h()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->kernel_h()); - } - - // optional uint32 kernel_w = 4; - if (has_kernel_w()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->kernel_w()); - } - - // optional uint32 stride_h = 5; - if (has_stride_h()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->stride_h()); - } - - // optional uint32 stride_w = 6; - if (has_stride_w()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->stride_w()); - } - - } - if (!unknown_fields().empty()) { - total_size += - ::google::protobuf::internal::WireFormat::ComputeUnknownFieldsSize( - unknown_fields()); - } - GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); - _cached_size_ = total_size; - GOOGLE_SAFE_CONCURRENT_WRITES_END(); - return total_size; -} - -void AddOpParameter::MergeFrom(const ::google::protobuf::Message& from) { - GOOGLE_CHECK_NE(&from, this); - const AddOpParameter* source = - ::google::protobuf::internal::dynamic_cast_if_available( - &from); - if (source == NULL) { - ::google::protobuf::internal::ReflectionOps::Merge(from, this); - } else { - MergeFrom(*source); - } -} - -void AddOpParameter::MergeFrom(const AddOpParameter& from) { - GOOGLE_CHECK_NE(&from, this); - if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { - if (from.has_pad_h()) { - set_pad_h(from.pad_h()); - } - if (from.has_pad_w()) { - set_pad_w(from.pad_w()); - } - if (from.has_kernel_h()) { - set_kernel_h(from.kernel_h()); - } - if (from.has_kernel_w()) { - set_kernel_w(from.kernel_w()); - } - if (from.has_stride_h()) { - set_stride_h(from.stride_h()); - } - if (from.has_stride_w()) { - set_stride_w(from.stride_w()); - } - } - mutable_unknown_fields()->MergeFrom(from.unknown_fields()); -} - -void AddOpParameter::CopyFrom(const ::google::protobuf::Message& from) { - if (&from == this) return; - Clear(); - MergeFrom(from); -} - -void AddOpParameter::CopyFrom(const AddOpParameter& from) { - if (&from == this) return; - Clear(); - MergeFrom(from); -} - -bool AddOpParameter::IsInitialized() const { - - return true; -} - -void AddOpParameter::Swap(AddOpParameter* other) { - if (other != this) { - std::swap(pad_h_, other->pad_h_); - std::swap(pad_w_, other->pad_w_); - std::swap(kernel_h_, other->kernel_h_); - std::swap(kernel_w_, other->kernel_w_); - std::swap(stride_h_, other->stride_h_); - std::swap(stride_w_, other->stride_w_); - std::swap(_has_bits_[0], other->_has_bits_[0]); - _unknown_fields_.Swap(&other->_unknown_fields_); - std::swap(_cached_size_, other->_cached_size_); - } -} - -::google::protobuf::Metadata AddOpParameter::GetMetadata() const { - protobuf_AssignDescriptorsOnce(); - ::google::protobuf::Metadata metadata; - metadata.descriptor = AddOpParameter_descriptor_; - metadata.reflection = AddOpParameter_reflection_; - return metadata; -} - - -// =================================================================== - -#ifndef _MSC_VER -const int ConvolutionOpParameter::kNumOutputFieldNumber; -const int ConvolutionOpParameter::kBiasTermFieldNumber; -const int ConvolutionOpParameter::kPadFieldNumber; -const int ConvolutionOpParameter::kKernelSizeFieldNumber; -const int ConvolutionOpParameter::kStrideFieldNumber; -const int ConvolutionOpParameter::kDilationFieldNumber; -const int ConvolutionOpParameter::kPadHFieldNumber; -const int ConvolutionOpParameter::kPadWFieldNumber; -const int ConvolutionOpParameter::kKernelHFieldNumber; -const int ConvolutionOpParameter::kKernelWFieldNumber; -const int ConvolutionOpParameter::kStrideHFieldNumber; -const int ConvolutionOpParameter::kStrideWFieldNumber; -const int ConvolutionOpParameter::kGroupFieldNumber; -#endif // !_MSC_VER - -ConvolutionOpParameter::ConvolutionOpParameter() - : ::google::protobuf::Message() { - SharedCtor(); - // @@protoc_insertion_point(constructor:sita.ConvolutionOpParameter) -} - -void ConvolutionOpParameter::InitAsDefaultInstance() { -} - -ConvolutionOpParameter::ConvolutionOpParameter(const ConvolutionOpParameter& from) - : ::google::protobuf::Message() { - SharedCtor(); - MergeFrom(from); - // @@protoc_insertion_point(copy_constructor:sita.ConvolutionOpParameter) -} - -void ConvolutionOpParameter::SharedCtor() { +void ConvolutionParameter::SharedCtor() { _cached_size_ = 0; num_output_ = 0u; bias_term_ = true; - pad_ = 0u; - kernel_size_ = 0u; - stride_ = 0u; - dilation_ = 1u; + pad_ = 1u; + kernel_size_ = 3u; + stride_ = 1u; pad_h_ = 0u; pad_w_ = 0u; - kernel_h_ = 3u; - kernel_w_ = 3u; + kernel_h_ = 0u; + kernel_w_ = 0u; stride_h_ = 0u; stride_w_ = 0u; group_ = 1u; ::memset(_has_bits_, 0, sizeof(_has_bits_)); } -ConvolutionOpParameter::~ConvolutionOpParameter() { - // @@protoc_insertion_point(destructor:sita.ConvolutionOpParameter) +ConvolutionParameter::~ConvolutionParameter() { + // @@protoc_insertion_point(destructor:sita.ConvolutionParameter) SharedDtor(); } -void ConvolutionOpParameter::SharedDtor() { +void ConvolutionParameter::SharedDtor() { if (this != default_instance_) { } } -void ConvolutionOpParameter::SetCachedSize(int size) const { +void ConvolutionParameter::SetCachedSize(int size) const { GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN(); _cached_size_ = size; GOOGLE_SAFE_CONCURRENT_WRITES_END(); } -const ::google::protobuf::Descriptor* ConvolutionOpParameter::descriptor() { +const ::google::protobuf::Descriptor* ConvolutionParameter::descriptor() { protobuf_AssignDescriptorsOnce(); - return ConvolutionOpParameter_descriptor_; + return ConvolutionParameter_descriptor_; } -const ConvolutionOpParameter& ConvolutionOpParameter::default_instance() { +const ConvolutionParameter& ConvolutionParameter::default_instance() { if (default_instance_ == NULL) protobuf_AddDesc_sita_5foperators_2eproto(); return *default_instance_; } -ConvolutionOpParameter* ConvolutionOpParameter::default_instance_ = NULL; +ConvolutionParameter* ConvolutionParameter::default_instance_ = NULL; -ConvolutionOpParameter* ConvolutionOpParameter::New() const { - return new ConvolutionOpParameter; +ConvolutionParameter* ConvolutionParameter::New() const { + return new ConvolutionParameter; } -void ConvolutionOpParameter::Clear() { +void ConvolutionParameter::Clear() { #define OFFSET_OF_FIELD_(f) (reinterpret_cast( \ - &reinterpret_cast(16)->f) - \ + &reinterpret_cast(16)->f) - \ reinterpret_cast(16)) #define ZR_(first, last) do { \ @@ -1588,16 +1073,15 @@ void ConvolutionOpParameter::Clear() { } while (0) if (_has_bits_[0 / 32] & 255) { - ZR_(pad_, stride_); - ZR_(pad_h_, pad_w_); + ZR_(pad_h_, kernel_h_); num_output_ = 0u; bias_term_ = true; - dilation_ = 1u; + pad_ = 1u; + kernel_size_ = 3u; + stride_ = 1u; } - if (_has_bits_[8 / 32] & 7936) { - ZR_(stride_h_, stride_w_); - kernel_h_ = 3u; - kernel_w_ = 3u; + if (_has_bits_[8 / 32] & 3840) { + ZR_(kernel_w_, stride_w_); group_ = 1u; } @@ -1608,11 +1092,11 @@ void ConvolutionOpParameter::Clear() { mutable_unknown_fields()->Clear(); } -bool ConvolutionOpParameter::MergePartialFromCodedStream( +bool ConvolutionParameter::MergePartialFromCodedStream( ::google::protobuf::io::CodedInputStream* input) { #define DO_(EXPRESSION) if (!(EXPRESSION)) goto failure ::google::protobuf::uint32 tag; - // @@protoc_insertion_point(parse_start:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(parse_start:sita.ConvolutionParameter) for (;;) { ::std::pair< ::google::protobuf::uint32, bool> p = input->ReadTagWithCutoff(127); tag = p.first; @@ -1647,7 +1131,7 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( break; } - // optional uint32 pad = 3; + // optional uint32 pad = 3 [default = 1]; case 3: { if (tag == 24) { parse_pad: @@ -1662,7 +1146,7 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( break; } - // optional uint32 kernel_size = 4; + // optional uint32 kernel_size = 4 [default = 3]; case 4: { if (tag == 32) { parse_kernel_size: @@ -1677,7 +1161,7 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( break; } - // optional uint32 stride = 5; + // optional uint32 stride = 5 [default = 1]; case 5: { if (tag == 40) { parse_stride: @@ -1688,28 +1172,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(48)) goto parse_dilation; + if (input->ExpectTag(48)) goto parse_pad_h; break; } - // optional uint32 dilation = 6 [default = 1]; + // optional uint32 pad_h = 6; case 6: { if (tag == 48) { - parse_dilation: - DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< - ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( - input, &dilation_))); - set_has_dilation(); - } else { - goto handle_unusual; - } - if (input->ExpectTag(56)) goto parse_pad_h; - break; - } - - // optional uint32 pad_h = 7 [default = 0]; - case 7: { - if (tag == 56) { parse_pad_h: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1718,13 +1187,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(64)) goto parse_pad_w; + if (input->ExpectTag(56)) goto parse_pad_w; break; } - // optional uint32 pad_w = 8 [default = 0]; - case 8: { - if (tag == 64) { + // optional uint32 pad_w = 7; + case 7: { + if (tag == 56) { parse_pad_w: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1733,13 +1202,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(72)) goto parse_kernel_h; + if (input->ExpectTag(64)) goto parse_kernel_h; break; } - // optional uint32 kernel_h = 9 [default = 3]; - case 9: { - if (tag == 72) { + // optional uint32 kernel_h = 8; + case 8: { + if (tag == 64) { parse_kernel_h: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1748,13 +1217,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(80)) goto parse_kernel_w; + if (input->ExpectTag(72)) goto parse_kernel_w; break; } - // optional uint32 kernel_w = 10 [default = 3]; - case 10: { - if (tag == 80) { + // optional uint32 kernel_w = 9; + case 9: { + if (tag == 72) { parse_kernel_w: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1763,13 +1232,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(88)) goto parse_stride_h; + if (input->ExpectTag(80)) goto parse_stride_h; break; } - // optional uint32 stride_h = 11; - case 11: { - if (tag == 88) { + // optional uint32 stride_h = 10; + case 10: { + if (tag == 80) { parse_stride_h: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1778,13 +1247,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(96)) goto parse_stride_w; + if (input->ExpectTag(88)) goto parse_stride_w; break; } - // optional uint32 stride_w = 12; - case 12: { - if (tag == 96) { + // optional uint32 stride_w = 11; + case 11: { + if (tag == 88) { parse_stride_w: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1793,13 +1262,13 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } else { goto handle_unusual; } - if (input->ExpectTag(104)) goto parse_group; + if (input->ExpectTag(96)) goto parse_group; break; } - // optional uint32 group = 13 [default = 1]; - case 13: { - if (tag == 104) { + // optional uint32 group = 12 [default = 1]; + case 12: { + if (tag == 96) { parse_group: DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive< ::google::protobuf::uint32, ::google::protobuf::internal::WireFormatLite::TYPE_UINT32>( @@ -1826,17 +1295,17 @@ bool ConvolutionOpParameter::MergePartialFromCodedStream( } } success: - // @@protoc_insertion_point(parse_success:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(parse_success:sita.ConvolutionParameter) return true; failure: - // @@protoc_insertion_point(parse_failure:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(parse_failure:sita.ConvolutionParameter) return false; #undef DO_ } -void ConvolutionOpParameter::SerializeWithCachedSizes( +void ConvolutionParameter::SerializeWithCachedSizes( ::google::protobuf::io::CodedOutputStream* output) const { - // @@protoc_insertion_point(serialize_start:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(serialize_start:sita.ConvolutionParameter) // optional uint32 num_output = 1; if (has_num_output()) { ::google::protobuf::internal::WireFormatLite::WriteUInt32(1, this->num_output(), output); @@ -1847,71 +1316,66 @@ void ConvolutionOpParameter::SerializeWithCachedSizes( ::google::protobuf::internal::WireFormatLite::WriteBool(2, this->bias_term(), output); } - // optional uint32 pad = 3; + // optional uint32 pad = 3 [default = 1]; if (has_pad()) { ::google::protobuf::internal::WireFormatLite::WriteUInt32(3, this->pad(), output); } - // optional uint32 kernel_size = 4; + // optional uint32 kernel_size = 4 [default = 3]; if (has_kernel_size()) { ::google::protobuf::internal::WireFormatLite::WriteUInt32(4, this->kernel_size(), output); } - // optional uint32 stride = 5; + // optional uint32 stride = 5 [default = 1]; if (has_stride()) { ::google::protobuf::internal::WireFormatLite::WriteUInt32(5, this->stride(), output); } - // optional uint32 dilation = 6 [default = 1]; - if (has_dilation()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(6, this->dilation(), output); - } - - // optional uint32 pad_h = 7 [default = 0]; + // optional uint32 pad_h = 6; if (has_pad_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(7, this->pad_h(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(6, this->pad_h(), output); } - // optional uint32 pad_w = 8 [default = 0]; + // optional uint32 pad_w = 7; if (has_pad_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(8, this->pad_w(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(7, this->pad_w(), output); } - // optional uint32 kernel_h = 9 [default = 3]; + // optional uint32 kernel_h = 8; if (has_kernel_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(9, this->kernel_h(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(8, this->kernel_h(), output); } - // optional uint32 kernel_w = 10 [default = 3]; + // optional uint32 kernel_w = 9; if (has_kernel_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(10, this->kernel_w(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(9, this->kernel_w(), output); } - // optional uint32 stride_h = 11; + // optional uint32 stride_h = 10; if (has_stride_h()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(11, this->stride_h(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(10, this->stride_h(), output); } - // optional uint32 stride_w = 12; + // optional uint32 stride_w = 11; if (has_stride_w()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(12, this->stride_w(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(11, this->stride_w(), output); } - // optional uint32 group = 13 [default = 1]; + // optional uint32 group = 12 [default = 1]; if (has_group()) { - ::google::protobuf::internal::WireFormatLite::WriteUInt32(13, this->group(), output); + ::google::protobuf::internal::WireFormatLite::WriteUInt32(12, this->group(), output); } if (!unknown_fields().empty()) { ::google::protobuf::internal::WireFormat::SerializeUnknownFields( unknown_fields(), output); } - // @@protoc_insertion_point(serialize_end:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(serialize_end:sita.ConvolutionParameter) } -::google::protobuf::uint8* ConvolutionOpParameter::SerializeWithCachedSizesToArray( +::google::protobuf::uint8* ConvolutionParameter::SerializeWithCachedSizesToArray( ::google::protobuf::uint8* target) const { - // @@protoc_insertion_point(serialize_to_array_start:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(serialize_to_array_start:sita.ConvolutionParameter) // optional uint32 num_output = 1; if (has_num_output()) { target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(1, this->num_output(), target); @@ -1922,70 +1386,65 @@ ::google::protobuf::uint8* ConvolutionOpParameter::SerializeWithCachedSizesToArr target = ::google::protobuf::internal::WireFormatLite::WriteBoolToArray(2, this->bias_term(), target); } - // optional uint32 pad = 3; + // optional uint32 pad = 3 [default = 1]; if (has_pad()) { target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(3, this->pad(), target); } - // optional uint32 kernel_size = 4; + // optional uint32 kernel_size = 4 [default = 3]; if (has_kernel_size()) { target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(4, this->kernel_size(), target); } - // optional uint32 stride = 5; + // optional uint32 stride = 5 [default = 1]; if (has_stride()) { target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(5, this->stride(), target); } - // optional uint32 dilation = 6 [default = 1]; - if (has_dilation()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(6, this->dilation(), target); - } - - // optional uint32 pad_h = 7 [default = 0]; + // optional uint32 pad_h = 6; if (has_pad_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(7, this->pad_h(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(6, this->pad_h(), target); } - // optional uint32 pad_w = 8 [default = 0]; + // optional uint32 pad_w = 7; if (has_pad_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(8, this->pad_w(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(7, this->pad_w(), target); } - // optional uint32 kernel_h = 9 [default = 3]; + // optional uint32 kernel_h = 8; if (has_kernel_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(9, this->kernel_h(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(8, this->kernel_h(), target); } - // optional uint32 kernel_w = 10 [default = 3]; + // optional uint32 kernel_w = 9; if (has_kernel_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(10, this->kernel_w(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(9, this->kernel_w(), target); } - // optional uint32 stride_h = 11; + // optional uint32 stride_h = 10; if (has_stride_h()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(11, this->stride_h(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(10, this->stride_h(), target); } - // optional uint32 stride_w = 12; + // optional uint32 stride_w = 11; if (has_stride_w()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(12, this->stride_w(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(11, this->stride_w(), target); } - // optional uint32 group = 13 [default = 1]; + // optional uint32 group = 12 [default = 1]; if (has_group()) { - target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(13, this->group(), target); + target = ::google::protobuf::internal::WireFormatLite::WriteUInt32ToArray(12, this->group(), target); } if (!unknown_fields().empty()) { target = ::google::protobuf::internal::WireFormat::SerializeUnknownFieldsToArray( unknown_fields(), target); } - // @@protoc_insertion_point(serialize_to_array_end:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(serialize_to_array_end:sita.ConvolutionParameter) return target; } -int ConvolutionOpParameter::ByteSize() const { +int ConvolutionParameter::ByteSize() const { int total_size = 0; if (_has_bits_[0 / 32] & (0xffu << (0 % 32))) { @@ -2001,79 +1460,72 @@ int ConvolutionOpParameter::ByteSize() const { total_size += 1 + 1; } - // optional uint32 pad = 3; + // optional uint32 pad = 3 [default = 1]; if (has_pad()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->pad()); } - // optional uint32 kernel_size = 4; + // optional uint32 kernel_size = 4 [default = 3]; if (has_kernel_size()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->kernel_size()); } - // optional uint32 stride = 5; + // optional uint32 stride = 5 [default = 1]; if (has_stride()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->stride()); } - // optional uint32 dilation = 6 [default = 1]; - if (has_dilation()) { - total_size += 1 + - ::google::protobuf::internal::WireFormatLite::UInt32Size( - this->dilation()); - } - - // optional uint32 pad_h = 7 [default = 0]; + // optional uint32 pad_h = 6; if (has_pad_h()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->pad_h()); } - // optional uint32 pad_w = 8 [default = 0]; + // optional uint32 pad_w = 7; if (has_pad_w()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->pad_w()); } - } - if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { - // optional uint32 kernel_h = 9 [default = 3]; + // optional uint32 kernel_h = 8; if (has_kernel_h()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->kernel_h()); } - // optional uint32 kernel_w = 10 [default = 3]; + } + if (_has_bits_[8 / 32] & (0xffu << (8 % 32))) { + // optional uint32 kernel_w = 9; if (has_kernel_w()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->kernel_w()); } - // optional uint32 stride_h = 11; + // optional uint32 stride_h = 10; if (has_stride_h()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->stride_h()); } - // optional uint32 stride_w = 12; + // optional uint32 stride_w = 11; if (has_stride_w()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( this->stride_w()); } - // optional uint32 group = 13 [default = 1]; + // optional uint32 group = 12 [default = 1]; if (has_group()) { total_size += 1 + ::google::protobuf::internal::WireFormatLite::UInt32Size( @@ -2092,10 +1544,10 @@ int ConvolutionOpParameter::ByteSize() const { return total_size; } -void ConvolutionOpParameter::MergeFrom(const ::google::protobuf::Message& from) { +void ConvolutionParameter::MergeFrom(const ::google::protobuf::Message& from) { GOOGLE_CHECK_NE(&from, this); - const ConvolutionOpParameter* source = - ::google::protobuf::internal::dynamic_cast_if_available( + const ConvolutionParameter* source = + ::google::protobuf::internal::dynamic_cast_if_available( &from); if (source == NULL) { ::google::protobuf::internal::ReflectionOps::Merge(from, this); @@ -2104,7 +1556,7 @@ void ConvolutionOpParameter::MergeFrom(const ::google::protobuf::Message& from) } } -void ConvolutionOpParameter::MergeFrom(const ConvolutionOpParameter& from) { +void ConvolutionParameter::MergeFrom(const ConvolutionParameter& from) { GOOGLE_CHECK_NE(&from, this); if (from._has_bits_[0 / 32] & (0xffu << (0 % 32))) { if (from.has_num_output()) { @@ -2122,20 +1574,17 @@ void ConvolutionOpParameter::MergeFrom(const ConvolutionOpParameter& from) { if (from.has_stride()) { set_stride(from.stride()); } - if (from.has_dilation()) { - set_dilation(from.dilation()); - } if (from.has_pad_h()) { set_pad_h(from.pad_h()); } if (from.has_pad_w()) { set_pad_w(from.pad_w()); } - } - if (from._has_bits_[8 / 32] & (0xffu << (8 % 32))) { if (from.has_kernel_h()) { set_kernel_h(from.kernel_h()); } + } + if (from._has_bits_[8 / 32] & (0xffu << (8 % 32))) { if (from.has_kernel_w()) { set_kernel_w(from.kernel_w()); } @@ -2152,31 +1601,30 @@ void ConvolutionOpParameter::MergeFrom(const ConvolutionOpParameter& from) { mutable_unknown_fields()->MergeFrom(from.unknown_fields()); } -void ConvolutionOpParameter::CopyFrom(const ::google::protobuf::Message& from) { +void ConvolutionParameter::CopyFrom(const ::google::protobuf::Message& from) { if (&from == this) return; Clear(); MergeFrom(from); } -void ConvolutionOpParameter::CopyFrom(const ConvolutionOpParameter& from) { +void ConvolutionParameter::CopyFrom(const ConvolutionParameter& from) { if (&from == this) return; Clear(); MergeFrom(from); } -bool ConvolutionOpParameter::IsInitialized() const { +bool ConvolutionParameter::IsInitialized() const { return true; } -void ConvolutionOpParameter::Swap(ConvolutionOpParameter* other) { +void ConvolutionParameter::Swap(ConvolutionParameter* other) { if (other != this) { std::swap(num_output_, other->num_output_); std::swap(bias_term_, other->bias_term_); std::swap(pad_, other->pad_); std::swap(kernel_size_, other->kernel_size_); std::swap(stride_, other->stride_); - std::swap(dilation_, other->dilation_); std::swap(pad_h_, other->pad_h_); std::swap(pad_w_, other->pad_w_); std::swap(kernel_h_, other->kernel_h_); @@ -2190,11 +1638,11 @@ void ConvolutionOpParameter::Swap(ConvolutionOpParameter* other) { } } -::google::protobuf::Metadata ConvolutionOpParameter::GetMetadata() const { +::google::protobuf::Metadata ConvolutionParameter::GetMetadata() const { protobuf_AssignDescriptorsOnce(); ::google::protobuf::Metadata metadata; - metadata.descriptor = ConvolutionOpParameter_descriptor_; - metadata.reflection = ConvolutionOpParameter_reflection_; + metadata.descriptor = ConvolutionParameter_descriptor_; + metadata.reflection = ConvolutionParameter_reflection_; return metadata; } diff --git a/src/sita/proto/sita_operators.pb.h b/src/sita/proto/sita_operators.pb.h index 5431bd24..40503f6b 100644 --- a/src/sita/proto/sita_operators.pb.h +++ b/src/sita/proto/sita_operators.pb.h @@ -36,8 +36,7 @@ void protobuf_ShutdownFile_sita_5foperators_2eproto(); class GraphParameter; class OperatorParameter; -class AddOpParameter; -class ConvolutionOpParameter; +class ConvolutionParameter; // =================================================================== @@ -247,17 +246,12 @@ class OperatorParameter : public ::google::protobuf::Message { inline const ::google::protobuf::RepeatedPtrField< ::std::string>& output() const; inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_output(); - // repeated float loss_weight = 5; - inline int loss_weight_size() const; - inline void clear_loss_weight(); - static const int kLossWeightFieldNumber = 5; - inline float loss_weight(int index) const; - inline void set_loss_weight(int index, float value); - inline void add_loss_weight(float value); - inline const ::google::protobuf::RepeatedField< float >& - loss_weight() const; - inline ::google::protobuf::RepeatedField< float >* - mutable_loss_weight(); + // optional bool gradient_block = 5 [default = false]; + inline bool has_gradient_block() const; + inline void clear_gradient_block(); + static const int kGradientBlockFieldNumber = 5; + inline bool gradient_block() const; + inline void set_gradient_block(bool value); // repeated .sita.ParamConfig param = 6; inline int param_size() const; @@ -271,23 +265,14 @@ class OperatorParameter : public ::google::protobuf::Message { inline ::google::protobuf::RepeatedPtrField< ::sita::ParamConfig >* mutable_param(); - // optional .sita.AddOpParameter add_op_param = 100; - inline bool has_add_op_param() const; - inline void clear_add_op_param(); - static const int kAddOpParamFieldNumber = 100; - inline const ::sita::AddOpParameter& add_op_param() const; - inline ::sita::AddOpParameter* mutable_add_op_param(); - inline ::sita::AddOpParameter* release_add_op_param(); - inline void set_allocated_add_op_param(::sita::AddOpParameter* add_op_param); - - // optional .sita.ConvolutionOpParameter convolution_op_param = 101; - inline bool has_convolution_op_param() const; - inline void clear_convolution_op_param(); - static const int kConvolutionOpParamFieldNumber = 101; - inline const ::sita::ConvolutionOpParameter& convolution_op_param() const; - inline ::sita::ConvolutionOpParameter* mutable_convolution_op_param(); - inline ::sita::ConvolutionOpParameter* release_convolution_op_param(); - inline void set_allocated_convolution_op_param(::sita::ConvolutionOpParameter* convolution_op_param); + // optional .sita.ConvolutionParameter convolution_param = 101; + inline bool has_convolution_param() const; + inline void clear_convolution_param(); + static const int kConvolutionParamFieldNumber = 101; + inline const ::sita::ConvolutionParameter& convolution_param() const; + inline ::sita::ConvolutionParameter* mutable_convolution_param(); + inline ::sita::ConvolutionParameter* release_convolution_param(); + inline void set_allocated_convolution_param(::sita::ConvolutionParameter* convolution_param); // @@protoc_insertion_point(class_scope:sita.OperatorParameter) private: @@ -295,10 +280,10 @@ class OperatorParameter : public ::google::protobuf::Message { inline void clear_has_name(); inline void set_has_type(); inline void clear_has_type(); - inline void set_has_add_op_param(); - inline void clear_has_add_op_param(); - inline void set_has_convolution_op_param(); - inline void clear_has_convolution_op_param(); + inline void set_has_gradient_block(); + inline void clear_has_gradient_block(); + inline void set_has_convolution_param(); + inline void clear_has_convolution_param(); ::google::protobuf::UnknownFieldSet _unknown_fields_; @@ -308,10 +293,9 @@ class OperatorParameter : public ::google::protobuf::Message { ::std::string* type_; ::google::protobuf::RepeatedPtrField< ::std::string> input_; ::google::protobuf::RepeatedPtrField< ::std::string> output_; - ::google::protobuf::RepeatedField< float > loss_weight_; ::google::protobuf::RepeatedPtrField< ::sita::ParamConfig > param_; - ::sita::AddOpParameter* add_op_param_; - ::sita::ConvolutionOpParameter* convolution_op_param_; + ::sita::ConvolutionParameter* convolution_param_; + bool gradient_block_; friend void protobuf_AddDesc_sita_5foperators_2eproto(); friend void protobuf_AssignDesc_sita_5foperators_2eproto(); friend void protobuf_ShutdownFile_sita_5foperators_2eproto(); @@ -321,14 +305,14 @@ class OperatorParameter : public ::google::protobuf::Message { }; // ------------------------------------------------------------------- -class AddOpParameter : public ::google::protobuf::Message { +class ConvolutionParameter : public ::google::protobuf::Message { public: - AddOpParameter(); - virtual ~AddOpParameter(); + ConvolutionParameter(); + virtual ~ConvolutionParameter(); - AddOpParameter(const AddOpParameter& from); + ConvolutionParameter(const ConvolutionParameter& from); - inline AddOpParameter& operator=(const AddOpParameter& from) { + inline ConvolutionParameter& operator=(const ConvolutionParameter& from) { CopyFrom(from); return *this; } @@ -342,146 +326,17 @@ class AddOpParameter : public ::google::protobuf::Message { } static const ::google::protobuf::Descriptor* descriptor(); - static const AddOpParameter& default_instance(); + static const ConvolutionParameter& default_instance(); - void Swap(AddOpParameter* other); + void Swap(ConvolutionParameter* other); // implements Message ---------------------------------------------- - AddOpParameter* New() const; + ConvolutionParameter* New() const; void CopyFrom(const ::google::protobuf::Message& from); void MergeFrom(const ::google::protobuf::Message& from); - void CopyFrom(const AddOpParameter& from); - void MergeFrom(const AddOpParameter& from); - void Clear(); - bool IsInitialized() const; - - int ByteSize() const; - bool MergePartialFromCodedStream( - ::google::protobuf::io::CodedInputStream* input); - void SerializeWithCachedSizes( - ::google::protobuf::io::CodedOutputStream* output) const; - ::google::protobuf::uint8* SerializeWithCachedSizesToArray(::google::protobuf::uint8* output) const; - int GetCachedSize() const { return _cached_size_; } - private: - void SharedCtor(); - void SharedDtor(); - void SetCachedSize(int size) const; - public: - ::google::protobuf::Metadata GetMetadata() const; - - // nested types ---------------------------------------------------- - - // accessors ------------------------------------------------------- - - // optional uint32 pad_h = 1 [default = 0]; - inline bool has_pad_h() const; - inline void clear_pad_h(); - static const int kPadHFieldNumber = 1; - inline ::google::protobuf::uint32 pad_h() const; - inline void set_pad_h(::google::protobuf::uint32 value); - - // optional uint32 pad_w = 2 [default = 0]; - inline bool has_pad_w() const; - inline void clear_pad_w(); - static const int kPadWFieldNumber = 2; - inline ::google::protobuf::uint32 pad_w() const; - inline void set_pad_w(::google::protobuf::uint32 value); - - // optional uint32 kernel_h = 3; - inline bool has_kernel_h() const; - inline void clear_kernel_h(); - static const int kKernelHFieldNumber = 3; - inline ::google::protobuf::uint32 kernel_h() const; - inline void set_kernel_h(::google::protobuf::uint32 value); - - // optional uint32 kernel_w = 4; - inline bool has_kernel_w() const; - inline void clear_kernel_w(); - static const int kKernelWFieldNumber = 4; - inline ::google::protobuf::uint32 kernel_w() const; - inline void set_kernel_w(::google::protobuf::uint32 value); - - // optional uint32 stride_h = 5; - inline bool has_stride_h() const; - inline void clear_stride_h(); - static const int kStrideHFieldNumber = 5; - inline ::google::protobuf::uint32 stride_h() const; - inline void set_stride_h(::google::protobuf::uint32 value); - - // optional uint32 stride_w = 6; - inline bool has_stride_w() const; - inline void clear_stride_w(); - static const int kStrideWFieldNumber = 6; - inline ::google::protobuf::uint32 stride_w() const; - inline void set_stride_w(::google::protobuf::uint32 value); - - // @@protoc_insertion_point(class_scope:sita.AddOpParameter) - private: - inline void set_has_pad_h(); - inline void clear_has_pad_h(); - inline void set_has_pad_w(); - inline void clear_has_pad_w(); - inline void set_has_kernel_h(); - inline void clear_has_kernel_h(); - inline void set_has_kernel_w(); - inline void clear_has_kernel_w(); - inline void set_has_stride_h(); - inline void clear_has_stride_h(); - inline void set_has_stride_w(); - inline void clear_has_stride_w(); - - ::google::protobuf::UnknownFieldSet _unknown_fields_; - - ::google::protobuf::uint32 _has_bits_[1]; - mutable int _cached_size_; - ::google::protobuf::uint32 pad_h_; - ::google::protobuf::uint32 pad_w_; - ::google::protobuf::uint32 kernel_h_; - ::google::protobuf::uint32 kernel_w_; - ::google::protobuf::uint32 stride_h_; - ::google::protobuf::uint32 stride_w_; - friend void protobuf_AddDesc_sita_5foperators_2eproto(); - friend void protobuf_AssignDesc_sita_5foperators_2eproto(); - friend void protobuf_ShutdownFile_sita_5foperators_2eproto(); - - void InitAsDefaultInstance(); - static AddOpParameter* default_instance_; -}; -// ------------------------------------------------------------------- - -class ConvolutionOpParameter : public ::google::protobuf::Message { - public: - ConvolutionOpParameter(); - virtual ~ConvolutionOpParameter(); - - ConvolutionOpParameter(const ConvolutionOpParameter& from); - - inline ConvolutionOpParameter& operator=(const ConvolutionOpParameter& from) { - CopyFrom(from); - return *this; - } - - inline const ::google::protobuf::UnknownFieldSet& unknown_fields() const { - return _unknown_fields_; - } - - inline ::google::protobuf::UnknownFieldSet* mutable_unknown_fields() { - return &_unknown_fields_; - } - - static const ::google::protobuf::Descriptor* descriptor(); - static const ConvolutionOpParameter& default_instance(); - - void Swap(ConvolutionOpParameter* other); - - // implements Message ---------------------------------------------- - - ConvolutionOpParameter* New() const; - void CopyFrom(const ::google::protobuf::Message& from); - void MergeFrom(const ::google::protobuf::Message& from); - void CopyFrom(const ConvolutionOpParameter& from); - void MergeFrom(const ConvolutionOpParameter& from); + void CopyFrom(const ConvolutionParameter& from); + void MergeFrom(const ConvolutionParameter& from); void Clear(); bool IsInitialized() const; @@ -517,84 +372,77 @@ class ConvolutionOpParameter : public ::google::protobuf::Message { inline bool bias_term() const; inline void set_bias_term(bool value); - // optional uint32 pad = 3; + // optional uint32 pad = 3 [default = 1]; inline bool has_pad() const; inline void clear_pad(); static const int kPadFieldNumber = 3; inline ::google::protobuf::uint32 pad() const; inline void set_pad(::google::protobuf::uint32 value); - // optional uint32 kernel_size = 4; + // optional uint32 kernel_size = 4 [default = 3]; inline bool has_kernel_size() const; inline void clear_kernel_size(); static const int kKernelSizeFieldNumber = 4; inline ::google::protobuf::uint32 kernel_size() const; inline void set_kernel_size(::google::protobuf::uint32 value); - // optional uint32 stride = 5; + // optional uint32 stride = 5 [default = 1]; inline bool has_stride() const; inline void clear_stride(); static const int kStrideFieldNumber = 5; inline ::google::protobuf::uint32 stride() const; inline void set_stride(::google::protobuf::uint32 value); - // optional uint32 dilation = 6 [default = 1]; - inline bool has_dilation() const; - inline void clear_dilation(); - static const int kDilationFieldNumber = 6; - inline ::google::protobuf::uint32 dilation() const; - inline void set_dilation(::google::protobuf::uint32 value); - - // optional uint32 pad_h = 7 [default = 0]; + // optional uint32 pad_h = 6; inline bool has_pad_h() const; inline void clear_pad_h(); - static const int kPadHFieldNumber = 7; + static const int kPadHFieldNumber = 6; inline ::google::protobuf::uint32 pad_h() const; inline void set_pad_h(::google::protobuf::uint32 value); - // optional uint32 pad_w = 8 [default = 0]; + // optional uint32 pad_w = 7; inline bool has_pad_w() const; inline void clear_pad_w(); - static const int kPadWFieldNumber = 8; + static const int kPadWFieldNumber = 7; inline ::google::protobuf::uint32 pad_w() const; inline void set_pad_w(::google::protobuf::uint32 value); - // optional uint32 kernel_h = 9 [default = 3]; + // optional uint32 kernel_h = 8; inline bool has_kernel_h() const; inline void clear_kernel_h(); - static const int kKernelHFieldNumber = 9; + static const int kKernelHFieldNumber = 8; inline ::google::protobuf::uint32 kernel_h() const; inline void set_kernel_h(::google::protobuf::uint32 value); - // optional uint32 kernel_w = 10 [default = 3]; + // optional uint32 kernel_w = 9; inline bool has_kernel_w() const; inline void clear_kernel_w(); - static const int kKernelWFieldNumber = 10; + static const int kKernelWFieldNumber = 9; inline ::google::protobuf::uint32 kernel_w() const; inline void set_kernel_w(::google::protobuf::uint32 value); - // optional uint32 stride_h = 11; + // optional uint32 stride_h = 10; inline bool has_stride_h() const; inline void clear_stride_h(); - static const int kStrideHFieldNumber = 11; + static const int kStrideHFieldNumber = 10; inline ::google::protobuf::uint32 stride_h() const; inline void set_stride_h(::google::protobuf::uint32 value); - // optional uint32 stride_w = 12; + // optional uint32 stride_w = 11; inline bool has_stride_w() const; inline void clear_stride_w(); - static const int kStrideWFieldNumber = 12; + static const int kStrideWFieldNumber = 11; inline ::google::protobuf::uint32 stride_w() const; inline void set_stride_w(::google::protobuf::uint32 value); - // optional uint32 group = 13 [default = 1]; + // optional uint32 group = 12 [default = 1]; inline bool has_group() const; inline void clear_group(); - static const int kGroupFieldNumber = 13; + static const int kGroupFieldNumber = 12; inline ::google::protobuf::uint32 group() const; inline void set_group(::google::protobuf::uint32 value); - // @@protoc_insertion_point(class_scope:sita.ConvolutionOpParameter) + // @@protoc_insertion_point(class_scope:sita.ConvolutionParameter) private: inline void set_has_num_output(); inline void clear_has_num_output(); @@ -606,8 +454,6 @@ class ConvolutionOpParameter : public ::google::protobuf::Message { inline void clear_has_kernel_size(); inline void set_has_stride(); inline void clear_has_stride(); - inline void set_has_dilation(); - inline void clear_has_dilation(); inline void set_has_pad_h(); inline void clear_has_pad_h(); inline void set_has_pad_w(); @@ -632,7 +478,6 @@ class ConvolutionOpParameter : public ::google::protobuf::Message { ::google::protobuf::uint32 pad_; ::google::protobuf::uint32 kernel_size_; ::google::protobuf::uint32 stride_; - ::google::protobuf::uint32 dilation_; ::google::protobuf::uint32 pad_h_; ::google::protobuf::uint32 pad_w_; ::google::protobuf::uint32 kernel_h_; @@ -645,7 +490,7 @@ class ConvolutionOpParameter : public ::google::protobuf::Message { friend void protobuf_ShutdownFile_sita_5foperators_2eproto(); void InitAsDefaultInstance(); - static ConvolutionOpParameter* default_instance_; + static ConvolutionParameter* default_instance_; }; // =================================================================== @@ -1024,34 +869,28 @@ OperatorParameter::mutable_output() { return &output_; } -// repeated float loss_weight = 5; -inline int OperatorParameter::loss_weight_size() const { - return loss_weight_.size(); -} -inline void OperatorParameter::clear_loss_weight() { - loss_weight_.Clear(); +// optional bool gradient_block = 5 [default = false]; +inline bool OperatorParameter::has_gradient_block() const { + return (_has_bits_[0] & 0x00000010u) != 0; } -inline float OperatorParameter::loss_weight(int index) const { - // @@protoc_insertion_point(field_get:sita.OperatorParameter.loss_weight) - return loss_weight_.Get(index); +inline void OperatorParameter::set_has_gradient_block() { + _has_bits_[0] |= 0x00000010u; } -inline void OperatorParameter::set_loss_weight(int index, float value) { - loss_weight_.Set(index, value); - // @@protoc_insertion_point(field_set:sita.OperatorParameter.loss_weight) +inline void OperatorParameter::clear_has_gradient_block() { + _has_bits_[0] &= ~0x00000010u; } -inline void OperatorParameter::add_loss_weight(float value) { - loss_weight_.Add(value); - // @@protoc_insertion_point(field_add:sita.OperatorParameter.loss_weight) +inline void OperatorParameter::clear_gradient_block() { + gradient_block_ = false; + clear_has_gradient_block(); } -inline const ::google::protobuf::RepeatedField< float >& -OperatorParameter::loss_weight() const { - // @@protoc_insertion_point(field_list:sita.OperatorParameter.loss_weight) - return loss_weight_; +inline bool OperatorParameter::gradient_block() const { + // @@protoc_insertion_point(field_get:sita.OperatorParameter.gradient_block) + return gradient_block_; } -inline ::google::protobuf::RepeatedField< float >* -OperatorParameter::mutable_loss_weight() { - // @@protoc_insertion_point(field_mutable_list:sita.OperatorParameter.loss_weight) - return &loss_weight_; +inline void OperatorParameter::set_gradient_block(bool value) { + set_has_gradient_block(); + gradient_block_ = value; + // @@protoc_insertion_point(field_set:sita.OperatorParameter.gradient_block) } // repeated .sita.ParamConfig param = 6; @@ -1084,550 +923,337 @@ OperatorParameter::mutable_param() { return ¶m_; } -// optional .sita.AddOpParameter add_op_param = 100; -inline bool OperatorParameter::has_add_op_param() const { +// optional .sita.ConvolutionParameter convolution_param = 101; +inline bool OperatorParameter::has_convolution_param() const { return (_has_bits_[0] & 0x00000040u) != 0; } -inline void OperatorParameter::set_has_add_op_param() { +inline void OperatorParameter::set_has_convolution_param() { _has_bits_[0] |= 0x00000040u; } -inline void OperatorParameter::clear_has_add_op_param() { +inline void OperatorParameter::clear_has_convolution_param() { _has_bits_[0] &= ~0x00000040u; } -inline void OperatorParameter::clear_add_op_param() { - if (add_op_param_ != NULL) add_op_param_->::sita::AddOpParameter::Clear(); - clear_has_add_op_param(); +inline void OperatorParameter::clear_convolution_param() { + if (convolution_param_ != NULL) convolution_param_->::sita::ConvolutionParameter::Clear(); + clear_has_convolution_param(); } -inline const ::sita::AddOpParameter& OperatorParameter::add_op_param() const { - // @@protoc_insertion_point(field_get:sita.OperatorParameter.add_op_param) - return add_op_param_ != NULL ? *add_op_param_ : *default_instance_->add_op_param_; +inline const ::sita::ConvolutionParameter& OperatorParameter::convolution_param() const { + // @@protoc_insertion_point(field_get:sita.OperatorParameter.convolution_param) + return convolution_param_ != NULL ? *convolution_param_ : *default_instance_->convolution_param_; } -inline ::sita::AddOpParameter* OperatorParameter::mutable_add_op_param() { - set_has_add_op_param(); - if (add_op_param_ == NULL) add_op_param_ = new ::sita::AddOpParameter; - // @@protoc_insertion_point(field_mutable:sita.OperatorParameter.add_op_param) - return add_op_param_; +inline ::sita::ConvolutionParameter* OperatorParameter::mutable_convolution_param() { + set_has_convolution_param(); + if (convolution_param_ == NULL) convolution_param_ = new ::sita::ConvolutionParameter; + // @@protoc_insertion_point(field_mutable:sita.OperatorParameter.convolution_param) + return convolution_param_; } -inline ::sita::AddOpParameter* OperatorParameter::release_add_op_param() { - clear_has_add_op_param(); - ::sita::AddOpParameter* temp = add_op_param_; - add_op_param_ = NULL; +inline ::sita::ConvolutionParameter* OperatorParameter::release_convolution_param() { + clear_has_convolution_param(); + ::sita::ConvolutionParameter* temp = convolution_param_; + convolution_param_ = NULL; return temp; } -inline void OperatorParameter::set_allocated_add_op_param(::sita::AddOpParameter* add_op_param) { - delete add_op_param_; - add_op_param_ = add_op_param; - if (add_op_param) { - set_has_add_op_param(); +inline void OperatorParameter::set_allocated_convolution_param(::sita::ConvolutionParameter* convolution_param) { + delete convolution_param_; + convolution_param_ = convolution_param; + if (convolution_param) { + set_has_convolution_param(); } else { - clear_has_add_op_param(); + clear_has_convolution_param(); } - // @@protoc_insertion_point(field_set_allocated:sita.OperatorParameter.add_op_param) -} - -// optional .sita.ConvolutionOpParameter convolution_op_param = 101; -inline bool OperatorParameter::has_convolution_op_param() const { - return (_has_bits_[0] & 0x00000080u) != 0; -} -inline void OperatorParameter::set_has_convolution_op_param() { - _has_bits_[0] |= 0x00000080u; -} -inline void OperatorParameter::clear_has_convolution_op_param() { - _has_bits_[0] &= ~0x00000080u; -} -inline void OperatorParameter::clear_convolution_op_param() { - if (convolution_op_param_ != NULL) convolution_op_param_->::sita::ConvolutionOpParameter::Clear(); - clear_has_convolution_op_param(); -} -inline const ::sita::ConvolutionOpParameter& OperatorParameter::convolution_op_param() const { - // @@protoc_insertion_point(field_get:sita.OperatorParameter.convolution_op_param) - return convolution_op_param_ != NULL ? *convolution_op_param_ : *default_instance_->convolution_op_param_; -} -inline ::sita::ConvolutionOpParameter* OperatorParameter::mutable_convolution_op_param() { - set_has_convolution_op_param(); - if (convolution_op_param_ == NULL) convolution_op_param_ = new ::sita::ConvolutionOpParameter; - // @@protoc_insertion_point(field_mutable:sita.OperatorParameter.convolution_op_param) - return convolution_op_param_; -} -inline ::sita::ConvolutionOpParameter* OperatorParameter::release_convolution_op_param() { - clear_has_convolution_op_param(); - ::sita::ConvolutionOpParameter* temp = convolution_op_param_; - convolution_op_param_ = NULL; - return temp; -} -inline void OperatorParameter::set_allocated_convolution_op_param(::sita::ConvolutionOpParameter* convolution_op_param) { - delete convolution_op_param_; - convolution_op_param_ = convolution_op_param; - if (convolution_op_param) { - set_has_convolution_op_param(); - } else { - clear_has_convolution_op_param(); - } - // @@protoc_insertion_point(field_set_allocated:sita.OperatorParameter.convolution_op_param) + // @@protoc_insertion_point(field_set_allocated:sita.OperatorParameter.convolution_param) } // ------------------------------------------------------------------- -// AddOpParameter - -// optional uint32 pad_h = 1 [default = 0]; -inline bool AddOpParameter::has_pad_h() const { - return (_has_bits_[0] & 0x00000001u) != 0; -} -inline void AddOpParameter::set_has_pad_h() { - _has_bits_[0] |= 0x00000001u; -} -inline void AddOpParameter::clear_has_pad_h() { - _has_bits_[0] &= ~0x00000001u; -} -inline void AddOpParameter::clear_pad_h() { - pad_h_ = 0u; - clear_has_pad_h(); -} -inline ::google::protobuf::uint32 AddOpParameter::pad_h() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.pad_h) - return pad_h_; -} -inline void AddOpParameter::set_pad_h(::google::protobuf::uint32 value) { - set_has_pad_h(); - pad_h_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.pad_h) -} - -// optional uint32 pad_w = 2 [default = 0]; -inline bool AddOpParameter::has_pad_w() const { - return (_has_bits_[0] & 0x00000002u) != 0; -} -inline void AddOpParameter::set_has_pad_w() { - _has_bits_[0] |= 0x00000002u; -} -inline void AddOpParameter::clear_has_pad_w() { - _has_bits_[0] &= ~0x00000002u; -} -inline void AddOpParameter::clear_pad_w() { - pad_w_ = 0u; - clear_has_pad_w(); -} -inline ::google::protobuf::uint32 AddOpParameter::pad_w() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.pad_w) - return pad_w_; -} -inline void AddOpParameter::set_pad_w(::google::protobuf::uint32 value) { - set_has_pad_w(); - pad_w_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.pad_w) -} - -// optional uint32 kernel_h = 3; -inline bool AddOpParameter::has_kernel_h() const { - return (_has_bits_[0] & 0x00000004u) != 0; -} -inline void AddOpParameter::set_has_kernel_h() { - _has_bits_[0] |= 0x00000004u; -} -inline void AddOpParameter::clear_has_kernel_h() { - _has_bits_[0] &= ~0x00000004u; -} -inline void AddOpParameter::clear_kernel_h() { - kernel_h_ = 0u; - clear_has_kernel_h(); -} -inline ::google::protobuf::uint32 AddOpParameter::kernel_h() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.kernel_h) - return kernel_h_; -} -inline void AddOpParameter::set_kernel_h(::google::protobuf::uint32 value) { - set_has_kernel_h(); - kernel_h_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.kernel_h) -} - -// optional uint32 kernel_w = 4; -inline bool AddOpParameter::has_kernel_w() const { - return (_has_bits_[0] & 0x00000008u) != 0; -} -inline void AddOpParameter::set_has_kernel_w() { - _has_bits_[0] |= 0x00000008u; -} -inline void AddOpParameter::clear_has_kernel_w() { - _has_bits_[0] &= ~0x00000008u; -} -inline void AddOpParameter::clear_kernel_w() { - kernel_w_ = 0u; - clear_has_kernel_w(); -} -inline ::google::protobuf::uint32 AddOpParameter::kernel_w() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.kernel_w) - return kernel_w_; -} -inline void AddOpParameter::set_kernel_w(::google::protobuf::uint32 value) { - set_has_kernel_w(); - kernel_w_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.kernel_w) -} - -// optional uint32 stride_h = 5; -inline bool AddOpParameter::has_stride_h() const { - return (_has_bits_[0] & 0x00000010u) != 0; -} -inline void AddOpParameter::set_has_stride_h() { - _has_bits_[0] |= 0x00000010u; -} -inline void AddOpParameter::clear_has_stride_h() { - _has_bits_[0] &= ~0x00000010u; -} -inline void AddOpParameter::clear_stride_h() { - stride_h_ = 0u; - clear_has_stride_h(); -} -inline ::google::protobuf::uint32 AddOpParameter::stride_h() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.stride_h) - return stride_h_; -} -inline void AddOpParameter::set_stride_h(::google::protobuf::uint32 value) { - set_has_stride_h(); - stride_h_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.stride_h) -} - -// optional uint32 stride_w = 6; -inline bool AddOpParameter::has_stride_w() const { - return (_has_bits_[0] & 0x00000020u) != 0; -} -inline void AddOpParameter::set_has_stride_w() { - _has_bits_[0] |= 0x00000020u; -} -inline void AddOpParameter::clear_has_stride_w() { - _has_bits_[0] &= ~0x00000020u; -} -inline void AddOpParameter::clear_stride_w() { - stride_w_ = 0u; - clear_has_stride_w(); -} -inline ::google::protobuf::uint32 AddOpParameter::stride_w() const { - // @@protoc_insertion_point(field_get:sita.AddOpParameter.stride_w) - return stride_w_; -} -inline void AddOpParameter::set_stride_w(::google::protobuf::uint32 value) { - set_has_stride_w(); - stride_w_ = value; - // @@protoc_insertion_point(field_set:sita.AddOpParameter.stride_w) -} - -// ------------------------------------------------------------------- - -// ConvolutionOpParameter +// ConvolutionParameter // optional uint32 num_output = 1; -inline bool ConvolutionOpParameter::has_num_output() const { +inline bool ConvolutionParameter::has_num_output() const { return (_has_bits_[0] & 0x00000001u) != 0; } -inline void ConvolutionOpParameter::set_has_num_output() { +inline void ConvolutionParameter::set_has_num_output() { _has_bits_[0] |= 0x00000001u; } -inline void ConvolutionOpParameter::clear_has_num_output() { +inline void ConvolutionParameter::clear_has_num_output() { _has_bits_[0] &= ~0x00000001u; } -inline void ConvolutionOpParameter::clear_num_output() { +inline void ConvolutionParameter::clear_num_output() { num_output_ = 0u; clear_has_num_output(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::num_output() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.num_output) +inline ::google::protobuf::uint32 ConvolutionParameter::num_output() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.num_output) return num_output_; } -inline void ConvolutionOpParameter::set_num_output(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_num_output(::google::protobuf::uint32 value) { set_has_num_output(); num_output_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.num_output) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.num_output) } // optional bool bias_term = 2 [default = true]; -inline bool ConvolutionOpParameter::has_bias_term() const { +inline bool ConvolutionParameter::has_bias_term() const { return (_has_bits_[0] & 0x00000002u) != 0; } -inline void ConvolutionOpParameter::set_has_bias_term() { +inline void ConvolutionParameter::set_has_bias_term() { _has_bits_[0] |= 0x00000002u; } -inline void ConvolutionOpParameter::clear_has_bias_term() { +inline void ConvolutionParameter::clear_has_bias_term() { _has_bits_[0] &= ~0x00000002u; } -inline void ConvolutionOpParameter::clear_bias_term() { +inline void ConvolutionParameter::clear_bias_term() { bias_term_ = true; clear_has_bias_term(); } -inline bool ConvolutionOpParameter::bias_term() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.bias_term) +inline bool ConvolutionParameter::bias_term() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.bias_term) return bias_term_; } -inline void ConvolutionOpParameter::set_bias_term(bool value) { +inline void ConvolutionParameter::set_bias_term(bool value) { set_has_bias_term(); bias_term_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.bias_term) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.bias_term) } -// optional uint32 pad = 3; -inline bool ConvolutionOpParameter::has_pad() const { +// optional uint32 pad = 3 [default = 1]; +inline bool ConvolutionParameter::has_pad() const { return (_has_bits_[0] & 0x00000004u) != 0; } -inline void ConvolutionOpParameter::set_has_pad() { +inline void ConvolutionParameter::set_has_pad() { _has_bits_[0] |= 0x00000004u; } -inline void ConvolutionOpParameter::clear_has_pad() { +inline void ConvolutionParameter::clear_has_pad() { _has_bits_[0] &= ~0x00000004u; } -inline void ConvolutionOpParameter::clear_pad() { - pad_ = 0u; +inline void ConvolutionParameter::clear_pad() { + pad_ = 1u; clear_has_pad(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::pad() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.pad) +inline ::google::protobuf::uint32 ConvolutionParameter::pad() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.pad) return pad_; } -inline void ConvolutionOpParameter::set_pad(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_pad(::google::protobuf::uint32 value) { set_has_pad(); pad_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.pad) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.pad) } -// optional uint32 kernel_size = 4; -inline bool ConvolutionOpParameter::has_kernel_size() const { +// optional uint32 kernel_size = 4 [default = 3]; +inline bool ConvolutionParameter::has_kernel_size() const { return (_has_bits_[0] & 0x00000008u) != 0; } -inline void ConvolutionOpParameter::set_has_kernel_size() { +inline void ConvolutionParameter::set_has_kernel_size() { _has_bits_[0] |= 0x00000008u; } -inline void ConvolutionOpParameter::clear_has_kernel_size() { +inline void ConvolutionParameter::clear_has_kernel_size() { _has_bits_[0] &= ~0x00000008u; } -inline void ConvolutionOpParameter::clear_kernel_size() { - kernel_size_ = 0u; +inline void ConvolutionParameter::clear_kernel_size() { + kernel_size_ = 3u; clear_has_kernel_size(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::kernel_size() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.kernel_size) +inline ::google::protobuf::uint32 ConvolutionParameter::kernel_size() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.kernel_size) return kernel_size_; } -inline void ConvolutionOpParameter::set_kernel_size(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_kernel_size(::google::protobuf::uint32 value) { set_has_kernel_size(); kernel_size_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.kernel_size) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.kernel_size) } -// optional uint32 stride = 5; -inline bool ConvolutionOpParameter::has_stride() const { +// optional uint32 stride = 5 [default = 1]; +inline bool ConvolutionParameter::has_stride() const { return (_has_bits_[0] & 0x00000010u) != 0; } -inline void ConvolutionOpParameter::set_has_stride() { +inline void ConvolutionParameter::set_has_stride() { _has_bits_[0] |= 0x00000010u; } -inline void ConvolutionOpParameter::clear_has_stride() { +inline void ConvolutionParameter::clear_has_stride() { _has_bits_[0] &= ~0x00000010u; } -inline void ConvolutionOpParameter::clear_stride() { - stride_ = 0u; +inline void ConvolutionParameter::clear_stride() { + stride_ = 1u; clear_has_stride(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::stride() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.stride) +inline ::google::protobuf::uint32 ConvolutionParameter::stride() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.stride) return stride_; } -inline void ConvolutionOpParameter::set_stride(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_stride(::google::protobuf::uint32 value) { set_has_stride(); stride_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.stride) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.stride) } -// optional uint32 dilation = 6 [default = 1]; -inline bool ConvolutionOpParameter::has_dilation() const { +// optional uint32 pad_h = 6; +inline bool ConvolutionParameter::has_pad_h() const { return (_has_bits_[0] & 0x00000020u) != 0; } -inline void ConvolutionOpParameter::set_has_dilation() { +inline void ConvolutionParameter::set_has_pad_h() { _has_bits_[0] |= 0x00000020u; } -inline void ConvolutionOpParameter::clear_has_dilation() { +inline void ConvolutionParameter::clear_has_pad_h() { _has_bits_[0] &= ~0x00000020u; } -inline void ConvolutionOpParameter::clear_dilation() { - dilation_ = 1u; - clear_has_dilation(); -} -inline ::google::protobuf::uint32 ConvolutionOpParameter::dilation() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.dilation) - return dilation_; -} -inline void ConvolutionOpParameter::set_dilation(::google::protobuf::uint32 value) { - set_has_dilation(); - dilation_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.dilation) -} - -// optional uint32 pad_h = 7 [default = 0]; -inline bool ConvolutionOpParameter::has_pad_h() const { - return (_has_bits_[0] & 0x00000040u) != 0; -} -inline void ConvolutionOpParameter::set_has_pad_h() { - _has_bits_[0] |= 0x00000040u; -} -inline void ConvolutionOpParameter::clear_has_pad_h() { - _has_bits_[0] &= ~0x00000040u; -} -inline void ConvolutionOpParameter::clear_pad_h() { +inline void ConvolutionParameter::clear_pad_h() { pad_h_ = 0u; clear_has_pad_h(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::pad_h() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.pad_h) +inline ::google::protobuf::uint32 ConvolutionParameter::pad_h() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.pad_h) return pad_h_; } -inline void ConvolutionOpParameter::set_pad_h(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_pad_h(::google::protobuf::uint32 value) { set_has_pad_h(); pad_h_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.pad_h) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.pad_h) } -// optional uint32 pad_w = 8 [default = 0]; -inline bool ConvolutionOpParameter::has_pad_w() const { - return (_has_bits_[0] & 0x00000080u) != 0; +// optional uint32 pad_w = 7; +inline bool ConvolutionParameter::has_pad_w() const { + return (_has_bits_[0] & 0x00000040u) != 0; } -inline void ConvolutionOpParameter::set_has_pad_w() { - _has_bits_[0] |= 0x00000080u; +inline void ConvolutionParameter::set_has_pad_w() { + _has_bits_[0] |= 0x00000040u; } -inline void ConvolutionOpParameter::clear_has_pad_w() { - _has_bits_[0] &= ~0x00000080u; +inline void ConvolutionParameter::clear_has_pad_w() { + _has_bits_[0] &= ~0x00000040u; } -inline void ConvolutionOpParameter::clear_pad_w() { +inline void ConvolutionParameter::clear_pad_w() { pad_w_ = 0u; clear_has_pad_w(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::pad_w() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.pad_w) +inline ::google::protobuf::uint32 ConvolutionParameter::pad_w() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.pad_w) return pad_w_; } -inline void ConvolutionOpParameter::set_pad_w(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_pad_w(::google::protobuf::uint32 value) { set_has_pad_w(); pad_w_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.pad_w) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.pad_w) } -// optional uint32 kernel_h = 9 [default = 3]; -inline bool ConvolutionOpParameter::has_kernel_h() const { - return (_has_bits_[0] & 0x00000100u) != 0; +// optional uint32 kernel_h = 8; +inline bool ConvolutionParameter::has_kernel_h() const { + return (_has_bits_[0] & 0x00000080u) != 0; } -inline void ConvolutionOpParameter::set_has_kernel_h() { - _has_bits_[0] |= 0x00000100u; +inline void ConvolutionParameter::set_has_kernel_h() { + _has_bits_[0] |= 0x00000080u; } -inline void ConvolutionOpParameter::clear_has_kernel_h() { - _has_bits_[0] &= ~0x00000100u; +inline void ConvolutionParameter::clear_has_kernel_h() { + _has_bits_[0] &= ~0x00000080u; } -inline void ConvolutionOpParameter::clear_kernel_h() { - kernel_h_ = 3u; +inline void ConvolutionParameter::clear_kernel_h() { + kernel_h_ = 0u; clear_has_kernel_h(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::kernel_h() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.kernel_h) +inline ::google::protobuf::uint32 ConvolutionParameter::kernel_h() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.kernel_h) return kernel_h_; } -inline void ConvolutionOpParameter::set_kernel_h(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_kernel_h(::google::protobuf::uint32 value) { set_has_kernel_h(); kernel_h_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.kernel_h) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.kernel_h) } -// optional uint32 kernel_w = 10 [default = 3]; -inline bool ConvolutionOpParameter::has_kernel_w() const { - return (_has_bits_[0] & 0x00000200u) != 0; +// optional uint32 kernel_w = 9; +inline bool ConvolutionParameter::has_kernel_w() const { + return (_has_bits_[0] & 0x00000100u) != 0; } -inline void ConvolutionOpParameter::set_has_kernel_w() { - _has_bits_[0] |= 0x00000200u; +inline void ConvolutionParameter::set_has_kernel_w() { + _has_bits_[0] |= 0x00000100u; } -inline void ConvolutionOpParameter::clear_has_kernel_w() { - _has_bits_[0] &= ~0x00000200u; +inline void ConvolutionParameter::clear_has_kernel_w() { + _has_bits_[0] &= ~0x00000100u; } -inline void ConvolutionOpParameter::clear_kernel_w() { - kernel_w_ = 3u; +inline void ConvolutionParameter::clear_kernel_w() { + kernel_w_ = 0u; clear_has_kernel_w(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::kernel_w() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.kernel_w) +inline ::google::protobuf::uint32 ConvolutionParameter::kernel_w() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.kernel_w) return kernel_w_; } -inline void ConvolutionOpParameter::set_kernel_w(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_kernel_w(::google::protobuf::uint32 value) { set_has_kernel_w(); kernel_w_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.kernel_w) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.kernel_w) } -// optional uint32 stride_h = 11; -inline bool ConvolutionOpParameter::has_stride_h() const { - return (_has_bits_[0] & 0x00000400u) != 0; +// optional uint32 stride_h = 10; +inline bool ConvolutionParameter::has_stride_h() const { + return (_has_bits_[0] & 0x00000200u) != 0; } -inline void ConvolutionOpParameter::set_has_stride_h() { - _has_bits_[0] |= 0x00000400u; +inline void ConvolutionParameter::set_has_stride_h() { + _has_bits_[0] |= 0x00000200u; } -inline void ConvolutionOpParameter::clear_has_stride_h() { - _has_bits_[0] &= ~0x00000400u; +inline void ConvolutionParameter::clear_has_stride_h() { + _has_bits_[0] &= ~0x00000200u; } -inline void ConvolutionOpParameter::clear_stride_h() { +inline void ConvolutionParameter::clear_stride_h() { stride_h_ = 0u; clear_has_stride_h(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::stride_h() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.stride_h) +inline ::google::protobuf::uint32 ConvolutionParameter::stride_h() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.stride_h) return stride_h_; } -inline void ConvolutionOpParameter::set_stride_h(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_stride_h(::google::protobuf::uint32 value) { set_has_stride_h(); stride_h_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.stride_h) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.stride_h) } -// optional uint32 stride_w = 12; -inline bool ConvolutionOpParameter::has_stride_w() const { - return (_has_bits_[0] & 0x00000800u) != 0; +// optional uint32 stride_w = 11; +inline bool ConvolutionParameter::has_stride_w() const { + return (_has_bits_[0] & 0x00000400u) != 0; } -inline void ConvolutionOpParameter::set_has_stride_w() { - _has_bits_[0] |= 0x00000800u; +inline void ConvolutionParameter::set_has_stride_w() { + _has_bits_[0] |= 0x00000400u; } -inline void ConvolutionOpParameter::clear_has_stride_w() { - _has_bits_[0] &= ~0x00000800u; +inline void ConvolutionParameter::clear_has_stride_w() { + _has_bits_[0] &= ~0x00000400u; } -inline void ConvolutionOpParameter::clear_stride_w() { +inline void ConvolutionParameter::clear_stride_w() { stride_w_ = 0u; clear_has_stride_w(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::stride_w() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.stride_w) +inline ::google::protobuf::uint32 ConvolutionParameter::stride_w() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.stride_w) return stride_w_; } -inline void ConvolutionOpParameter::set_stride_w(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_stride_w(::google::protobuf::uint32 value) { set_has_stride_w(); stride_w_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.stride_w) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.stride_w) } -// optional uint32 group = 13 [default = 1]; -inline bool ConvolutionOpParameter::has_group() const { - return (_has_bits_[0] & 0x00001000u) != 0; +// optional uint32 group = 12 [default = 1]; +inline bool ConvolutionParameter::has_group() const { + return (_has_bits_[0] & 0x00000800u) != 0; } -inline void ConvolutionOpParameter::set_has_group() { - _has_bits_[0] |= 0x00001000u; +inline void ConvolutionParameter::set_has_group() { + _has_bits_[0] |= 0x00000800u; } -inline void ConvolutionOpParameter::clear_has_group() { - _has_bits_[0] &= ~0x00001000u; +inline void ConvolutionParameter::clear_has_group() { + _has_bits_[0] &= ~0x00000800u; } -inline void ConvolutionOpParameter::clear_group() { +inline void ConvolutionParameter::clear_group() { group_ = 1u; clear_has_group(); } -inline ::google::protobuf::uint32 ConvolutionOpParameter::group() const { - // @@protoc_insertion_point(field_get:sita.ConvolutionOpParameter.group) +inline ::google::protobuf::uint32 ConvolutionParameter::group() const { + // @@protoc_insertion_point(field_get:sita.ConvolutionParameter.group) return group_; } -inline void ConvolutionOpParameter::set_group(::google::protobuf::uint32 value) { +inline void ConvolutionParameter::set_group(::google::protobuf::uint32 value) { set_has_group(); group_ = value; - // @@protoc_insertion_point(field_set:sita.ConvolutionOpParameter.group) + // @@protoc_insertion_point(field_set:sita.ConvolutionParameter.group) } diff --git a/src/sita/proto/sita_operators.proto b/src/sita/proto/sita_operators.proto index 6b324f25..d86e2420 100644 --- a/src/sita/proto/sita_operators.proto +++ b/src/sita/proto/sita_operators.proto @@ -12,39 +12,27 @@ message OperatorParameter { optional string type = 2; // the operator type repeated string input = 3; // the name of each input repeated string output = 4; // the name of each output - repeated float loss_weight = 5; + optional bool gradient_block = 5 [default = false]; repeated ParamConfig param = 6; //param config //operators - optional AddOpParameter add_op_param = 100; - optional ConvolutionOpParameter convolution_op_param = 101; + optional ConvolutionParameter convolution_param = 101; } -message AddOpParameter{ - optional uint32 pad_h = 1 [default = 0]; // The padding height (2D only) - optional uint32 pad_w = 2 [default = 0]; // The padding width (2D only) - optional uint32 kernel_h = 3; // The kernel height (2D only) - optional uint32 kernel_w = 4; // The kernel width (2D only) - optional uint32 stride_h = 5; // The stride height (2D only) - optional uint32 stride_w = 6; // The stride width (2D only) -} -message ConvolutionOpParameter{ +message ConvolutionParameter{ optional uint32 num_output = 1; // The number of outputs optional bool bias_term = 2 [default = true]; // whether to have bias terms - optional uint32 pad = 3; // The padding size; defaults to 0 - optional uint32 kernel_size = 4; // The kernel size - optional uint32 stride = 5; // The stride; defaults to 1 - optional uint32 dilation = 6 [default = 1]; // The dilation; defaults to 1 - - optional uint32 pad_h = 7 [default = 0]; // The padding height (2D only) - optional uint32 pad_w = 8 [default = 0]; // The padding width (2D only) - optional uint32 kernel_h = 9 [default = 3]; // The kernel height (2D only) - optional uint32 kernel_w = 10 [default = 3]; // The kernel width (2D only) - optional uint32 stride_h = 11; // The stride height (2D only) - optional uint32 stride_w = 12; // The stride width (2D only) - optional uint32 group = 13 [default = 1]; // The group size for group conv - - + optional uint32 pad = 3 [default = 1]; // The padding size; defaults to 0 + optional uint32 kernel_size = 4 [default = 3]; // The kernel size + optional uint32 stride = 5 [default = 1]; // The stride defaults to 1 + + optional uint32 pad_h = 6; // The padding height (2D only) + optional uint32 pad_w = 7; // The padding width (2D only) + optional uint32 kernel_h = 8; // The kernel height (2D only) + optional uint32 kernel_w = 9; // The kernel width (2D only) + optional uint32 stride_h = 10; // The stride height (2D only) + optional uint32 stride_w = 11; // The stride width (2D only) + optional uint32 group = 12 [default = 1]; // The group size for group conv } diff --git a/src/sita/workspace.cpp b/src/sita/workspace.cpp index 07d2454e..63909b21 100644 --- a/src/sita/workspace.cpp +++ b/src/sita/workspace.cpp @@ -276,7 +276,6 @@ void GlobalWorkSpace::global_init(Graph * graph, DataProvider * da boost::shared_ptr > op = OperatorRegistry::CreateOperator(opdef, gws); op->setup(); op->init(); - op->infer_shape(); _ops.push_back(op); } } @@ -288,12 +287,6 @@ void GlobalWorkSpace::forward(){ } } -template -void GlobalWorkSpace::infer_shape(){ - for(int i = 0; i < _ops.size(); i++){ - _ops[i]->infer_shape(); - } -} template void GlobalWorkSpace::backward(){ @@ -310,8 +303,12 @@ void GlobalWorkSpace::train(){ for(int i = 0; i < batch->product_size();i ++){ fetch_output(batch->product_name(i))->copy_from(batch->product(i), true); } - forward(); - backward(); + LOG(INFO) << "--------------------------begin solve---------------------------"; + int k = 0; + while(k != 1000){ + forward(); + backward(); + } } diff --git a/test.prototxt b/test.prototxt index 58ae3511..6bab8ce8 100644 --- a/test.prototxt +++ b/test.prototxt @@ -2,13 +2,13 @@ name: "LeNet" operatordef { name: "conv1" - type: "ConvolutionOp" + type: "Convolution" input: "data" output: "conv1" - convolution_op_param{ - num_output: 20 - kernel_size: 5 - stride: 1 + convolution_param{ + num_output: 10 + kernel_size: 3 + stride: 2 } param{ filler{type:"xavier"} @@ -21,13 +21,13 @@ operatordef { operatordef { name: "conv2" - type: "ConvolutionOp" + type: "Convolution" input: "conv1" output: "conv2" - convolution_op_param{ + convolution_param{ num_output: 20 - kernel_size: 5 - stride: 1 + kernel_size: 3 + stride: 2 } param{ filler{type:"xavier"} @@ -39,13 +39,13 @@ operatordef { } operatordef { name: "conv3" - type: "ConvolutionOp" + type: "Convolution" input: "conv2" output: "conv3" - convolution_op_param{ - num_output: 20 - kernel_size: 5 - stride: 1 + convolution_param{ + num_output: 25 + kernel_size: 3 + stride: 2 } param{ filler{type:"xavier"} diff --git a/tools/main.cpp b/tools/sita.cpp similarity index 94% rename from tools/main.cpp rename to tools/sita.cpp index 6fff3ec7..27b534ac 100644 --- a/tools/main.cpp +++ b/tools/sita.cpp @@ -24,13 +24,7 @@ int main(int argc, char** argv) { gws.global_init(&graph, &mnistdp); - int k = 0; - - while(k != 10000000){ - gws.train(); - k++; - LOG(INFO)<