Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add bfloat16 floating-point format support based on AMP #17265

Merged
merged 61 commits into from
Feb 16, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
350da26
Add Bfloat16
ZhennanQin May 20, 2019
b39d3e1
mshadow support bf16
ElaineBao Oct 31, 2019
9c44b66
rebase bf16 mkldnn1.0
ElaineBao Nov 5, 2019
da7118c
support bf16 gemm
xinyu-intel Nov 12, 2019
c220bfc
resolve fp32 ip bwd bug
xinyu-intel Nov 18, 2019
f1055e5
add other bf16 ops
ElaineBao Nov 20, 2019
b1f8b94
change func name from fp16 to lp16 (low precision 16), to include bf16
ElaineBao Nov 20, 2019
ac3edaa
add amp_cast bf16 support for ndarray
ElaineBao Nov 21, 2019
16dd430
fix executor copy_params
ElaineBao Nov 27, 2019
99e31e9
add test case for bf16
ElaineBao Nov 27, 2019
77a6a6f
remove numpy dtype hook for bf16
ElaineBao Nov 28, 2019
22b70af
add bf16 type support
ElaineBao Dec 9, 2019
e9dd678
rebase to mxnet master
ElaineBao Dec 9, 2019
97372fd
add single conv test
ElaineBao Dec 10, 2019
c4aec63
fix symbolic inference
ElaineBao Dec 10, 2019
7a7ab3a
add dtype check when copy
ElaineBao Dec 11, 2019
e3fced6
add single conv and bn test
ElaineBao Dec 11, 2019
9b68d07
skip fp16 amp_cast test in cpu
ElaineBao Dec 16, 2019
89eaba1
Fix resnet50 first convolution
ZhennanQin Dec 17, 2019
2cb8969
Skip first convolution for bfloat16
ZhennanQin Dec 20, 2019
db28e3a
support bf16 fallback compute
rongzha1 Dec 23, 2019
e6d2b69
recover origin test
rongzha1 Dec 23, 2019
704ac96
fix bf16 bn test, enhance assert_almost_equal_with_err
ElaineBao Dec 23, 2019
f931fcb
add some bf16 unittests
wuxun-zhang Dec 19, 2019
1103bb7
using assert_almost_equal_with_err for fallback bn test
rongzha1 Dec 25, 2019
98f8b07
add relu6 bf16 support
ElaineBao Dec 29, 2019
bdb2483
fix lint
rongzha1 Jan 3, 2020
b757cc4
fix subgraph conv with data=0
ElaineBao Jan 6, 2020
b27dec5
mkldnn doesn't support 0 dim tensor
rongzha1 Jan 7, 2020
f37b3fe
rm dtype check when copy
rongzha1 Jan 7, 2020
b3e5deb
using bf16 tvm
rongzha1 Jan 8, 2020
d88bc3e
rm bf16 mnist demo
rongzha1 Jan 10, 2020
bf6c727
use official tvm
rongzha1 Jan 10, 2020
a8eab98
change function name; fix lint error
rongzha1 Jan 14, 2020
f46fcdc
fix clang check error:conditional expression is ambiguous; 'float' ca…
rongzha1 Jan 14, 2020
e39e3a6
nvcc compiler build pass
rongzha1 Jan 15, 2020
b554998
fix gpu amp cast symbol error
rongzha1 Jan 16, 2020
7db4f61
fix mnist training error
rongzha1 Jan 17, 2020
c04fe99
fix cpp test: Engine.VarVersion error
rongzha1 Jan 17, 2020
b3306c9
workaround cpp failed test mkldnn fc bwd
rongzha1 Jan 19, 2020
06a32fe
to fix mkldnn test_mkldnn_ndarray_slice error
rongzha1 Jan 20, 2020
1b090b5
1. move some code from to np_broadcast_reduce_op_value.cc to np_broad…
rongzha1 Jan 21, 2020
44f133b
use official dlpack
rongzha1 Jan 21, 2020
04a1402
rename np_broadcast_reduce_op_value_part2.cc and add some description
rongzha1 Jan 21, 2020
afe1c09
1. update dlpack url in .gitmodule
rongzha1 Jan 23, 2020
1027985
fix remaining NodePtr due to tvm update
rongzha1 Feb 10, 2020
0d83536
mv some code from mxnet_op.h to mxnet_op_kernel_assign.h to avoid WIN…
rongzha1 Feb 10, 2020
592e67f
fix WIN CPU build fail:compiler is out of heap space in pass 2
rongzha1 Feb 10, 2020
e934607
fix WIN build fail
rongzha1 Feb 11, 2020
ce35638
fix lint
rongzha1 Feb 11, 2020
4dfd91b
add print for test bf16_concat
rongzha1 Feb 12, 2020
405e2aa
fix bf16 test fail
rongzha1 Feb 12, 2020
e0246ae
disable bf16 concat test
rongzha1 Feb 12, 2020
043d315
tmp skip to root cause edge test halt
rongzha1 Feb 13, 2020
052bf79
fix bf16_bn test error
rongzha1 Feb 13, 2020
1880d95
enable test_bulk
rongzha1 Feb 14, 2020
5edd735
tmp rm bf16 to locate edge error
rongzha1 Feb 14, 2020
3bad0fe
Revert "tmp rm bf16 to locate edge error"
rongzha1 Feb 15, 2020
4a43b1d
add Apache license header
rongzha1 Feb 15, 2020
9a37a0a
trigger CI
rongzha1 Feb 15, 2020
df090c1
add robust for test bf16 bn
rongzha1 Feb 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
change func name from fp16 to lp16 (low precision 16), to include bf16
  • Loading branch information
ElaineBao authored and rongzha1 committed Feb 15, 2020
commit b1f8b945d0c3b353957d266f2c4f7901cbbaafda
44 changes: 22 additions & 22 deletions python/mxnet/contrib/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
# coding: utf-8
"""Functions for enabling AMP (automatic mixed precision)."""
__all__ = ['init', 'init_trainer', 'scale_loss', 'unscale', 'convert_model',
'convert_hybrid_block', 'list_fp16_ops', 'list_fp32_ops',
'list_fp16_fp32_ops', 'list_conditional_fp32_ops',
'convert_hybrid_block', 'list_lp16_ops', 'list_fp32_ops',
'list_lp16_fp32_ops', 'list_conditional_fp32_ops',
'list_widest_type_cast', 'list_loss_output_functions',
'convert_symbol']

Expand Down Expand Up @@ -162,7 +162,7 @@ def _new_fun(*args, **kwargs):
getattr(module, op_name_prefix[1:-1])

wrap_list = target_precision_ops if target_precision_ops is not None \
else list_fp16_ops(target_dtype)
else list_lp16_ops(target_dtype)
for fun_name in wrap_list:
try:
fun_name, cur_module = _get_fun_to_wrap(fun_name, module, submodule_dict)
Expand Down Expand Up @@ -388,11 +388,11 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
list of values of the parameter that make the operator to be casted to FP32)
excluded_sym_names : list of strs, optional
A list of strings that represent the names of symbols that users want to exclude
from being casted to FP16 or FP32.
from being casted to LP16 or FP32.
data_names : list of strs, optional
A list of strings that represent input data tensor names to the model
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
Whether to cast the arg_params and aux_params that don't require to be in LP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
Expand All @@ -404,7 +404,7 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
if target_dtype_ops is not None:
assert isinstance(target_dtype_ops, list), "target_dtype_ops should be a list of strs"
else:
target_dtype_ops = list_fp16_ops(target_dtype)
target_dtype_ops = list_lp16_ops(target_dtype)

if fp32_ops is not None:
assert isinstance(fp32_ops, list), "fp32_ops should be a list of strs"
Expand Down Expand Up @@ -450,15 +450,15 @@ def convert_symbol(sym, target_dtype="float16", target_dtype_ops=None,
"Common ops in fp32_ops and conditional_fp32_ops {}".format(common_ops)

combined_ops = set(target_dtype_ops + fp32_ops + conditional_op_names)
all_fp16_fp32_ops = set(list_fp16_ops(target_dtype) + list_fp32_ops(target_dtype)
+ list_fp16_fp32_ops(target_dtype) + original_conditional_op_names)
all_lp16_fp32_ops = set(list_lp16_ops(target_dtype) + list_fp32_ops(target_dtype)
+ list_lp16_fp32_ops(target_dtype) + original_conditional_op_names)

illegal_ops = combined_ops - all_fp16_fp32_ops
illegal_ops = combined_ops - all_lp16_fp32_ops
assert not illegal_ops, '''Can only choose ops from one of the three lists
for fp16_ops and fp32_ops
1. amp.list_fp16_ops(target_dtype)
for lp16_ops and fp32_ops
1. amp.list_lp16_ops(target_dtype)
2. amp.list_fp32_ops(target_dtype)
3. amp.list_fp16_fp32_ops(target_dtype)
3. amp.list_lp16_fp32_ops(target_dtype)
4. amp.list_conditional_fp32_ops(target_dtype)
Op %s not in any of them''' % (illegal_ops)

Expand Down Expand Up @@ -550,7 +550,7 @@ def convert_model(sym, arg_params, aux_params, target_dtype="float16", target_dt
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
Whether to cast the arg_params and aux_params that don't require to be in LP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
Expand Down Expand Up @@ -607,7 +607,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
block : HybridBlock or SymbolBlock object
FP32 HybridBlock or SymbolBlock object
target_dtype : str or numpy
currently only supports fp16. The target dtype indicates to add cast layers
currently only supports lp16. The target dtype indicates to add cast layers
when possible so that lower precision computation can be leveraged.
target_precision_ops : list of strs
Override the list of operator names casted to target_dtype.
Expand All @@ -623,7 +623,7 @@ def convert_hybrid_block(block, target_dtype="float16", target_dtype_ops=None,
ctx : Context
Context on which model parameters should live
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
Whether to cast the arg_params and aux_params that don't require to be in LP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
Expand Down Expand Up @@ -709,7 +709,7 @@ def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype
A list of strings that represent the names of symbols that users want to exclude
from being executed in lower precision.
cast_optional_params : bool, default False
Whether to cast the arg_params and aux_params that don't require to be in FP16
Whether to cast the arg_params and aux_params that don't require to be in LP16
because of a cast layer following it, but will reduce the computation and memory
overhead of the model if casted.
"""
Expand Down Expand Up @@ -744,13 +744,13 @@ def convert_bucketing_module(bucketing_mod, target_dtype="float16", target_dtype
compression_params=bucketing_mod._compression_params)
return result_mod

def list_fp16_ops(target_dtype):
"""Get the default list of FP16 ops for AMP
def list_lp16_ops(target_dtype):
"""Get the default list of LP16 ops for AMP
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.FP16_FUNCS
elif target_dtype == bfloat16:
return lists.symbol_bf16.FP16_FUNCS
return lists.symbol_bf16.BF16_FUNCS

def list_fp32_ops(target_dtype):
"""Get the default list of FP32 ops for AMP
Expand All @@ -760,13 +760,13 @@ def list_fp32_ops(target_dtype):
elif target_dtype in [bfloat16]:
return lists.symbol_bf16.FP32_FUNCS

def list_fp16_fp32_ops(target_dtype):
"""Get the default list of ops which run in both FP16 and FP32
def list_lp16_fp32_ops(target_dtype):
"""Get the default list of ops which run in both LP16 and FP32
"""
if target_dtype in ['float16', np.float16]:
return lists.symbol_fp16.FP16_FP32_FUNCS
elif target_dtype in [bfloat16]:
return lists.symbol_bf16.FP16_FP32_FUNCS
return lists.symbol_bf16.BF16_FP32_FUNCS

def list_conditional_fp32_ops(target_dtype):
"""Get the conditional fp32 ops list
Expand Down
6 changes: 3 additions & 3 deletions python/mxnet/contrib/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""

# Functions that should be cast to lower precision
FP16_FUNCS = [
BF16_FUNCS = [
'Convolution',
'FullyConnected',
]

# Functions that should not be casted, either because
# they are irrelevant (not used in the network itself
# like image transformations or optimizers) or they
# are dtype neutral (can work in both fp16 and fp32)
FP16_FP32_FUNCS = [
# are dtype neutral (can work in both bf16 and fp32)
BF16_FP32_FUNCS = [
'abs',
'_add',
'BatchNorm',
Expand Down
18 changes: 9 additions & 9 deletions tests/python/gpu/test_contrib_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,22 @@
set_default_context(mx.gpu(0))

def test_amp_coverage():
conditional = [item[0] for item in amp.lists.symbol.CONDITIONAL_FP32_FUNCS]
conditional = [item[0] for item in amp.lists.symbol_fp16.CONDITIONAL_FP32_FUNCS]

# Check for duplicates
for a in [amp.lists.symbol.FP16_FUNCS,
amp.lists.symbol.FP16_FP32_FUNCS,
amp.lists.symbol.FP32_FUNCS,
amp.lists.symbol.WIDEST_TYPE_CASTS,
for a in [amp.lists.symbol_fp16.FP16_FUNCS,
amp.lists.symbol_fp16.FP16_FP32_FUNCS,
amp.lists.symbol_fp16.FP32_FUNCS,
amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
conditional]:
ret = [item for item, count in collections.Counter(a).items() if count > 1]
assert ret == [], "Elements " + str(ret) + " are duplicated in the AMP lists."

t = []
for a in [amp.lists.symbol.FP16_FUNCS,
amp.lists.symbol.FP16_FP32_FUNCS,
amp.lists.symbol.FP32_FUNCS,
amp.lists.symbol.WIDEST_TYPE_CASTS,
for a in [amp.lists.symbol_fp16.FP16_FUNCS,
amp.lists.symbol_fp16.FP16_FP32_FUNCS,
amp.lists.symbol_fp16.FP32_FUNCS,
amp.lists.symbol_fp16.WIDEST_TYPE_CASTS,
conditional]:
t += a
ret = [item for item, count in collections.Counter(t).items() if count > 1]
Expand Down