Skip to content

Commit

Permalink
cudnn convolution support
Browse files Browse the repository at this point in the history
  • Loading branch information
unsky committed Sep 25, 2018
1 parent 3b5e94c commit ec70591
Show file tree
Hide file tree
Showing 22 changed files with 1,005 additions and 1,725 deletions.
7 changes: 4 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,12 @@ execute_process(COMMAND protoc --cpp_out=${PROJECT_SOURCE_DIR}/src/sita/proto/
set(SITA_FILES
src/sita/proto/sita_operators.pb.cc
src/sita/proto/sita_utils.pb.cc
tools/main.cpp
tools/sita.cpp
src/sita/memory_control.cpp
src/sita/tensor.cpp
# src/sita/dlflow/operators/add_op.cpp
src/sita/dlflow/operators/convolution_op.cpp
src/sita/context.cpp
src/sita/dlflow/operators/convolution.cpp
src/sita/dlflow/operators/convolution.cu
src/sita/workspace.cpp
src/sita/dlflow/operator.cpp
src/sita/dlflow/graph.cpp
Expand Down
28 changes: 12 additions & 16 deletions include/sita/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,21 @@
namespace sita{


template <typename Dtype> class dataType;
template<> class dataType<float> {
template <typename Dtype> class CudnnDataType;
template<> class CudnnDataType<float> {
public:
static const cudnnDataType_t type = CUDNN_DATA_FLOAT;
static float oneval = 1.0;
static float zeroval = 0.0;
static const void *one = static_cast<void *>(&dataType<float>::oneval);
static const void *zero = static_cast<void *>(&dataType<float>::zeroval);
static float oneval, zeroval;
static const void *one, *zero;
};
template<> class dataType<double> {
template<> class CudnnDataType<double> {
public:
static const cudnnDataType_t type = CUDNN_DATA_DOUBLE;
static double oneval = 1.0;
static double zeroval = 0.0;
static const void *one = static_cast<void *>(&dataType<double>::oneval);
static const void *zero = static_cast<void *>(&dataType<double>::zeroval);
static double oneval, zeroval;
static const void *one, *zero;
};

class Context{
class Context{
public:
Context() {}
~Context() {}
Expand Down Expand Up @@ -111,7 +107,7 @@ template<> class dataType<double> {
inline static void set_tensor4d_descriptor(cudnnTensorDescriptor_t* desc,
int n, int c, int h, int w,
int stride_n, int stride_c, int stride_h, int stride_w) {
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, dataType<Dtype>::type,
CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(*desc, CudnnDataType<Dtype>::type,
n, c, h, w, stride_n, stride_c, stride_h, stride_w));
}

Expand All @@ -132,10 +128,10 @@ template<> class dataType<double> {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));

#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, CudnnDataType<Dtype>::type,
CUDNN_TENSOR_NCHW, n, c, h, w));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, CudnnDataType<Dtype>::type,
CUDNN_TENSOR_NCHW, n, c, h, w));
#endif
}
Expand All @@ -152,7 +148,7 @@ template<> class dataType<double> {
#if CUDNN_VERSION_MIN(6, 0, 0)
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION,
dataType<Dtype>::type));
CudnnDataType<Dtype>::type));
#else
CUDNN_CHECK(cudnnSetConvolution2dDescriptor(*conv,
pad_h, pad_w, stride_h, stride_w, 1, 1, CUDNN_CROSS_CORRELATION));
Expand Down
2 changes: 2 additions & 0 deletions include/sita/dlflow/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ class Operator{
std::vector<std::string> _params;
std::map<std::string, std::vector<int> > _input_shapes;
std::map<std::string, std::vector<int> > _output_shapes;
std::map<std::string, std::vector<int> > _param_shapes;
bool _is_shared;
std::vector<std::pair<std::string, std::string> > _shared_param_pairs;
bool _gradient_block;
};

}//namespace
Expand Down
31 changes: 0 additions & 31 deletions include/sita/dlflow/operators/add_op.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,29 @@
// Created by unsky on 15/08/18.
//

#ifndef SITA_DLFLOW_CONVOLUTION_OP_H
#define SITA_DLFLOW_CONVOLUTION_OP_H
#ifndef SITA_DLFLOW_CONVOLUTION_H
#define SITA_DLFLOW_CONVOLUTION_H
#include "sita/dlflow/operator.h"
#include "sita/proto/sita.h"
namespace sita{

template<typename Dtype>
class ConvolutionOp: public Operator<Dtype>{
class Convolution: public Operator<Dtype>{
public:
ConvolutionOp(const OperatorParameter& opdef, GlobalWorkSpace<Dtype> *gws):Operator<Dtype>(opdef,gws){
_op_param = opdef.convolution_op_param();
Convolution(const OperatorParameter& opdef, GlobalWorkSpace<Dtype> *gws):Operator<Dtype>(opdef,gws){
_op_param = opdef.convolution_param();
}
~ConvolutionOp(){};
~Convolution();
void init();
void infer_shape();
void forward();
void backward(){};
void backward();

bool inline has_param(){ return _has_param;}

protected:
bool _has_param = true;
ConvolutionOpParameter _op_param;
ConvolutionParameter _op_param;

private:
bool _handles_setup;
Expand Down Expand Up @@ -54,4 +54,4 @@ class ConvolutionOp: public Operator<Dtype>{

};
}
#endif //SITA_DLFLOW_CONVOLUTION_OP_H
#endif //SITA_DLFLOW_CONVOLUTION_H
40 changes: 0 additions & 40 deletions include/sita/dlflow/operators/data_test_op.h

This file was deleted.

12 changes: 12 additions & 0 deletions include/sita/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,5 +80,17 @@ private:\
template class classname<double>; \



#define INSTANTIATE_OPERATOR_GPU_FORWARD(classname) \
template void classname<float>::forward(); \
template void classname<double>::forward();

#define INSTANTIATE_OPERATOR_GPU_BACKWARD(classname) \
template void classname<float>::backward(); \
template void classname<double>::backward()

#define INSTANTIATE_OPERATOR_GPU_FUNCS(classname) \
INSTANTIATE_OPERATOR_GPU_FORWARD(classname); \
INSTANTIATE_OPERATOR_GPU_BACKWARD(classname)
}//namespace
#endif //SITA_MACROS_H
3 changes: 3 additions & 0 deletions include/sita/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class Tensor {
inline const int dim() const {
return _dim;
}
inline const int size(){
return _count * sizeof(Dtype);
}

inline const int count() const {
return _count;
Expand Down
1 change: 0 additions & 1 deletion include/sita/workspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class GlobalWorkSpace : public WorkSpace{
}

void global_init(Graph * graph, DataProvider<Dtype> * data_provider);
void infer_shape();
void forward();
void backward();
void train();
Expand Down
19 changes: 19 additions & 0 deletions src/sita/context.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include "sita/context.h"

namespace sita{

float CudnnDataType<float>::oneval = 1.0;
float CudnnDataType<float>::zeroval = 0.0;
const void* CudnnDataType<float>::one =
static_cast<void *>(&CudnnDataType<float>::oneval);
const void* CudnnDataType<float>::zero =
static_cast<void *>(&CudnnDataType<float>::zeroval);

double CudnnDataType<double>::oneval = 1.0;
double CudnnDataType<double>::zeroval = 0.0;
const void* CudnnDataType<double>::one =
static_cast<void *>(&CudnnDataType<double>::oneval);
const void* CudnnDataType<double>::zero =
static_cast<void *>(&CudnnDataType<double>::zeroval);

}
2 changes: 1 addition & 1 deletion src/sita/dlflow/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void Operator<Dtype>::setup(){
_param_configs.push_back(_opdef.param(i));
}
_is_shared = false;
_gradient_block = _opdef.gradient_block();
_shared_param_pairs.clear();

}

template<typename Dtype>
Expand Down
30 changes: 0 additions & 30 deletions src/sita/dlflow/operators/add_op.cpp

This file was deleted.

Loading

0 comments on commit ec70591

Please sign in to comment.