Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] Fix floor divide #21096

Merged
merged 4 commits into from
Jul 18, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Made separate version for half_t with division made on float
  • Loading branch information
Kacper-Pietkun committed Jul 14, 2022
commit 099dca44c0e906bef7adcacd276583222bc2ea8f
15 changes: 10 additions & 5 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,10 @@ struct rtrue_divide : public mxnet_op::tunable {
/***** floor_divide ******/

struct floor_divide : public mxnet_op::tunable {
template <typename DType,
typename std::enable_if<!std::is_same<DType, bool>::value &&
(std::is_integral<DType>::value ||
std::is_same<DType, mshadow::half::half_t>::value),
int>::type = 0>
template <
typename DType,
typename std::enable_if<!std::is_same<DType, bool>::value && std::is_integral<DType>::value,
int>::type = 0>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
return static_cast<DType>(::floor(static_cast<double>(a) / static_cast<double>(b)));
}
Expand All @@ -249,6 +248,12 @@ struct floor_divide : public mxnet_op::tunable {
return static_cast<bool>(::floor(a / b));
}

MSHADOW_XINLINE static mshadow::half::half_t Map(mshadow::half::half_t a,
mshadow::half::half_t b) {
return static_cast<mshadow::half::half_t>(
::floor(static_cast<float>(a) / static_cast<float>(b)));
}

template <typename DType,
typename std::enable_if<!std::is_integral<DType>::value &&
!std::is_same<DType, float>::value &&
Expand Down