Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
oneDNN FullyConnected weight caching & refactor (#21047)
Browse files Browse the repository at this point in the history
* FC weight and bias caching

* prepare output for sum

* check initialization conditions

* create output mem desc

* PrepareQuantization

* remove unused variables

* cleanup

* Enable BRGEMM

* Reorder functions

* make minmax enum anonymous

* node identificator & env flag

* fix sanity

* fix sanity

* apply review

* rename variable
  • Loading branch information
bgawrych committed Jul 6, 2022
1 parent b713dc5 commit 84b1626
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 308 deletions.
9 changes: 1 addition & 8 deletions src/operator/nn/dnnl/dnnl_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,6 @@ inline static dnnl::memory::desc GetMemDesc(const NDArray& arr, int dtype = -1)
return dnnl::memory::desc{dims, get_dnnl_type(dtype), dnnl::memory::format_tag::any};
}

inline static bool ChooseBRGEMMImpl(const dnnl::memory::dims& weight_dims, size_t batch_size) {
// Conditions based on measurement results done on CLX8280
// https://github.com/apache/incubator-mxnet/pull/20533
return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 16384 &&
weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
}

inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
Expand All @@ -370,7 +363,7 @@ inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
// for batch 256 alexnet benchmark test
const bool force_fc_ab_format = dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
if (force_fc_ab_format || dtype != mshadow::kInt8) {
format = dnnl::memory::format_tag::ab;
}
}
Expand Down
7 changes: 6 additions & 1 deletion src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class OpSignature {

#if MXNET_USE_ONEDNN == 1
void AddSign(const dnnl::memory::desc& desc) {
hash = hash * 2 + desc.data.format_kind;
hash = hash * 2 + desc.data.format_kind;
eles.push_back(desc.data.format_kind);
hash = hash * 2 + desc.data.data_type;
eles.push_back(desc.data.data_type);
Expand Down Expand Up @@ -617,6 +617,11 @@ class OpSignature {

#endif

void AddSign(const std::string& s) {
uint64_t key = static_cast<uint64_t>(std::hash<std::string>{}(s));
eles.push_back(key);
}

void AddSign(const std::vector<NDArray>& arrs) {
for (auto& arr : arrs) {
AddSign(arr);
Expand Down
Loading

0 comments on commit 84b1626

Please sign in to comment.