Skip to content

Commit

Permalink
pooling support
Browse files Browse the repository at this point in the history
  • Loading branch information
unsky committed Nov 6, 2018
1 parent bd52b2a commit e61b173
Show file tree
Hide file tree
Showing 9 changed files with 1,187 additions and 26 deletions.
17 changes: 9 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ set(SITA_FILES
src/sita/memory_control.cpp
src/sita/tensor.cpp
src/sita/context.cpp
src/sita/dlflow/operators/convolution.cpp
src/sita/dlflow/operators/convolution.cu
src/sita/dlflow/operators/batch_norm.cpp
src/sita/dlflow/operators/batch_norm.cu
src/sita/dlflow/operators/relu.cpp
src/sita/dlflow/operators/relu.cu
src/sita/operators/convolution.cpp
src/sita/operators/convolution.cu
src/sita/operators/batch_norm.cpp
src/sita/operators/batch_norm.cu
src/sita/operators/relu.cpp
src/sita/operators/relu.cu
# src/sita/operators/pooling.cpp
src/sita/workspace.cpp
src/sita/dlflow/operator.cpp
src/sita/dlflow/graph.cpp
src/sita/operator.cpp
src/sita/graph.cpp
src/sita/io_protobuff.cpp
src/sita/internal_thread.cpp
src/sita/blocking_queue.cpp
Expand Down
39 changes: 39 additions & 0 deletions include/sita/operators/pooling.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
//
// Created by unsky on 15/08/18.
//

#ifndef SITA_DLFLOW_POOLING_H
#define SITA_DLFLOW_POOLING_H
#include "sita/operator.h"
#include "sita/proto/sita.h"
namespace sita{
template<typename Dtype>
class Pooling: public Operator<Dtype>{
public:
Pooling(const OperatorParameter& opdef, GlobalWorkSpace<Dtype> *gws, std::string phase):Operator<Dtype>(opdef, gws, phase){
_op_param = opdef.pooling_param();
_handles_setup = false;
}
~Pooling();
void init();
void infer_shape();
void forward();
void backward();

bool inline has_param(){ return _has_param;}

protected:
bool _has_param = false;
PoolingParameter _op_param;

private:
cudnnHandle_t* _handle;
cudaStream_t* _stream;
// cuDNN descriptors / handles
cudnnTensorDescriptor_t _input_desc, _output_desc;
cudnnPoolingDescriptor_t _pooling_desc;
cudnnPoolingMode_t _mode;
bool _handles_setup;
};

}
4 changes: 2 additions & 2 deletions include/sita/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
#include "types.h"
#include "sita/dataprovider/mnist_dataprovider.h"
#include "sita/dataprovider/dataprovider.h"
#include "sita/dlflow/registry.h"
#include "sita/dlflow/operator.h"
#include "sita/registry.h"
#include "sita/operator.h"
namespace sita{


Expand Down
41 changes: 41 additions & 0 deletions src/sita/operators/pooling.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "sita/operators/pooling.h"

namespace sita{
template <typename Dtype>
Pooling<Dtype>::~Pooling() {
if (!_handles_setup) return;
cudnnDestroyTensorDescriptor(_input_desc);
cudnnDestroyTensorDescriptor(_output_desc);
cudnnDestroyPoolingDescriptor(_pooling_desc);
}

template<typename Dtype>
void Pooling<Dtype>::init(){
Context::create_tensor4d_descriptor<Dtype>(&_input_desc);
Context::create_tensor4d_descriptor<Dtype>(&_output_desc);

LOG(INFO) << "Params:";
int kernel_h;
int kernel_w;
if(this->_op_param.has_kernel_h() && this->_op_param.has_kernel_w()){
kernel_h = this->_op_param.kernel_h();
kernel_w = this->_op_param.kernel_w();
}else{
kernel_h = this->_op_param.kernel_size();
kernel_w = this->_op_param.kernel_size();
}

Context::create_tensor4d_descriptor<Dtype>(&_pooling_desc,
this->layer_param_.pooling_param().pool(), &_mode,
this->_kernel_h, this->_kernel_w, this->_pad_h, this->_pad_w,
this->_stride_h, this->_stride_w);
handles_setup_ = true;
}

template<typename Dtype>
void Pooling<Dtype>::infer_shape(){

}
INSTANTIATE_CLASS(Pooling);
REGISTER_OPERATOR_CLASS(Pooling);
}
Empty file added src/sita/operators/pooling.cu
Empty file.
Loading

0 comments on commit e61b173

Please sign in to comment.