-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add transpose & repeat kernel for c++ for opencl (#258)
* add transpose kernel for c++ for opencl * add repeat kernel for c++ for opencl * disable utl::static_vector copy constructor on opencl * fix transpose & repeat view for c++ for opencl * update c++ for opencl tests * update repeat tests * update opencl copy_buffer to handle numeric type * init size_type specification for index functions * add test data for repeat and transpose indexing * fix repeat and transpose opencl kernels * add cast index function * update repeat and transpose view for opencl * update opencl tester to cast index array * update tests * fix opencl index tests
- Loading branch information
Showing
31 changed files
with
6,219 additions
and
5,179 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
set(CMAKE_C_COMPILER clang) | ||
set(CMAKE_CXX_COMPILER clang++) | ||
|
||
add_compile_options(-W -Wall -Werror -Wextra -Wno-gnu-string-literal-operator-template) | ||
add_compile_options(-W -Wall -Werror -Wextra -Wno-gnu-string-literal-operator-template -Wno-deprecated-declarations) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
#ifndef NMTOOLS_ARRAY_EVAL_OPENCL_KERNELS_TRANSPOSE_HPP | ||
#define NMTOOLS_ARRAY_EVAL_OPENCL_KERNELS_TRANSPOSE_HPP | ||
|
||
#include "nmtools/array/ndarray.hpp" | ||
#include "nmtools/array/view/ref.hpp" | ||
#include "nmtools/array/view/repeat.hpp" | ||
#include "nmtools/array/view/mutable_ref.hpp" | ||
#include "nmtools/array/eval/kernel_helper.hpp" | ||
#include "nmtools/array/eval/opencl/kernel_helper.hpp" | ||
#include "nmtools/array/index/cast.hpp" | ||
|
||
#ifndef nm_stringify | ||
#define nm_stringify(a) #a | ||
#endif | ||
|
||
#define nmtools_cl_kernel_name(out_type,inp_type) repeat##_##out_type##_##inp_type | ||
#define nmtools_cl_kernel_name_str(out_type,inp_type) nm_stringify(repeat##_##out_type##_##inp_type) | ||
|
||
#ifdef NMTOOLS_OPENCL_BUILD_KERNELS | ||
|
||
namespace nm = nmtools; | ||
namespace na = nmtools::array; | ||
namespace view = nmtools::view; | ||
namespace meta = nmtools::meta; | ||
namespace opencl = nmtools::array::opencl; | ||
namespace detail = nmtools::view::detail; | ||
|
||
#define nmtools_cl_kernel(out_type, inp_type) \ | ||
kernel void nmtools_cl_kernel_name(out_type,inp_type) \ | ||
( global out_type* out_ptr \ | ||
, global const inp_type* inp_ptr \ | ||
, global const nm_cl_index_t* out_shape_ptr \ | ||
, global const nm_cl_index_t* inp_shape_ptr \ | ||
, global const nm_cl_index_t* repeats_ptr \ | ||
, const nm_cl_index_t out_dim \ | ||
, const nm_cl_index_t inp_dim \ | ||
, const nm_cl_index_t repeats_size \ | ||
, const nm_cl_index_t axis \ | ||
) \ | ||
{ \ | ||
auto repeats = na::create_vector(repeats_ptr,repeats_size); \ | ||
auto input = na::create_array(inp_ptr,inp_shape_ptr,inp_dim); \ | ||
auto output = na::create_mutable_array(out_ptr,out_shape_ptr,out_dim); \ | ||
auto repeated = view::repeat(input,repeats,axis); \ | ||
opencl::assign_array(output,repeated); \ | ||
} | ||
|
||
nmtools_cl_kernel(float,float) | ||
nmtools_cl_kernel(double,double) | ||
|
||
#else // NMTOOLS_OPENCL_BUILD_KERNELS | ||
|
||
#include "nmtools/array/eval/opencl/context.hpp" | ||
#include <cstring> // memcpy | ||
|
||
extern unsigned char nm_cl_repeat_spv[]; | ||
extern unsigned int nm_cl_repeat_spv_len; | ||
|
||
namespace nmtools::array::opencl | ||
{ | ||
template <typename...args_t> | ||
struct kernel_t< | ||
view::decorator_t<view::repeat_t,args_t...> | ||
> { | ||
using view_t = view::decorator_t<view::repeat_t,args_t...>; | ||
|
||
view_t view; | ||
std::shared_ptr<context_t> context; | ||
|
||
static auto get_spirv() | ||
{ | ||
using vector = nmtools_list<unsigned char>; | ||
auto spirv = vector(); | ||
spirv.resize(nm_cl_repeat_spv_len); | ||
memcpy(spirv.data(),nm_cl_repeat_spv,sizeof(unsigned char) * nm_cl_repeat_spv_len); | ||
return spirv; | ||
} | ||
|
||
template <typename inp_t, typename out_t=inp_t> | ||
static auto kernel_name() | ||
{ | ||
if constexpr (meta::is_same_v<inp_t,float> && meta::is_same_v<out_t,float>) { | ||
return nmtools_cl_kernel_name_str(float,float); | ||
} else if constexpr (meta::is_same_v<inp_t,double> && meta::is_same_v<out_t,double>) { | ||
return nmtools_cl_kernel_name_str(double,double); | ||
} | ||
} | ||
|
||
template <typename output_t> | ||
auto eval(output_t& output) | ||
{ | ||
using out_t = meta::get_element_type_t<output_t>; | ||
|
||
const auto& inp_array = *get_array(view); | ||
using inp_t = meta::get_element_type_t<meta::remove_cvref_pointer_t<decltype(inp_array)>>; | ||
|
||
auto inp_buffer = context->create_buffer(inp_array); | ||
auto out_buffer = context->create_buffer<out_t>(nmtools::size(output)); | ||
|
||
uint32_t repeats_size = nmtools::len(view.repeats); | ||
uint32_t axis = view.axis; | ||
|
||
auto kernel_name = this->kernel_name<inp_t,out_t>(); | ||
|
||
if (!context->has_kernel(kernel_name)) { | ||
context->create_kernel(get_spirv(),kernel_name); | ||
} | ||
|
||
auto kernel = context->get_kernel(kernel_name); | ||
|
||
auto out_size = nmtools::size(output); | ||
[[maybe_unused]] auto inp_size = nmtools::size(inp_array); | ||
[[maybe_unused]] auto dst_size = nmtools::size(view); | ||
|
||
auto out_shape = nmtools::shape(output); | ||
auto inp_shape = nmtools::shape(inp_array); | ||
|
||
auto out_shape_buffer = context->create_buffer(index::cast<nm_cl_index_t>(out_shape)); | ||
auto inp_shape_buffer = context->create_buffer(index::cast<nm_cl_index_t>(inp_shape)); | ||
auto repeats_buffer = context->create_buffer(index::cast<nm_cl_index_t>(view.repeats)); | ||
|
||
uint32_t out_dim = nmtools::len(out_shape); | ||
uint32_t inp_dim = nmtools::len(inp_shape); | ||
|
||
auto kernel_info = kernel.kernel_info_; | ||
auto local_size = nmtools_array{kernel_info->preferred_work_group_size_multiple}; | ||
auto global_size = nmtools_array{size_t(std::ceil(float(out_size) / local_size[0])) * local_size[0]}; | ||
|
||
auto default_args = nmtools_tuple{out_buffer,inp_buffer,out_shape_buffer,inp_shape_buffer,repeats_buffer,index::cast<nm_cl_index_t>(out_dim),index::cast<nm_cl_index_t>(inp_dim),index::cast<nm_cl_index_t>(repeats_size),index::cast<nm_cl_index_t>(axis)}; | ||
|
||
context->set_args(kernel,default_args); | ||
context->run(kernel,out_buffer,output,global_size,local_size); | ||
} | ||
}; | ||
} | ||
|
||
#endif // NMTOOLS_OPENCL_BUILD_KERNELS | ||
|
||
#undef nmtools_cl_kernel_bin | ||
#undef nmtools_cl_kernel_len | ||
#undef nmtools_cl_ufunc_name | ||
#undef nmtools_cl_ufunc_type | ||
#undef nmtools_cl_kernel_name | ||
#undef nmtools_cl_kernel_name_str | ||
|
||
#endif // NMTOOLS_ARRAY_EVAL_OPENCL_KERNELS_TRANSPOSE_HPP |
Oops, something went wrong.