Skip to content

Commit

Permalink
Support projection feature for LSTM on CPU (Only Inference) (apache#1…
Browse files Browse the repository at this point in the history
…7702)

* Support projection feature for LSTM on CPU

* test solution for -Werror=maybe-uninitialized

* Check device type when create state

* Document the projection feature of LSTM for RNN operator

* Minor fix

* Re-run CI
  • Loading branch information
zixuanweeei committed Apr 13, 2020
1 parent db93398 commit d25929b
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 47 deletions.
8 changes: 0 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRnn(const NDArray &input) {
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

static inline bool SupportMKLDNNQuantize(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16;
Expand Down
12 changes: 12 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,18 @@ class MKLDNNRnnOp {
const std::vector<NDArray> &outputs);
};

inline bool SupportMKLDNNRnn(const int input_dtype) {
if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

inline bool SupportMKLDNNRnn(const RNNParam &param, const int input_dtype) {
if (param.projection_size.has_value()) return false;
return SupportMKLDNNRnn(input_dtype);
}

} // namespace op
} // namespace mxnet

Expand Down
2 changes: 0 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE)
void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx,
void* y, void* hy, void* cy,
const int dtype) {
using dims = mkldnn::memory::dims;
using desc = mkldnn::memory::desc;
using format_tag = mkldnn::memory::format_tag;
auto& cpu_engine = CpuEngine::Get()->get_engine();
Expand Down Expand Up @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
Expand Down
33 changes: 24 additions & 9 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer,
inline size_t GetRNNWorkspaceSize(int seq_length,
int batch_size,
int hidden_size,
int projection_size,
int direction,
int mode) {
size_t size = 0;
Expand Down Expand Up @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws,
const int batch_size,
const int input_size,
const int state_size,
const int projection_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
batch_size, input_size, state_size, projection_size,
x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
Expand Down Expand Up @@ -511,10 +513,7 @@ class RNNOp {
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
if (param_.projection_size.has_value()) {
LOG(FATAL) <<
"hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
}

if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
Expand Down Expand Up @@ -843,9 +842,14 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}

// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
Expand All @@ -856,6 +860,9 @@ class RNNOp {

if (ctx.is_train || ctx.need_grad) {
// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
Expand Down Expand Up @@ -896,6 +903,7 @@ class RNNOp {
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
Expand Down Expand Up @@ -1096,10 +1104,17 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

// allocate temp space
const size_t work_cpu_space_size =
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size,
projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
Expand Down
44 changes: 29 additions & 15 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,19 @@ static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const
return request;
}

#if MXNET_USE_MKLDNN == 1
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode = DispatchMode::kFCompute;

#if MXNET_USE_MKLDNN == 1
wanted_mode = DispatchMode::kFComputeEx;
#endif // MXNET_USE_MKLDNN == 1

return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
const bool support_mkldnn_rnn =
!param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn,
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1

struct RNNGrad {
const char *op_name;
Expand Down Expand Up @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
}

#if MXNET_USE_MKLDNN == 1
if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
&& in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
data_shape[1], data_shape[2]);
Expand All @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Forward(ctx, inputs, req, outputs);
} else {
Expand All @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Backward(ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http:https://www.bioinf.jku.at/publications
h_t = o_t * \tanh(c_t)
\end{array}
With the projection size being set, LSTM could use the projection feature to reduce the parameters
size and give some speedups without significant damage to the accuracy.
Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech
Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128
.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
r_t = W_{hr} h_t
\end{array}
**GRU**
Gated Recurrent Unit - Cho et al. 2014. http:https://arxiv.org/abs/1406.1078
Expand Down Expand Up @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with
})
.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
Expand Down Expand Up @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeExCPU)
#endif
Expand Down
48 changes: 35 additions & 13 deletions src/operator/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const int N,
const int I,
const int H,
const int P,
const Tensor<cpu, 2, DType> &x,
const Tensor<cpu, 2, DType> &hx,
const Tensor<cpu, 2, DType> &cx,
Expand All @@ -219,7 +220,9 @@ void LstmForwardInferenceSingleLayer(DType* ws,
DType* cy_ptr) {
using namespace mshadow;
const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H)));
Tensor<cpu, 2, DType> whr(w_ptr, Shape2(1, 1));
if (P > 0) whr = Tensor<cpu, 2, DType>(wh.dptr_ + P * 4 * H, Shape2(P, H));
const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H));
const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H));
Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, H * 4));
Expand All @@ -228,7 +231,10 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H));
Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H));
Tensor<cpu, 2, DType> c(h.dptr_ + N * H, Shape2(N, H));
Tensor<cpu, 2, DType> r(hy_ptr, Shape2(1, 1));
if (P > 0) r = Tensor<cpu, 2, DType>(hy_ptr, Shape2(N, P));
const int offset = bid ? H : 0;
const int proj_offset = bid ? P : 0;
const DType alpha = 1.0;
const DType beta = 0.0;
const int cell_size = N * H;
Expand All @@ -237,7 +243,11 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
for (int i = 0; i < T; ++i) {
int t = bid ? T - 1 - i : i;
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
if (P > 0) {
linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true);
} else {
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
}
#pragma omp parallel for num_threads(omp_threads)
for (int jk = 0; jk < cell_size; ++jk) {
int j = jk / H;
Expand All @@ -248,14 +258,21 @@ void LstmForwardInferenceSingleLayer(DType* ws,
DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]);
DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt;
DType ht = ot * tanh(ct);
y[t][j][k + offset] = ht;
if (P == 0) y[t][j][k + offset] = ht;
if (i == T - 1 && state_outputs) {
hy_ptr[jk] = ht;
if (P == 0) hy_ptr[jk] = ht;
cy_ptr[jk] = ct;
} else {
h[j][k] = ht;
c[j][k] = ct;
}
h[j][k] = ht;
}
if (P > 0) {
linalg_gemm(h, whr, r, alpha, beta, false, true);
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < N; ++j) {
std::memcpy(y[t][j].dptr_ + proj_offset, r[j].dptr_, P * sizeof(DType));
}
}
}
}
Expand All @@ -269,6 +286,7 @@ void LstmForwardInference(DType* ws,
const int N,
const int I,
const int H,
const int P,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -278,36 +296,40 @@ void LstmForwardInference(DType* ws,
DType* hy_ptr,
DType* cy_ptr) {
const int total_layers = D * L;
Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, P ? P : H));
Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
const int b_size = 2 * H * 4;
const int cell_size = N * H;
const int projection_size = (P ? P : H) * N;
DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2;
DType* y_cur_ptr = y_ptr;
int idx = 0; // state & cell state's idx;
bool flag = L % 2 ? false : true;
for (int i = 0; i < L; ++i) {
const int input_size = i ? H * D : I;
const int w_size = (input_size + H) * H * 4;
const int input_size = i ? (P ? P : H) * D : I;
int w_size = (input_size + (P ? P : H)) * H * 4;
if (P > 0) {
w_size += P * H;
}
// If bidirectional, need space to save current layer output y.
if (D == 2) {
y_cur_ptr = flag ? y_tmp_ptr : y_ptr;
flag = !flag;
}
Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, H * D));
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, false, T, N, input_size, H,
Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, (P ? P : H) * D));
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, false, T, N, input_size, H, P,
x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
// If bidirectional, then calculate the reverse direction's forward result.
if (D == 2) {
w_ptr += w_size;
b_ptr += b_size;
++idx;
if (state_outputs) {
hy_ptr += cell_size;
hy_ptr += projection_size;
cy_ptr += cell_size;
}
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, true, T, N, input_size, H,
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, true, T, N, input_size, H, P,
x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
}
// Don't need to move pointer in the last layer.
Expand All @@ -317,7 +339,7 @@ void LstmForwardInference(DType* ws,
x_ptr = y_cur_ptr;
++idx;
if (state_outputs) {
hy_ptr += cell_size;
hy_ptr += projection_size;
cy_ptr += cell_size;
}
}
Expand Down
Loading

0 comments on commit d25929b

Please sign in to comment.