Skip to content

Commit

Permalink
Added accessors to vector types
Browse files Browse the repository at this point in the history
  • Loading branch information
amirshavit authored and jszuppe committed Feb 10, 2019
1 parent 924ed68 commit 04f7d58
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 21 deletions.
84 changes: 63 additions & 21 deletions include/boost/compute/types/fundamental.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,33 +44,88 @@ typedef cl_double double_;
#define BOOST_COMPUTE_MAKE_VECTOR_TYPE(scalar, size) \
BOOST_PP_CAT(BOOST_PP_CAT(::boost::compute::scalar, size), _)

namespace detail {

// specialized vector_type base classes that provide the
// (x,y), (x,y,z,w), (s0..s7), (s0..sf) accessors
template<class Scalar, size_t N> class vector_type_desc;

template<class Scalar>
class vector_type_desc<Scalar, 2>
{
public:
Scalar x, y;

Scalar& operator[](size_t i)
{
return (&x)[i];
}

const Scalar operator[](size_t i) const
{
return (&x)[i];
}
};

template<class Scalar>
class vector_type_desc<Scalar, 4> : public vector_type_desc<Scalar, 2>
{
public:
Scalar z, w;
};

template<class Scalar>
class vector_type_desc<Scalar, 8>
{
public:
Scalar s0, s1, s2, s3, s4, s5, s6, s7;

Scalar& operator[](size_t i)
{
return (&s0)[i];
}

const Scalar operator[](size_t i) const
{
return (&s0)[i];
}
};

template<class Scalar>
class vector_type_desc<Scalar, 16> : public vector_type_desc<Scalar, 8>
{
public:
Scalar s8, s9, sa, sb, sc, sd, se, sf;
};

} // end detail namespace

// vector data types
template<class Scalar, size_t N>
class vector_type
class vector_type : public detail::vector_type_desc<Scalar, N>
{
public:
typedef Scalar scalar_type;

vector_type()
: m_value()
{
}

explicit vector_type(const Scalar scalar)
{
for(size_t i = 0; i < N; i++)
m_value[i] = scalar;
(*this)[i] = scalar;
}

vector_type(const vector_type<Scalar, N> &other)
{
std::memcpy(m_value, other.m_value, sizeof(m_value));
std::memcpy(this, &other, sizeof(Scalar) * N);
}

vector_type<Scalar, N>&
operator=(const vector_type<Scalar, N> &other)
{
std::memcpy(m_value, other.m_value, sizeof(m_value));
std::memcpy(this, &other, sizeof(Scalar) * N);
return *this;
}

Expand All @@ -79,38 +134,25 @@ class vector_type
return N;
}

Scalar& operator[](size_t i)
{
return m_value[i];
}

Scalar operator[](size_t i) const
{
return m_value[i];
}

bool operator==(const vector_type<Scalar, N> &other) const
{
return std::memcmp(m_value, other.m_value, sizeof(m_value)) == 0;
return std::memcmp(this, &other, sizeof(Scalar) * N) == 0;
}

bool operator!=(const vector_type<Scalar, N> &other) const
{
return !(*this == other);
}

protected:
scalar_type m_value[N];
};

#define BOOST_COMPUTE_VECTOR_TYPE_CTOR_ARG_FUNCTION(z, i, _) \
BOOST_PP_COMMA_IF(i) scalar_type BOOST_PP_CAT(arg, i)
#define BOOST_COMPUTE_VECTOR_TYPE_DECLARE_CTOR_ARGS(scalar, size) \
BOOST_PP_REPEAT(size, BOOST_COMPUTE_VECTOR_TYPE_CTOR_ARG_FUNCTION, _)
#define BOOST_COMPUTE_VECTOR_TYPE_ASSIGN_CTOR_ARG(z, i, _) \
m_value[i] = BOOST_PP_CAT(arg, i);
(*this)[i] = BOOST_PP_CAT(arg, i);
#define BOOST_COMPUTE_VECTOR_TYPE_ASSIGN_CTOR_SINGLE_ARG(z, i, _) \
m_value[i] = arg;
(*this)[i] = arg;

#define BOOST_COMPUTE_DECLARE_VECTOR_TYPE_CLASS(cl_scalar, size, class_name) \
class class_name : public vector_type<cl_scalar, size> \
Expand Down
53 changes: 53 additions & 0 deletions test/test_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,56 @@ BOOST_AUTO_TEST_CASE(vector_string)
stream << boost::compute::int2_(1, 2);
BOOST_CHECK_EQUAL(stream.str(), std::string("int2(1, 2)"));
}

BOOST_AUTO_TEST_CASE(vector_accessors_basic)
{
boost::compute::float4_ v;
v.x = 1;
v.y = 2;
v.z = 3;
v.w = 4;
BOOST_CHECK(v == boost::compute::float4_(1, 2, 3, 4));
}

BOOST_AUTO_TEST_CASE(vector_accessors_all)
{
boost::compute::int2_ i2(1, 2);
BOOST_CHECK_EQUAL(i2.x, 1);
BOOST_CHECK_EQUAL(i2.y, 2);

boost::compute::int4_ i4(1, 2, 3, 4);
BOOST_CHECK_EQUAL(i4.x, 1);
BOOST_CHECK_EQUAL(i4.y, 2);
BOOST_CHECK_EQUAL(i4.z, 3);
BOOST_CHECK_EQUAL(i4.w, 4);

boost::compute::int8_ i8(1, 2, 3, 4, 5, 6, 7, 8);
BOOST_CHECK_EQUAL(i8.s0, 1);
BOOST_CHECK_EQUAL(i8.s1, 2);
BOOST_CHECK_EQUAL(i8.s2, 3);
BOOST_CHECK_EQUAL(i8.s3, 4);
BOOST_CHECK_EQUAL(i8.s4, 5);
BOOST_CHECK_EQUAL(i8.s5, 6);
BOOST_CHECK_EQUAL(i8.s6, 7);
BOOST_CHECK_EQUAL(i8.s7, 8);

boost::compute::int16_ i16(
1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16);
BOOST_CHECK_EQUAL(i16.s0, 1);
BOOST_CHECK_EQUAL(i16.s1, 2);
BOOST_CHECK_EQUAL(i16.s2, 3);
BOOST_CHECK_EQUAL(i16.s3, 4);
BOOST_CHECK_EQUAL(i16.s4, 5);
BOOST_CHECK_EQUAL(i16.s5, 6);
BOOST_CHECK_EQUAL(i16.s6, 7);
BOOST_CHECK_EQUAL(i16.s7, 8);
BOOST_CHECK_EQUAL(i16.s8, 9);
BOOST_CHECK_EQUAL(i16.s9, 10);
BOOST_CHECK_EQUAL(i16.sa, 11);
BOOST_CHECK_EQUAL(i16.sb, 12);
BOOST_CHECK_EQUAL(i16.sc, 13);
BOOST_CHECK_EQUAL(i16.sd, 14);
BOOST_CHECK_EQUAL(i16.se, 15);
BOOST_CHECK_EQUAL(i16.sf, 16);
}

0 comments on commit 04f7d58

Please sign in to comment.