Skip to content

Commit

Permalink
Fixed radix sort with custom float type (#551)
Browse files Browse the repository at this point in the history
* fix(test): custom_float_type test for downstream projects reflects their usage pattern

* fix(radix sort): Build fix for custom floating point types used in downstream projects
  • Loading branch information
mfep committed Apr 16, 2024
1 parent bc3d0b0 commit 6c6328b
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ class segmented_warp_sort_helper<
storage_type& storage,
unsigned int begin_bit,
unsigned int end_bit)
-> std::enable_if_t<is_floating_point<K>::value>
-> std::enable_if_t<!is_integral<K>::value>
{
(void)begin_bit;
(void)end_bit;
Expand All @@ -597,7 +597,7 @@ class segmented_warp_sort_helper<
storage_type& storage,
unsigned int begin_bit,
unsigned int end_bit)
-> std::enable_if_t<!is_floating_point<K>::value>
-> std::enable_if_t<is_integral<K>::value>
{
if(begin_bit == 0 && end_bit == 8 * sizeof(key_type))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ auto invoke_merge_sort_block_merge(
bool debug_synchronous,
typename std::iterator_traits<KeysIterator>::value_type* keys_buffer,
typename std::iterator_traits<ValuesIterator>::value_type* values_buffer)
-> std::enable_if_t<
!is_floating_point<typename std::iterator_traits<KeysIterator>::value_type>::value,
hipError_t>
-> std::enable_if_t<is_integral<typename std::iterator_traits<KeysIterator>::value_type>::value,
hipError_t>
{
using key_type = typename std::iterator_traits<KeysIterator>::value_type;
(void)decomposer;
Expand Down Expand Up @@ -101,7 +100,7 @@ auto invoke_merge_sort_block_merge(
typename std::iterator_traits<KeysIterator>::value_type* keys_buffer,
typename std::iterator_traits<ValuesIterator>::value_type* values_buffer)
-> std::enable_if_t<
is_floating_point<typename std::iterator_traits<KeysIterator>::value_type>::value,
!is_integral<typename std::iterator_traits<KeysIterator>::value_type>::value,
hipError_t>
{
using key_type = typename std::iterator_traits<KeysIterator>::value_type;
Expand Down
3 changes: 2 additions & 1 deletion test/rocprim/test_device_radix_sort.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ auto generate_key_input(KeyIter keys_input, size_t size, engine_type& rng_engine
// Working around custom_float_test_type, which is both a float and a custom_test_type
template<class T>
constexpr bool is_custom_not_float_test_type
= test_utils::is_custom_test_type<T>::value && !rocprim::is_floating_point<T>::value;
= test_utils::is_custom_test_type<T>::value
&& !std::is_same<test_utils::custom_float_type, T>::value;

template<class Config, bool Descending, class Key>
auto invoke_sort_keys(void* d_temporary_storage,
Expand Down
8 changes: 4 additions & 4 deletions test/rocprim/test_utils_custom_float_type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,6 @@ struct inner_type<custom_float_type>
namespace rocprim
{

template<>
struct is_floating_point<test_utils::custom_float_type> : std::true_type
{};

namespace detail
{

Expand All @@ -131,6 +127,10 @@ struct radix_key_codec_base<test_utils::custom_float_type>
: radix_key_codec_floating<test_utils::custom_float_type, unsigned int>
{};

static_assert(!is_floating_point<test_utils::custom_float_type>::value,
"custom_float_type must not be rocprim::is_floating_point, "
"since that is how downstream libraries use it.");

} // namespace detail
} // namespace rocprim

Expand Down

0 comments on commit 6c6328b

Please sign in to comment.