Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN reshape&flatten&expand_dims (#16258)
Browse files Browse the repository at this point in the history
* Add mkldnn 1.0 support for reshape/flatten/expanddims ops

* improve log & modify definition location of args_map_

* fix comments

* rebase code

* trigger CI

* trigger CI

* trigger CI

* trigger CI
  • Loading branch information
wuxun-zhang authored and TaoLv committed Oct 11, 2019
1 parent 458bb73 commit 922b616
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 108 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
bool SupportMKLDNNReshape(const NDArray &in_data, const NDArray &out_data);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
70 changes: 70 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_expand_dims.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_expand_dims.cc
* \brief Implement expand_dims operator via MKL-DNN reorder primitive
* \author Wuxun Zhang
*/

#if MXNET_USE_MKLDNN == 100

#include "mkldnn_reshape-inl.h"

namespace mxnet {
namespace op {

class MKLDNNExpandDimsFwd : public MKLDNNReshapeFwd {
public:
explicit MKLDNNExpandDimsFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output)
: MKLDNNReshapeFwd(req, input, output) {}
};

typedef ParamOpSign<ExpandDimParam> MKLDNNExpandDimsSignature;

void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output) {
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";

auto fwd = GetCachedForward<MKLDNNExpandDimsFwd, ExpandDimParam,
MKLDNNExpandDimsSignature>(param, req, input, output);

auto ws_size = fwd.GetWorkspaceSize();
void* ws_ptr = nullptr;
if (ws_size) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
.get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, req, ws_ptr);
}

} // namespace op
} // namespace mxnet

#endif
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_flatten-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FLATTEN_INL_H_
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "mkldnn_reshape-inl.h"

Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@

/*!
* \file mkldnn_flatten.cc
* \brief Implement flatten operator by using mkldnn reorder primitive
* \brief Implement flatten operator via using MKL-DNN reorder primitive
* \author Wuxun Zhang
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include "mkldnn_flatten-inl.h"

Expand Down Expand Up @@ -70,7 +70,7 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
fwd.Execute(input, output, req, ws_ptr);
}

} // namespace op
Expand Down
28 changes: 16 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
#endif

#if MXNET_USE_MKLDNN == 100
Expand Down Expand Up @@ -142,6 +130,22 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const NDArray &data,
const OpReqType &req,
const NDArray &output);

void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
#endif

} // namespace op
Expand Down
33 changes: 28 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <vector>
#include "mkldnn_base-inl.h"
#include "../../tensor/matrix_op-inl.h"
Expand All @@ -36,7 +36,6 @@ namespace op {

class MKLDNNReshapeFwd {
protected:
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;
Expand All @@ -47,15 +46,39 @@ class MKLDNNReshapeFwd {
const NDArray &input,
const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
void Execute(const NDArray &input,
const NDArray &output,
const OpReqType &req,
void* workspace = nullptr);
};

typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;

template<typename MKLDNNOpFwdType, typename ParamType, typename MKLDNNSigatureType>
MKLDNNOpFwdType &GetCachedForward(const ParamType& param,
const OpReqType &req,
const NDArray &input,
const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSigatureType,
MKLDNNOpFwdType, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSigatureType,
MKLDNNOpFwdType, OpHash> fwds;
#endif
MKLDNNSigatureType key(param);
key.AddSign(req);
key.AddSign(input);
key.AddSign(output);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNOpFwdType fwd(req, input, output);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
Expand Down
Loading

0 comments on commit 922b616

Please sign in to comment.