-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
1,187 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
// | ||
// Created by unsky on 15/08/18. | ||
// | ||
|
||
#ifndef SITA_DLFLOW_POOLING_H | ||
#define SITA_DLFLOW_POOLING_H | ||
#include "sita/operator.h" | ||
#include "sita/proto/sita.h" | ||
namespace sita{ | ||
template<typename Dtype> | ||
class Pooling: public Operator<Dtype>{ | ||
public: | ||
Pooling(const OperatorParameter& opdef, GlobalWorkSpace<Dtype> *gws, std::string phase):Operator<Dtype>(opdef, gws, phase){ | ||
_op_param = opdef.pooling_param(); | ||
_handles_setup = false; | ||
} | ||
~Pooling(); | ||
void init(); | ||
void infer_shape(); | ||
void forward(); | ||
void backward(); | ||
|
||
bool inline has_param(){ return _has_param;} | ||
|
||
protected: | ||
bool _has_param = false; | ||
PoolingParameter _op_param; | ||
|
||
private: | ||
cudnnHandle_t* _handle; | ||
cudaStream_t* _stream; | ||
// cuDNN descriptors / handles | ||
cudnnTensorDescriptor_t _input_desc, _output_desc; | ||
cudnnPoolingDescriptor_t _pooling_desc; | ||
cudnnPoolingMode_t _mode; | ||
bool _handles_setup; | ||
}; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
#include "sita/operators/pooling.h" | ||
|
||
namespace sita{ | ||
template <typename Dtype> | ||
Pooling<Dtype>::~Pooling() { | ||
if (!_handles_setup) return; | ||
cudnnDestroyTensorDescriptor(_input_desc); | ||
cudnnDestroyTensorDescriptor(_output_desc); | ||
cudnnDestroyPoolingDescriptor(_pooling_desc); | ||
} | ||
|
||
template<typename Dtype> | ||
void Pooling<Dtype>::init(){ | ||
Context::create_tensor4d_descriptor<Dtype>(&_input_desc); | ||
Context::create_tensor4d_descriptor<Dtype>(&_output_desc); | ||
|
||
LOG(INFO) << "Params:"; | ||
int kernel_h; | ||
int kernel_w; | ||
if(this->_op_param.has_kernel_h() && this->_op_param.has_kernel_w()){ | ||
kernel_h = this->_op_param.kernel_h(); | ||
kernel_w = this->_op_param.kernel_w(); | ||
}else{ | ||
kernel_h = this->_op_param.kernel_size(); | ||
kernel_w = this->_op_param.kernel_size(); | ||
} | ||
|
||
Context::create_tensor4d_descriptor<Dtype>(&_pooling_desc, | ||
this->layer_param_.pooling_param().pool(), &_mode, | ||
this->_kernel_h, this->_kernel_w, this->_pad_h, this->_pad_w, | ||
this->_stride_h, this->_stride_w); | ||
handles_setup_ = true; | ||
} | ||
|
||
template<typename Dtype> | ||
void Pooling<Dtype>::infer_shape(){ | ||
|
||
} | ||
INSTANTIATE_CLASS(Pooling); | ||
REGISTER_OPERATOR_CLASS(Pooling); | ||
} |
Empty file.
Oops, something went wrong.