Skip to content

Commit

Permalink
Support outer simd on ndarray (#253)
Browse files Browse the repository at this point in the history
* fix outer simd padding evaluation

* update outer simd indexing to support ndarray

* testing logging using newline after label instead of tab

* add outer simd ndarray tests for sse and avx

* add outer simd ndarray tests for simde avx512

* add outer simd ndarray tests for gcc vector extensions
  • Loading branch information
alifahrri committed Oct 26, 2023
1 parent fb43d8a commit 223e859
Show file tree
Hide file tree
Showing 19 changed files with 6,698 additions and 23 deletions.
3 changes: 2 additions & 1 deletion include/nmtools/array/eval/simd/evaluator/ufunc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ namespace nmtools::array
if (static_cast<int>(out_tag) == n_pad) {
auto n_pad = static_cast<int>(out_tag);
for (size_t i=0; i<(n_simd_pack - n_pad); i++) {
auto lhs = lhs_data_ptr[lhs_offset+i];
// lhs is always broadcasted
auto lhs = lhs_data_ptr[lhs_offset];
auto rhs = rhs_data_ptr[rhs_offset+i];
out_data_ptr[out_offset+i] = view.op(lhs,rhs);
}
Expand Down
119 changes: 100 additions & 19 deletions include/nmtools/array/eval/simd/index/ufunc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#include "nmtools/array/index/product.hpp"
#include "nmtools/array/shape.hpp"
#include "nmtools/array/eval/simd/index/common.hpp"
#include "nmtools/array/index/compute_offset.hpp"
#include "nmtools/array/index/compute_strides.hpp"
#include "nmtools/array/index/compute_indices.hpp"

namespace nmtools::index
{
Expand Down Expand Up @@ -266,8 +269,10 @@ namespace nmtools::index
}

template <auto N_ELEM_PACK, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t>
constexpr auto outer_simd_shape(meta::as_type<N_ELEM_PACK>, const out_shape_t& out_shape, const lhs_shape_t& lhs_shape, const rhs_shape_t&)
constexpr auto outer_simd_shape(meta::as_type<N_ELEM_PACK>, const out_shape_t& out_shape, const lhs_shape_t& lhs_shape, const rhs_shape_t& rhs_shape)
{
// TODO: compile-time inference
// same dim as out
using result_t = out_shape_t;
auto result = result_t {};

Expand All @@ -279,33 +284,89 @@ namespace nmtools::index
const auto n_ops = at(out_shape,meta::ct_v<-1>);
const auto n_packed_ops = n_ops / N_ELEM_PACK;

// assume lhs is 1D, rhs is 1D, and out is 2D
at(result,meta::ct_v<0>) = at(lhs_shape,meta::ct_v<0>);
at(result,meta::ct_v<1>) = n_packed_ops + (n_ops % N_ELEM_PACK ? 1 : 0);
auto lhs_dim = len(lhs_shape);
for (size_t i=0; i<lhs_dim; i++) {
at(result,i) = at(lhs_shape,i);
}
auto rhs_dim = len(rhs_shape);
for (size_t i=0; i<rhs_dim; i++) {
at(result,lhs_dim+i) = at(rhs_shape,i);
}
at(result,meta::ct_v<-1>) = n_packed_ops + (n_ops % N_ELEM_PACK ? 1 : 0);

return result;
} // outer_simd_shape

template <typename index_t=size_t, auto N_ELEM_PACK, typename simd_index_t, typename simd_shape_t, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t>
constexpr auto outer_simd(meta::as_type<N_ELEM_PACK>, const simd_index_t& simd_index, const simd_shape_t&, const out_shape_t& out_shape, const lhs_shape_t&, const rhs_shape_t&)
template <typename index_t=size_t, auto N_ELEM_PACK, typename simd_index_t, typename simd_shape_t, typename out_strides_t, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t, typename lhs_strides_t, typename rhs_strides_t>
constexpr auto outer_simd(meta::as_type<N_ELEM_PACK>, const simd_index_t& simd_index, const simd_shape_t&, const out_strides_t& out_strides, const out_shape_t& out_shape, const lhs_shape_t& lhs_shape, const rhs_shape_t& rhs_shape, const lhs_strides_t& lhs_strides, const rhs_strides_t& rhs_strides)
{
using tagged_index_t = nmtools_tuple<SIMD,index_t>;
using result_t = nmtools_array<tagged_index_t,3>;

const auto compute_outer_simd_offset = [](const auto& indices, const auto& strides){
index_t offset = 0;
auto m = len(indices)-1;
for (index_t i=0; i<m; i++) {
offset += static_cast<index_t>(at(strides,i)) * static_cast<index_t>(at(indices,i));
}
return offset;
};

const auto n_ops = at(out_shape,meta::ct_v<-1>);
const auto n_packed_ops = n_ops / N_ELEM_PACK;

// assume simd_index is 2D
const auto s_i = at(simd_index,meta::ct_v<0>);
const auto s_j = at(simd_index,meta::ct_v<1>);
const auto s_j = at(simd_index,meta::ct_v<-1>);

const auto out_tag = s_j * N_ELEM_PACK + N_ELEM_PACK > n_ops ? static_cast<SIMD>(N_ELEM_PACK - (n_ops - (n_packed_ops * N_ELEM_PACK))) : SIMD::PACKED;
const auto lhs_tag = SIMD::BROADCAST;
const auto rhs_tag = out_tag;

const auto out_offset = (out_tag == SIMD::PACKED ? (s_i * n_ops) + (s_j * N_ELEM_PACK) : (s_i * n_ops) + (s_j) * N_ELEM_PACK);
const auto lhs_offset = s_i;
const auto rhs_offset = (rhs_tag == SIMD::PACKED ? (s_j * N_ELEM_PACK) : (s_j) * N_ELEM_PACK);
const auto outer_offset = compute_outer_simd_offset(simd_index,out_strides);

const auto compute_offset = [](const auto& indices, const auto& strides, index_t start_dim, index_t N){
index_t offset = 0;
for (index_t i=0; i<N; i++) {
index_t stride = at(strides,i);
index_t index = at(indices,i+start_dim);
offset += stride * index;
}
return offset;
};
const auto out_offset = outer_offset + (s_j * N_ELEM_PACK);

auto lhs_dim = len(lhs_shape);
auto rhs_dim = len(rhs_shape);
const auto lhs_offset = [&]()->index_t{
switch (lhs_dim) {
case 1: {
return at(simd_index,meta::ct_v<0>);
} break;
case 2: {
// only works for 2-dim
return at(simd_index,meta::ct_v<0>) * at(out_shape,meta::ct_v<1>) + at(simd_index,meta::ct_v<1>);
} break;
default: {
return compute_offset(simd_index,lhs_strides,0,lhs_dim);
}
}
}();
const auto rhs_offset = [&]()->index_t{
switch (rhs_dim) {
case 1: {
return s_j * N_ELEM_PACK;
} break;
case 2: {
// only works for 2-dim
return at(simd_index,meta::ct_v<-2>) * at(out_shape,meta::ct_v<-1>) + (s_j * N_ELEM_PACK);
} break;
default: {
auto rhs_offset = compute_offset(simd_index,rhs_strides,lhs_dim,rhs_dim-1);
return rhs_offset + (s_j * N_ELEM_PACK);
}
}
// (rhs_tag == SIMD::PACKED ? (s_j * N_ELEM_PACK) : (s_j) * N_ELEM_PACK);
}();

auto result = result_t {};
at(result,0) = tagged_index_t{out_tag,out_offset};
Expand All @@ -315,6 +376,15 @@ namespace nmtools::index
return result;
} // outer_simd

template <typename index_t=size_t, auto N_ELEM_PACK, typename simd_index_t, typename simd_shape_t, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t>
constexpr auto outer_simd(meta::as_type<N_ELEM_PACK> n_elem_pack, const simd_index_t& simd_index, const simd_shape_t& simd_shape, const out_shape_t& out_shape, const lhs_shape_t& lhs_shape, const rhs_shape_t& rhs_shape)
{
const auto out_strides = compute_strides(out_shape);
const auto lhs_strides = compute_strides(lhs_shape);
const auto rhs_strides = compute_strides(rhs_shape);
return outer_simd(n_elem_pack,simd_index,simd_shape,out_strides,out_shape,lhs_shape,rhs_shape,lhs_strides,rhs_strides);
} // outer_simd

template <typename index_t, auto N_ELEM_PACK, typename out_shape_t, typename lhs_shape_t, typename rhs_shape_t>
struct outer_simd_enumerator_t
{
Expand All @@ -325,21 +395,33 @@ namespace nmtools::index
using simd_shape_type = out_shape_t;
using index_type = index_t;
using size_type = index_t;
using simd_index_type = nmtools_array<index_type,2>;
using simd_index_type = simd_shape_type;

using out_strides_type = meta::remove_cvref_t<decltype(compute_strides(meta::declval<out_shape_type>()))>;
using lhs_strides_type = meta::remove_cvref_t<decltype(compute_strides(meta::declval<lhs_shape_type>()))>;
using rhs_strides_type = meta::remove_cvref_t<decltype(compute_strides(meta::declval<rhs_shape_type>()))>;
using simd_strides_type = meta::remove_cvref_t<decltype(compute_strides(meta::declval<simd_shape_type>()))>;

meta::as_type<N_ELEM_PACK> n_elem_pack;
out_shape_type out_shape;
lhs_shape_type lhs_shape;
rhs_shape_type rhs_shape
;
out_shape_type out_shape;
lhs_shape_type lhs_shape;
rhs_shape_type rhs_shape;
simd_shape_type simd_shape;
out_strides_type out_strides;
lhs_strides_type lhs_strides;
rhs_strides_type rhs_strides;
simd_strides_type simd_strides;

outer_simd_enumerator_t(meta::as_type<N_ELEM_PACK>, const out_shape_t& out_shape_, const lhs_shape_t& lhs_shape_, const rhs_shape_t& rhs_shape_)
: n_elem_pack{}
, out_shape(out_shape_)
, lhs_shape(lhs_shape_)
, rhs_shape(rhs_shape_)
, simd_shape(outer_simd_shape(n_elem_pack,out_shape_,lhs_shape_,rhs_shape_))
, out_strides(compute_strides(out_shape))
, lhs_strides(compute_strides(lhs_shape))
, rhs_strides(compute_strides(rhs_shape))
, simd_strides(compute_strides(simd_shape))
{}

constexpr auto size() const noexcept
Expand All @@ -349,9 +431,8 @@ namespace nmtools::index

constexpr auto operator[](index_type i) const
{
auto index_i = i / at(simd_shape,meta::ct_v<1>);
auto index_j = i % at(simd_shape,meta::ct_v<1>);
return outer_simd(n_elem_pack,simd_index_type{index_i,index_j},simd_shape,out_shape,lhs_shape,rhs_shape);
auto index = compute_indices(i,simd_shape,simd_strides);
return outer_simd(n_elem_pack,index,simd_shape,out_strides,out_shape,lhs_shape,rhs_shape,lhs_strides,rhs_strides);
}
}; // outer_simd_enumerator_t

Expand Down
4 changes: 2 additions & 2 deletions include/nmtools/testing/testing.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ EXPECT_TRUE(isequal(result,expect)) \
auto result_ = isclose(result,expect,NMTOOLS_TESTING_OUTPUT_PRECISION); \
std::string message {}; \
message = message + \
+ "\n\tActual : " + STRINGIFY(result) \
+ "\n\tExpected: " + STRINGIFY(expect); \
+ "\n\tActual :\n" + STRINGIFY(result) \
+ "\n\tExpected:\n" + STRINGIFY(expect); \
NMTOOLS_CHECK_MESSAGE( result_, message ); \
}

Expand Down
4 changes: 4 additions & 0 deletions tests/simd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ if (NMTOOLS_SIMD_TEST_SSE)
x86/binary_sse.cpp
x86/matmul_sse.cpp
x86/outer_sse.cpp
x86/outer_2d_sse.cpp
x86/outer_nd_sse.cpp
)
if (NMTOOLS_SIMD_TEST_REDUCTION)
set(NMTOOLS_SIMD_TEST_SOURCES ${NMTOOLS_SIMD_TEST_SOURCES}
Expand All @@ -54,6 +56,8 @@ if (NMTOOLS_SIMD_TEST_AVX)
x86/binary_avx_broadcast.cpp
x86/matmul_avx.cpp
x86/outer_avx.cpp
x86/outer_2d_avx.cpp
x86/outer_nd_avx.cpp
)
if (NMTOOLS_SIMD_TEST_REDUCTION)
set(NMTOOLS_SIMD_TEST_SOURCES ${NMTOOLS_SIMD_TEST_SOURCES}
Expand Down
Loading

0 comments on commit 223e859

Please sign in to comment.