Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Improve bf16 support (#21002)
Browse files Browse the repository at this point in the history
* AMP improvements + enable bf16 input for quantize_v2

* Fix sanity

* Improve tests, AMP conversion interface, fix forwad hooks

* Fix tests

* Fix imports in tests

* Use different lp16_fp32 op in test

* Add amp.disable_amp() context, fix tests

* Add tests, generalize optimization disabling

* Fix sanity

* Review fixes

* Use is_integral<>::value

* Review fixes
Change flag type to unsigned int
Add a warning for an incorrect flag attribute value

* Extend bf16 support

* Combine enable_float_output and amp_out_dtype parameters

* Add bf16 support to _dnnl_batch_dot

* Fix sanity

* Add bf16 support to all dnnl ops, add tests

* Add license

* Fix conv activation fuse, disable masked_softmax bf16 support

* Fix sanity, add softmax test cases

* Compare bf16 outputs with fp32 reference

Co-authored-by: Bartlomiej Gawrych <[email protected]>
  • Loading branch information
PawelGlomski-Intel and Bartlomiej Gawrych committed Jul 15, 2022
1 parent ded6096 commit f6d1ed1
Show file tree
Hide file tree
Showing 30 changed files with 1,152 additions and 352 deletions.
20 changes: 19 additions & 1 deletion 3rdparty/mshadow/mshadow/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ struct DataType<half::half_t> {
#endif
#endif
};
template<>
template <>
struct DataType<bfloat::bf16_t> {
static const int kFlag = kBfloat16;
static const int kLanes = 1;
Expand Down Expand Up @@ -769,6 +769,10 @@ namespace isnan_typed {
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) {
return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS;
}
template <>
MSHADOW_XINLINE bool IsNan(volatile mshadow::bfloat::bf16_t val) {
return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) > MSHADOW_BF16_EXPONENT_BITS;
}
} // namespace isnan_typed

/*! \brief
Expand All @@ -795,6 +799,10 @@ namespace isinf_typed {
MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) {
return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS;
}
template <>
MSHADOW_XINLINE bool IsInf(volatile mshadow::bfloat::bf16_t val) {
return (val.bf16_ & (~MSHADOW_BF16_SIGN_BIT)) == MSHADOW_BF16_EXPONENT_BITS;
}
} // namespace isinf_typed

/*! \brief namespace for potential reducer operations */
Expand Down Expand Up @@ -881,6 +889,11 @@ MSHADOW_XINLINE half::half_t NegInfValue<half::half_t>(void) {
return half::half_t::Binary(
MSHADOW_HALF_SIGN_BIT | MSHADOW_HALF_EXPONENT_BITS);
}
/*! \brief negative infinity value of bfloat16 */
template <>
MSHADOW_XINLINE bfloat::bf16_t NegInfValue<bfloat::bf16_t>(void) {
return bfloat::bf16_t::Binary(MSHADOW_BF16_SIGN_BIT | MSHADOW_BF16_EXPONENT_BITS);
}

/*!
* \brief maximum value of certain types
Expand Down Expand Up @@ -962,6 +975,11 @@ template<>
MSHADOW_XINLINE half::half_t PosInfValue<half::half_t>(void) {
return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS);
}
/*! \brief positive infinity value of bfloat16 */
template <>
MSHADOW_XINLINE bfloat::bf16_t PosInfValue<bfloat::bf16_t>(void) {
return bfloat::bf16_t::Binary(MSHADOW_BF16_EXPONENT_BITS);
}

} // namespace limits

Expand Down
2 changes: 2 additions & 0 deletions 3rdparty/mshadow/mshadow/bfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,8 @@ MSHADOW_BF16_OPERATOR(bool, <=)

#define MSHADOW_BF16_MIN mshadow::bfloat::bf16_t::Binary(0xFF7F);
#define MSHADOW_BF16_MAX mshadow::bfloat::bf16_t::Binary(0x7F7F);
#define MSHADOW_BF16_SIGN_BIT 0x8000
#define MSHADOW_BF16_EXPONENT_BITS 0x7f80
} // namespace bfloat
} // namespace mshadow
#endif // MSHADOW_BFLOAT_H_
1 change: 1 addition & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ enum class OptConstraint : unsigned int {
DisableAMP = 1 << 0
// DisableQuantization = 1 << 1
};
using OptConstraint_int_t = std::underlying_type_t<OptConstraint>;

/*! \brief there are three numpy shape flags based on priority.
* GlobalOn
Expand Down
Loading

0 comments on commit f6d1ed1

Please sign in to comment.