Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Reduce overhead in sg_onednn_fully_connected for floats (#21092)
Browse files Browse the repository at this point in the history
* dispatch float fc to simpler function

* apply review comments

* remove unnecesarry variables

* fix sanity
  • Loading branch information
bgawrych committed Jul 19, 2022
1 parent ef0415d commit cca8f4e
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 73 deletions.
20 changes: 20 additions & 0 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,11 @@ struct DNNLPostEltwiseParam {
float scale = 1.f;
float alpha = 0.f;
float beta = 1.f;

bool operator==(const DNNLPostEltwiseParam& other) const {
return this->alg == other.alg && this->scale == other.scale && this->alpha == other.alpha &&
this->beta == other.beta;
}
};

void DNNLRun(mxnet::FComputeEx fn,
Expand All @@ -792,5 +797,20 @@ void DNNLRun(FComputeExUnary fn,
const mxnet::NDArray& outputs_);

} // namespace mxnet

namespace std {
template <>
struct hash<mxnet::DNNLPostEltwiseParam> {
size_t operator()(const mxnet::DNNLPostEltwiseParam& val) {
size_t ret = dmlc::HashCombine(0, static_cast<int>(val.alg));
ret = dmlc::HashCombine(ret, val.scale);
ret = dmlc::HashCombine(ret, val.alpha);
ret = dmlc::HashCombine(ret, val.beta);

return ret;
}
};
} // namespace std

#endif
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_BASE_INL_H_
63 changes: 60 additions & 3 deletions src/operator/nn/dnnl/dnnl_fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,30 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
.describe("Whether support channel-wise-quantize for weight.");
DNNL_DECLARE_ENABLED_FLOAT_OUTPUT_PARAMETER();
}

bool operator==(const DNNLFCParam& other) const {
return this->quantized == other.quantized &&
this->enabled_float_output == other.enabled_float_output &&
this->with_eltwise == other.with_eltwise && this->with_sum == other.with_sum &&
this->min_calib_range == other.min_calib_range &&
this->max_calib_range == other.max_calib_range &&
this->channel_wise_quantize == other.channel_wise_quantize;
}
};

struct DNNLFCFullParam {
struct DNNLFCFullParam : public dmlc::Parameter<DNNLFCFullParam> {
FullyConnectedParam default_param;
DNNLFCParam dnnl_param;
DNNLPostEltwiseParam eltwise_param;
float sum_scale = {1.0f};
std::vector<float> output_scales = {0.0f};
DMLC_DECLARE_PARAMETER(DNNLFCFullParam) {}

bool operator==(const DNNLFCFullParam& other) const {
return this->default_param == other.default_param && this->dnnl_param == other.dnnl_param &&
this->eltwise_param == other.eltwise_param && this->sum_scale == other.sum_scale &&
this->output_scales == other.output_scales;
}
};

static inline size_t GetInSumIndex(const DNNLFCFullParam& param) {
Expand Down Expand Up @@ -187,9 +203,9 @@ class DNNLFullyConnectedForward {
std::shared_ptr<dnnl::inner_product_forward> fwd_;
};

typedef ParamOpSign<FullyConnectedParam> DNNLFullyconSignature;
typedef ParamOpSign<DNNLFCFullParam> DNNLFullyconSignature;

DNNLFullyConnectedForward& GetFCFwd(const FullyConnectedParam& param,
DNNLFullyConnectedForward& GetFCFwd(const DNNLFCFullParam& param,
const bool is_train,
const NDArray& data,
const NDArray& weight,
Expand All @@ -207,6 +223,12 @@ void DNNLFCForward(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data);

void DNNLFCForwardImpl(const DNNLFCFullParam& full_param,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data);

void DNNLFCForwardFullFeature(const DNNLFCFullParam& param,
const OpContext& ctx,
DNNLFullyConnectedForward* fwd,
Expand All @@ -223,5 +245,40 @@ void DNNLFCBackward(const nnvm::NodeAttrs& attrs,
} // namespace op
} // namespace mxnet

namespace std {

template <>
struct hash<mxnet::op::DNNLFCParam> {
size_t operator()(const mxnet::op::DNNLFCParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.min_calib_range.has_value() ? val.min_calib_range.value() : 0);
ret = dmlc::HashCombine(ret, val.max_calib_range.has_value() ? val.max_calib_range.value() : 0);
ret = dmlc::HashCombine(
ret, val.channel_wise_quantize.has_value() ? val.channel_wise_quantize.value() : 0);
ret = dmlc::HashCombine(ret, val.quantized);
ret = dmlc::HashCombine(
ret, val.enabled_float_output.has_value() ? val.enabled_float_output.value() : -1);
ret = dmlc::HashCombine(ret, val.with_eltwise);
ret = dmlc::HashCombine(ret, val.with_sum);

return ret;
}
};

template <>
struct hash<mxnet::op::DNNLFCFullParam> {
size_t operator()(const mxnet::op::DNNLFCFullParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.default_param);
ret = dmlc::HashCombine(ret, val.dnnl_param);
ret = dmlc::HashCombine(ret, val.eltwise_param);
ret = dmlc::HashCombine(ret, val.sum_scale);
for (const auto& v : val.output_scales)
ret = dmlc::HashCombine(ret, v);
return ret;
}
};
} // namespace std

#endif // MXNET_USE_ONEDNN == 1
#endif // MXNET_OPERATOR_NN_DNNL_DNNL_FULLY_CONNECTED_INL_H_
16 changes: 10 additions & 6 deletions src/operator/nn/dnnl/dnnl_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ inline static dnnl::inner_product_backward_weights::primitive_desc GetFCBwdWeigh
}
}

DNNLFullyConnectedForward& GetFCFwd(const FullyConnectedParam& param,
DNNLFullyConnectedForward& GetFCFwd(const DNNLFCFullParam& param,
const bool is_train,
const NDArray& data,
const NDArray& weight,
Expand All @@ -146,10 +146,7 @@ DNNLFullyConnectedForward& GetFCFwd(const FullyConnectedParam& param,

auto it = fcFwds.find(key);
if (it == fcFwds.end()) {
DNNLFCFullParam full_param;
full_param.default_param = param;
full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
DNNLFullyConnectedForward fcFwd(full_param, is_train, data, weight, bias, out_md);
DNNLFullyConnectedForward fcFwd(param, is_train, data, weight, bias, out_md);
it = AddToCache(&fcFwds, key, fcFwd);
}
return it->second;
Expand Down Expand Up @@ -230,11 +227,18 @@ void DNNLFCForward(const nnvm::NodeAttrs& attrs,
DNNLFCFullParam full_param;
full_param.default_param = nnvm::get<FullyConnectedParam>(attrs.parsed);
full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
DNNLFCForwardImpl(full_param, ctx, in_data, req, out_data);
}

void DNNLFCForwardImpl(const DNNLFCFullParam& full_param,
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
NDArray data = in_data[fullc::kData];
dnnl::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]);
DNNLFCFlattenData(full_param.default_param, out_data[fullc::kOut], &data, &out_md);
auto& fwd = GetFCFwd(full_param.default_param,
auto& fwd = GetFCFwd(full_param,
ctx.is_train,
data,
in_data[fullc::kWeight],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ void DNNLQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& out_data) {
TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
FullyConnectedParam param = nnvm::get<FullyConnectedParam>(attrs.parsed);
const size_t num_inputs = param.no_bias ? 2 : 3;
DNNLFCFullParam full_param;
full_param.default_param = param;
full_param.dnnl_param.Init(std::unordered_map<std::string, std::string>());
const size_t num_inputs = param.no_bias ? 2 : 3;

CHECK_EQ(in_data.size(), static_cast<size_t>(num_inputs * 3));
CHECK_EQ(out_data.size(), 3U);
Expand Down Expand Up @@ -87,8 +90,8 @@ void DNNLQuantizedFullyConnectedForward(const nnvm::NodeAttrs& attrs,
bool is_train = false;
dnnl::memory::desc out_md = GetMemDesc(out_data[fullc::kOut]);
DNNLFCFlattenData(param, out_data[fullc::kOut], &data, &out_md);
auto& fwd =
GetFCFwd(param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);
auto& fwd = GetFCFwd(
full_param, is_train, data, weight, param.no_bias ? nullptr : &quantized_bias, out_md);

auto fwd_src_desc = fwd.fwd_pd.src_desc();
auto data_mem = in_data[fullc::kData].GetDNNLDataReorder(&fwd_src_desc);
Expand Down
100 changes: 39 additions & 61 deletions src/operator/subgraph/dnnl/dnnl_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,22 +112,31 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
const FCInputIndex idx(full_param_);
CHECK_EQ(in_data.size(), idx.GetTotal());
const int out_index = 0;
const int out_min_index = 1;
const int out_max_index = 2;

const auto& default_param = full_param_.default_param;
const auto& dnnl_param = full_param_.dnnl_param;
const bool has_bias = !default_param.no_bias;
const bool quantized = dnnl_param.quantized;
const bool out_quantized = dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value();
const bool channel_wise = quantized && dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();

const FCInputIndex idx(full_param_);
CHECK_EQ(in_data.size(), idx.GetTotal());
NDArray output;
if (dnnl_param.with_sum) {
output = PrepareOutputWithSum(in_data[idx.sum], out_data[out_index]);
} else {
output = out_data[out_index];
}

if (!dnnl_param.quantized) {
// dispatch to float version
DNNLFCForwardImpl(full_param_, ctx, in_data, req, {output});
return;
}

int index = 0;
const int out_index = index++;
const int out_min_index = out_quantized ? index++ : 0;
const int out_max_index = out_quantized ? index++ : 0;
CHECK_EQ(out_data.size(), index); // index is equal to total number of outputs
const bool has_bias = !default_param.no_bias;
const bool channel_wise =
dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize.value();

std::vector<float> min_max_vec(MIN_MAX_COUNT);
min_max_vec[kDataMin] = 0.0f;
Expand All @@ -141,26 +150,17 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
min_max_vec[kSumMax] = idx.sum_max ? in_data[idx.sum_max].data().dptr<float>()[0] : 0.0f;
NDArray data = in_data[idx.data];
const NDArray& weight = in_data[idx.weight];
NDArray output;

if (dnnl_param.with_sum) {
output = PrepareOutputWithSum(in_data[idx.sum], out_data[out_index]);
} else {
output = out_data[out_index];
}

if (dnnl_param.quantized) {
if (!channel_wise) {
min_max_vec[kWeightMin] = in_data[idx.weight_min].data().dptr<float>()[0];
min_max_vec[kWeightMax] = in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
min_max_vec[kBiasMin] = in_data[idx.bias_min].data().dptr<float>()[0];
min_max_vec[kBiasMax] = in_data[idx.bias_max].data().dptr<float>()[0];
}
if (!channel_wise) {
min_max_vec[kWeightMin] = in_data[idx.weight_min].data().dptr<float>()[0];
min_max_vec[kWeightMax] = in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
min_max_vec[kBiasMin] = in_data[idx.bias_min].data().dptr<float>()[0];
min_max_vec[kBiasMax] = in_data[idx.bias_max].data().dptr<float>()[0];
}
min_max_vec[kDataMin] = in_data[idx.data_min].data().dptr<float>()[0];
min_max_vec[kDataMax] = in_data[idx.data_max].data().dptr<float>()[0];
}
min_max_vec[kDataMin] = in_data[idx.data_min].data().dptr<float>()[0];
min_max_vec[kDataMax] = in_data[idx.data_max].data().dptr<float>()[0];

initialized_ = CheckInitializationConditions(in_data, min_max_vec, channel_wise);

Expand Down Expand Up @@ -201,11 +201,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
dnnl::memory::desc out_md = CreateOutputMemoryDesc(output.shape(), output.dtype());
cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine);

bool support_channelwise_scale = false;

if (dnnl_param.quantized) {
support_channelwise_scale = PrepareQuantization(ctx, in_data, output, min_max_vec);
}
bool support_channelwise_scale = PrepareQuantization(ctx, in_data, output, min_max_vec);

fwd_.reset(new DNNLFullyConnectedForward(full_param_,
ctx.is_train,
Expand All @@ -226,34 +222,17 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
initialized_ = true;
}

if (dnnl_param.with_sum) {
const auto& output_mem = output.GetDNNLData();
const auto& out_mem_desc = output_mem->get_desc();
auto dst_mem_desc = fwd_->fwd_pd.dst_desc();
if (out_mem_desc != dst_mem_desc) {
auto tmp_out_mem = output.GetDNNLDataReorder(&dst_mem_desc);
dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
dnnl_mem_ptr new_out_mem(new dnnl::memory(
dst_mem_desc, CpuEngine::Get()->get_engine(), output_mem->get_data_handle()));
DNNLStream::Get()->RegisterMem(new_out_mem);
DNNLMemoryCopy(*tmp_out_mem, new_out_mem.get());
output = NDArray(new_out_mem);
}
}

if (reorder_data_) {
data = data.Reorder2Default();
}
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
cached_data_mem_->set_data_handle(reinterpret_cast<void*>(data.data().dptr<DType>()));
});
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
cached_out_mem_->set_data_handle(reinterpret_cast<void*>(output.data().dptr<DType>()));
});

cached_data_mem_->set_data_handle(reinterpret_cast<void*>(data.data().dptr_));
cached_out_mem_->set_data_handle(reinterpret_cast<void*>(output.data().dptr_));

DNNLStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
DNNLStream::Get()->Submit();

if (dnnl_param.quantized && !dnnl_param.enabled_float_output.has_value()) {
if (!dnnl_param.enabled_float_output.has_value()) {
float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
float* output_max_ptr = out_data[out_max_index].data().dptr<float>();

Expand Down Expand Up @@ -322,8 +301,7 @@ NDArray SgDNNLFCOp::PrepareOutputWithSum(const NDArray& sum_input, const NDArray
bool SgDNNLFCOp::CheckInitializationConditions(const std::vector<NDArray>& inputs,
const std::vector<float>& min_max_vec,
bool is_channel_wise) {
if (initialized_ && full_param_.dnnl_param.quantized &&
dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
if (initialized_ && dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
bool has_bias = !full_param_.default_param.no_bias;
if (cached_data_min_ != min_max_vec[kDataMin] || cached_data_max_ != min_max_vec[kDataMax] ||
cached_sum_min_ != min_max_vec[kSumMin] || cached_sum_max_ != min_max_vec[kSumMax]) {
Expand Down Expand Up @@ -381,8 +359,8 @@ bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
bool has_bias = !full_param_.default_param.no_bias;
const NDArray& data = in_data[fullc::kData];
const NDArray& weight = in_data[fullc::kWeight];
const bool channel_wise = dnnl_param.quantized && dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();
const bool channel_wise =
dnnl_param.channel_wise_quantize.has_value() && dnnl_param.channel_wise_quantize.value();

CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, cached_data_max_);
Expand Down Expand Up @@ -541,7 +519,7 @@ void SgDNNLFCOp::GetCachedWeightsAndBias(const NDArray& weight,
const bool has_id = attrs.dict.count("__identifier__");
bool read_from_cache = false;

DNNLFullyconSignature key(full_param_.default_param);
DNNLFullyconSignature key(full_param_);
if (use_cache && has_id) {
key.AddSign(fwd_->fwd_pd.weights_desc());
key.AddSign(attrs.dict["__identifier__"]);
Expand Down

0 comments on commit cca8f4e

Please sign in to comment.