Skip to content

Commit

Permalink
#2240: Add unit test for Rabenseifner with Kokkos::View as DataType a…
Browse files Browse the repository at this point in the history
…nd fix compile issues realted to using Kokkos::View for allreduce
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 1f86e55 commit 699a6d8
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
34 changes: 20 additions & 14 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,25 +79,31 @@ class DataHandler<std::vector<T>> {
}
};

#ifdef VT_KOKKOS_ENABLED
#if KOKKOS_ENABLED_CHECKPOINT

template <typename T, typename... Props>
class DataHandler<Kokkos::View<T*, Props...>> {
class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
using ViewType = Kokkos::View<T*, Kokkos::HostSpace, Props...>;

public:
static size_t size(const Kokkos::View<T*, Props...>& data) {
return data.extent(0);
}
static T at(const Kokkos::View<T*, Props...>& data, size_t idx) {
return data(idx);
}
static T& at(Kokkos::View<T*, Props...>& data, size_t idx) {
return data(idx);
}
static void
set(Kokkos::View<T*, Props...>& data, size_t idx, const T& value) {
using Scalar = T;

static size_t size(const ViewType& data) { return data.extent(0); }

static T at(const ViewType& data, size_t idx) { return data(idx); }

static T& at(ViewType& data, size_t idx) { return data(idx); }

static void set(ViewType& data, size_t idx, const T& value) {
data(idx) = value;
}

static ViewType split(ViewType& data, size_t start, size_t end) {
return Kokkos::subview(data, std::make_pair(start, end));
}
};
#endif // VT_KOKKOS_ENABLED

#endif // KOKKOS_ENABLED_CHECKPOINT

} // namespace vt::collective::reduce::allreduce

Expand Down
47 changes: 47 additions & 0 deletions tests/unit/objgroup/test_objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,53 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
EXPECT_EQ(MyObjA::total_verify_expected_, 6);
}

#if KOKKOS_ENABLED_CHECKPOINT
struct TestObjGroupKokkos : TestParallelHarness {
void SetUp() override {
TestParallelHarness::SetUp();

Kokkos::initialize();

SET_MIN_NUM_NODES_CONSTRAINT(2);
}

void TearDown() override {
TestParallelHarness::TearDown();

Kokkos::finalize();
}
};

TEST_F(TestObjGroupKokkos, test_proxy_allreduce_kokkos) {
using namespace vt::collective;

TestObjGroup::total_verify_expected_ = 0;
auto const my_node = vt::theContext()->getNode();

auto kokkos_proxy =
vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_reduce_kokkos");

vt::theCollective()->barrier();

runInEpochCollective([&] {
Kokkos::View<float*, Kokkos::HostSpace> view("view", 256);
Kokkos::parallel_for(
"InitView", view.extent(0),
KOKKOS_LAMBDA(const int i) { view(i) = static_cast<float>(my_node); });

using Reducer = vt::collective::reduce::allreduce::Rabenseifner<
decltype(view), PlusOp, MyObjA, &MyObjA::verifyAllredView>;

theObjGroup()
->allreduce<Reducer, &MyObjA::verifyAllredView, MyObjA, PlusOp>(
kokkos_proxy, view);
});

EXPECT_EQ(MyObjA::total_verify_expected_, 1);
}
#endif // KOKKOS_ENABLED_CHECKPOINT


TEST_F(TestObjGroup, test_proxy_invoke) {
auto const& this_node = theContext()->getNode();

Expand Down
16 changes: 16 additions & 0 deletions tests/unit/objgroup/test_objgroup_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,21 @@ struct MyObjA {

void verifyAllredVecPayload(VectorPayload vec) { verifyAllredVec(vec.vec_); }

#if KOKKOS_ENABLED_CHECKPOINT
void verifyAllredView(Kokkos::View<float*, Kokkos::HostSpace> view) {
auto final_size = view.extent(0);
EXPECT_EQ(final_size, 256);

auto n = vt::theContext()->getNumNodes();
auto const total_sum = n * (n - 1) / 2;
Kokkos::parallel_for("InitView", view.extent(0), KOKKOS_LAMBDA(const int i) {
EXPECT_EQ(view(i), total_sum);
});

total_verify_expected_++;
}
#endif

int id_ = -1;
int recv_ = 0;
static int next_id;
Expand Down Expand Up @@ -208,6 +223,7 @@ class DataHandler<tests::unit::VectorPayload> {
return DataT{UnderlyingType{data.vec_.begin() + start, data.vec_.begin() + end}};
}
};

} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_UNIT_OBJGROUP_TEST_OBJGROUP_COMMON_H*/

0 comments on commit 699a6d8

Please sign in to comment.