Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add bfloat16 floating-point format support based on AMP #17265

Merged
merged 61 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
350da26
Add Bfloat16
ZhennanQin May 20, 2019
b39d3e1
mshadow support bf16
ElaineBao Oct 31, 2019
9c44b66
rebase bf16 mkldnn1.0
ElaineBao Nov 5, 2019
da7118c
support bf16 gemm
xinyu-intel Nov 12, 2019
c220bfc
resolve fp32 ip bwd bug
xinyu-intel Nov 18, 2019
f1055e5
add other bf16 ops
ElaineBao Nov 20, 2019
b1f8b94
change func name from fp16 to lp16 (low precision 16), to include bf16
ElaineBao Nov 20, 2019
ac3edaa
add amp_cast bf16 support for ndarray
ElaineBao Nov 21, 2019
16dd430
fix executor copy_params
ElaineBao Nov 27, 2019
99e31e9
add test case for bf16
ElaineBao Nov 27, 2019
77a6a6f
remove numpy dtype hook for bf16
ElaineBao Nov 28, 2019
22b70af
add bf16 type support
ElaineBao Dec 9, 2019
e9dd678
rebase to mxnet master
ElaineBao Dec 9, 2019
97372fd
add single conv test
ElaineBao Dec 10, 2019
c4aec63
fix symbolic inference
ElaineBao Dec 10, 2019
7a7ab3a
add dtype check when copy
ElaineBao Dec 11, 2019
e3fced6
add single conv and bn test
ElaineBao Dec 11, 2019
9b68d07
skip fp16 amp_cast test in cpu
ElaineBao Dec 16, 2019
89eaba1
Fix resnet50 first convolution
ZhennanQin Dec 17, 2019
2cb8969
Skip first convolution for bfloat16
ZhennanQin Dec 20, 2019
db28e3a
support bf16 fallback compute
rongzha1 Dec 23, 2019
e6d2b69
recover origin test
rongzha1 Dec 23, 2019
704ac96
fix bf16 bn test, enhance assert_almost_equal_with_err
ElaineBao Dec 23, 2019
f931fcb
add some bf16 unittests
wuxun-zhang Dec 19, 2019
1103bb7
using assert_almost_equal_with_err for fallback bn test
rongzha1 Dec 25, 2019
98f8b07
add relu6 bf16 support
ElaineBao Dec 29, 2019
bdb2483
fix lint
rongzha1 Jan 3, 2020
b757cc4
fix subgraph conv with data=0
ElaineBao Jan 6, 2020
b27dec5
mkldnn doesn't support 0 dim tensor
rongzha1 Jan 7, 2020
f37b3fe
rm dtype check when copy
rongzha1 Jan 7, 2020
b3e5deb
using bf16 tvm
rongzha1 Jan 8, 2020
d88bc3e
rm bf16 mnist demo
rongzha1 Jan 10, 2020
bf6c727
use official tvm
rongzha1 Jan 10, 2020
a8eab98
change function name; fix lint error
rongzha1 Jan 14, 2020
f46fcdc
fix clang check error:conditional expression is ambiguous; 'float' ca…
rongzha1 Jan 14, 2020
e39e3a6
nvcc compiler build pass
rongzha1 Jan 15, 2020
b554998
fix gpu amp cast symbol error
rongzha1 Jan 16, 2020
7db4f61
fix mnist training error
rongzha1 Jan 17, 2020
c04fe99
fix cpp test: Engine.VarVersion error
rongzha1 Jan 17, 2020
b3306c9
workaround cpp failed test mkldnn fc bwd
rongzha1 Jan 19, 2020
06a32fe
to fix mkldnn test_mkldnn_ndarray_slice error
rongzha1 Jan 20, 2020
1b090b5
1. move some code from to np_broadcast_reduce_op_value.cc to np_broad…
rongzha1 Jan 21, 2020
44f133b
use official dlpack
rongzha1 Jan 21, 2020
04a1402
rename np_broadcast_reduce_op_value_part2.cc and add some description
rongzha1 Jan 21, 2020
afe1c09
1. update dlpack url in .gitmodule
rongzha1 Jan 23, 2020
1027985
fix remaining NodePtr due to tvm update
rongzha1 Feb 10, 2020
0d83536
mv some code from mxnet_op.h to mxnet_op_kernel_assign.h to avoid WIN…
rongzha1 Feb 10, 2020
592e67f
fix WIN CPU build fail:compiler is out of heap space in pass 2
rongzha1 Feb 10, 2020
e934607
fix WIN build fail
rongzha1 Feb 11, 2020
ce35638
fix lint
rongzha1 Feb 11, 2020
4dfd91b
add print for test bf16_concat
rongzha1 Feb 12, 2020
405e2aa
fix bf16 test fail
rongzha1 Feb 12, 2020
e0246ae
disable bf16 concat test
rongzha1 Feb 12, 2020
043d315
tmp skip to root cause edge test halt
rongzha1 Feb 13, 2020
052bf79
fix bf16_bn test error
rongzha1 Feb 13, 2020
1880d95
enable test_bulk
rongzha1 Feb 14, 2020
5edd735
tmp rm bf16 to locate edge error
rongzha1 Feb 14, 2020
3bad0fe
Revert "tmp rm bf16 to locate edge error"
rongzha1 Feb 15, 2020
4a43b1d
add Apache license header
rongzha1 Feb 15, 2020
9a37a0a
trigger CI
rongzha1 Feb 15, 2020
df090c1
add robust for test bf16 bn
rongzha1 Feb 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
resolve fp32 ip bwd bug
  • Loading branch information
xinyu-intel authored and rongzha1 committed Feb 15, 2020
commit c220bfc474525e97f7cd5ae42098ecc6918558c5
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) {
case mshadow::kUint8:
return mkldnn::memory::data_type::u8;
default:
LOG(FATAL) << "unknown type for MKLDNN:" << dtype;
LOG(FATAL) << "unknown type for MKLDNN";
return mkldnn::memory::data_type::undef;
}
}
Expand Down
36 changes: 18 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,24 +273,6 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad));

CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
Expand Down Expand Up @@ -319,6 +301,24 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
mkldnn_args_map_t args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
MKLDNNStream::Get()->Submit();
}

Expand Down
2 changes: 0 additions & 2 deletions src/operator/tensor/elemwise_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,6 @@ void ElementWiseSumComputeExCPU(const nnvm::NodeAttrs& attrs,
ResourceRequest(ResourceRequest::kTempSpace));
NDArray out_nd = outputs[0];
mxnet::ndarray::ElementwiseSum<cpu>(s, rsc, inputs, &out_nd);
std::cout << "src/operator/tensor/elemwise_sum.cc: not fallback";
// FallBackCompute(ElementWiseSumCompute<cpu>, attrs, ctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 1
} else if (IsMKLDNNData(inputs)) {
MKLDNNRun(MKLDNNSumForward, attrs, ctx, inputs, req, outputs);
Expand Down