Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support basic operations on batch constant #955

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions include/xsimd/types/xsimd_batch_constant.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ namespace xsimd
template <class batch_type, bool... Values>
struct batch_bool_constant
{

public:
static constexpr std::size_t size = sizeof...(Values);
using arch_type = typename batch_type::arch_type;
using value_type = bool;
Expand All @@ -47,11 +49,67 @@ namespace xsimd

private:
static constexpr int mask_helper(int acc) noexcept { return acc; }

template <class... Tys>
static constexpr int mask_helper(int acc, int mask, Tys... masks) noexcept
{
return mask_helper(acc | mask, (masks << 1)...);
}

struct logical_or
{
constexpr bool operator()(bool x, bool y) const { return x || y; }
};
struct logical_and
{
constexpr bool operator()(bool x, bool y) const { return x && y; }
};
struct logical_xor
{
constexpr bool operator()(bool x, bool y) const { return x ^ y; }
};

template <class F, class SelfPack, class OtherPack, size_t... Indices>
static constexpr batch_bool_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
apply(detail::index_sequence<Indices...>)
{
return {};
}

template <class F, bool... OtherValues>
static constexpr auto apply(batch_bool_constant<batch_type, Values...>, batch_bool_constant<batch_type, OtherValues...>)
-> decltype(apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
{
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
return apply<F, std::tuple<std::integral_constant<bool, Values>...>, std::tuple<std::integral_constant<bool, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
}

public:
#define MAKE_BINARY_OP(OP, NAME) \
template <bool... OtherValues> \
constexpr auto operator OP(batch_bool_constant<batch_type, OtherValues...> other) const \
->decltype(apply<NAME>(*this, other)) \
{ \
return apply<NAME>(*this, other); \
}

MAKE_BINARY_OP(|, logical_or)
MAKE_BINARY_OP(||, logical_or)
MAKE_BINARY_OP(&, logical_and)
MAKE_BINARY_OP(&&, logical_and)
MAKE_BINARY_OP(^, logical_xor)

#undef MAKE_BINARY_OP

constexpr batch_bool_constant<batch_type, !Values...> operator!() const
{
return {};
}

constexpr batch_bool_constant<batch_type, !Values...> operator~() const
{
return {};
}
};

/**
Expand Down Expand Up @@ -88,6 +146,89 @@ namespace xsimd
{
return values[i];
}

struct arithmetic_add
{
constexpr value_type operator()(value_type x, value_type y) const { return x + y; }
};
struct arithmetic_sub
{
constexpr value_type operator()(value_type x, value_type y) const { return x - y; }
};
struct arithmetic_mul
{
constexpr value_type operator()(value_type x, value_type y) const { return x * y; }
};
struct arithmetic_div
{
constexpr value_type operator()(value_type x, value_type y) const { return x / y; }
};
struct arithmetic_mod
{
constexpr value_type operator()(value_type x, value_type y) const { return x % y; }
};
struct binary_and
{
constexpr value_type operator()(value_type x, value_type y) const { return x & y; }
};
struct binary_or
{
constexpr value_type operator()(value_type x, value_type y) const { return x | y; }
};
struct binary_xor
{
constexpr value_type operator()(value_type x, value_type y) const { return x ^ y; }
};

template <class F, class SelfPack, class OtherPack, size_t... Indices>
static constexpr batch_constant<batch_type, F()(std::tuple_element<Indices, SelfPack>::type::value, std::tuple_element<Indices, OtherPack>::type::value)...>
apply(detail::index_sequence<Indices...>)
{
return {};
}

template <class F, value_type... OtherValues>
static constexpr auto apply(batch_constant<batch_type, Values...>, batch_constant<batch_type, OtherValues...>)
-> decltype(apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>()))
{
static_assert(sizeof...(Values) == sizeof...(OtherValues), "compatible constant batches");
return apply<F, std::tuple<std::integral_constant<value_type, Values>...>, std::tuple<std::integral_constant<value_type, OtherValues>...>>(detail::make_index_sequence<sizeof...(Values)>());
}

public:
#define MAKE_BINARY_OP(OP, NAME) \
template <value_type... OtherValues> \
constexpr auto operator OP(batch_constant<batch_type, OtherValues...> other) const \
->decltype(apply<NAME>(*this, other)) \
{ \
return apply<NAME>(*this, other); \
}

MAKE_BINARY_OP(+, arithmetic_add)
MAKE_BINARY_OP(-, arithmetic_sub)
MAKE_BINARY_OP(*, arithmetic_mul)
MAKE_BINARY_OP(/, arithmetic_div)
MAKE_BINARY_OP(%, arithmetic_mod)
MAKE_BINARY_OP(&, binary_and)
MAKE_BINARY_OP(|, binary_or)
MAKE_BINARY_OP(^, binary_xor)

#undef MAKE_BINARY_OP

constexpr batch_constant<batch_type, (value_type)-Values...> operator-() const
{
return {};
}

constexpr batch_constant<batch_type, (value_type) + Values...> operator+() const
{
return {};
}

constexpr batch_constant<batch_type, (value_type)~Values...> operator~() const
{
return {};
}
};

namespace detail
Expand Down
109 changes: 106 additions & 3 deletions test/test_batch_constant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,69 @@ struct constant_batch_test
CHECK_BATCH_EQ((batch_type)b, expected);
}

template <value_type V>
struct constant
{
static constexpr value_type get(size_t /*index*/, size_t /*size*/)
{
return 3;
return V;
}
};

void test_init_from_constant() const
{
array_type expected;
std::fill(expected.begin(), expected.end(), constant::get(0, 0));
constexpr auto b = xsimd::make_batch_constant<batch_type, constant>();
std::fill(expected.begin(), expected.end(), constant<3>::get(0, 0));
constexpr auto b = xsimd::make_batch_constant<batch_type, constant<3>>();
INFO("batch(value_type)");
CHECK_BATCH_EQ((batch_type)b, expected);
}

void test_ops() const
{
constexpr auto n12 = xsimd::make_batch_constant<batch_type, constant<12>>();
constexpr auto n3 = xsimd::make_batch_constant<batch_type, constant<3>>();

constexpr auto n12_add_n3 = n12 + n3;
constexpr auto n15 = xsimd::make_batch_constant<batch_type, constant<15>>();
static_assert(std::is_same<decltype(n12_add_n3), decltype(n15)>::value, "n12 + n3 == n15");

constexpr auto n12_sub_n3 = n12 - n3;
constexpr auto n9 = xsimd::make_batch_constant<batch_type, constant<9>>();
static_assert(std::is_same<decltype(n12_sub_n3), decltype(n9)>::value, "n12 - n3 == n9");

constexpr auto n12_mul_n3 = n12 * n3;
constexpr auto n36 = xsimd::make_batch_constant<batch_type, constant<36>>();
static_assert(std::is_same<decltype(n12_mul_n3), decltype(n36)>::value, "n12 * n3 == n36");

constexpr auto n12_div_n3 = n12 / n3;
constexpr auto n4 = xsimd::make_batch_constant<batch_type, constant<4>>();
static_assert(std::is_same<decltype(n12_div_n3), decltype(n4)>::value, "n12 / n3 == n4");

constexpr auto n12_mod_n3 = n12 % n3;
constexpr auto n0 = xsimd::make_batch_constant<batch_type, constant<0>>();
static_assert(std::is_same<decltype(n12_mod_n3), decltype(n0)>::value, "n12 % n3 == n0");

constexpr auto n12_land_n3 = n12 & n3;
static_assert(std::is_same<decltype(n12_land_n3), decltype(n0)>::value, "n12 & n3 == n0");

constexpr auto n12_lor_n3 = n12 | n3;
static_assert(std::is_same<decltype(n12_lor_n3), decltype(n15)>::value, "n12 | n3 == n15");

constexpr auto n12_lxor_n3 = n12 ^ n3;
static_assert(std::is_same<decltype(n12_lxor_n3), decltype(n15)>::value, "n12 ^ n3 == n15");

constexpr auto n12_uadd = +n12;
static_assert(std::is_same<decltype(n12_uadd), decltype(n12)>::value, "+n12 == n12");

constexpr auto n12_inv = ~n12;
constexpr auto n12_inv_ = xsimd::make_batch_constant<batch_type, constant<(value_type)~12>>();
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "~n12 == n12_inv");

constexpr auto n12_usub = -n12;
constexpr auto n12_usub_ = xsimd::make_batch_constant<batch_type, constant<(value_type)-12>>();
static_assert(std::is_same<decltype(n12_inv), decltype(n12_inv_)>::value, "-n12 == n12_usub");
}
};

TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
Expand All @@ -93,6 +140,11 @@ TEST_CASE_TEMPLATE("[constant batch]", B, BATCH_INT_TYPES)
}

SUBCASE("init_from_constant") { Test.test_init_from_constant(); }

SUBCASE("operators")
{
Test.test_ops();
}
}

template <class B>
Expand Down Expand Up @@ -144,6 +196,53 @@ struct constant_bool_batch_test
INFO("batch_bool_constant(value_type)");
CHECK_BATCH_EQ((batch_bool_type)b, expected);
}

struct inv_split
{
static constexpr bool get(size_t index, size_t size)
{
return !split().get(index, size);
}
};

template <bool Val>
struct constant
{
static constexpr bool get(size_t /*index*/, size_t /*size*/)
{
return Val;
}
};

void test_ops() const
{
constexpr auto all_true = xsimd::make_batch_bool_constant<batch_type, constant<true>>();
constexpr auto all_false = xsimd::make_batch_bool_constant<batch_type, constant<false>>();

constexpr auto x = xsimd::make_batch_bool_constant<batch_type, split>();
constexpr auto y = xsimd::make_batch_bool_constant<batch_type, inv_split>();

constexpr auto x_or_y = x | y;
static_assert(std::is_same<decltype(x_or_y), decltype(all_true)>::value, "x | y == true");

constexpr auto x_lor_y = x || y;
static_assert(std::is_same<decltype(x_lor_y), decltype(all_true)>::value, "x || y == true");

constexpr auto x_and_y = x & y;
static_assert(std::is_same<decltype(x_and_y), decltype(all_false)>::value, "x & y == false");

constexpr auto x_land_y = x && y;
static_assert(std::is_same<decltype(x_land_y), decltype(all_false)>::value, "x && y == false");

constexpr auto x_xor_y = x ^ y;
static_assert(std::is_same<decltype(x_xor_y), decltype(all_true)>::value, "x ^ y == true");

constexpr auto not_x = !x;
static_assert(std::is_same<decltype(not_x), decltype(y)>::value, "!x == y");

constexpr auto inv_x = ~x;
static_assert(std::is_same<decltype(inv_x), decltype(y)>::value, "~x == y");
}
};

TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
Expand All @@ -155,5 +254,9 @@ TEST_CASE_TEMPLATE("[constant bool batch]", B, BATCH_INT_TYPES)
{
Test.test_init_from_generator_split();
}
SUBCASE("operators")
{
Test.test_ops();
}
}
#endif
Loading