Skip to content

Commit

Permalink
#2240: Working allreduce perf test with Kokkos
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 51ea74e commit 7aeaa8e
Show file tree
Hide file tree
Showing 10 changed files with 468 additions and 162 deletions.
89 changes: 66 additions & 23 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,38 +45,60 @@
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H

namespace vt::collective::reduce::allreduce {
#include <vector>

#ifdef VT_KOKKOS_ENABLED
#include <Kokkos_Core.hpp>
#endif

template <typename Container>
namespace vt::collective::reduce::allreduce {

template <typename DataType, typename Enable = void>
class DataHandler {
public:
using Scalar = float;
using Scalar = void;
};

template <typename Scalar>
class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::value>::type> {
public:
using ScalarType = Scalar;

static size_t size(const Container& data);
static Scalar& at(Container& data, size_t idx);
static void set(Container& data, size_t idx, const Scalar& value);
static Container split(Container& data, size_t start, size_t end);
static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
static ScalarType fromMemory(ScalarType* data, size_t count) {
return *data;
}

// static const ScalarType* data(const ScalarType& data) { return &data; }
// static size_t size(const ScalarType&) { return 1; }
// static ScalarType& at(ScalarType& data, size_t) { return data; }
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
};

template <typename T>
class DataHandler<std::vector<T>> {
public:
using UnderlyingType = std::vector<T>;
using Scalar = T;
static size_t size(const std::vector<T>& data) { return data.size(); }
static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
static void set(std::vector<T>& data, size_t idx, const T& value) {
data[idx] = value;
}
static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
return std::vector<T>{data.begin() + start, data.begin() + end};

static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromVec(const std::vector<T>& data) { return data; }
static std::vector<T> fromMemory(T* data, size_t count) {
return std::vector<T>(data, data + count);
}

// static const T* data(const std::vector<T>& data) {return data.data(); }
static size_t size(const std::vector<T>& data) { return data.size(); }
// static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
// static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
// static void set(std::vector<T>& data, size_t idx, const T& value) {
// data[idx] = value;
// }
// static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
// return std::vector<T>{data.begin() + start, data.begin() + end};
// }
};

#if KOKKOS_ENABLED_CHECKPOINT
Expand All @@ -88,19 +110,40 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
public:
using Scalar = T;

static size_t size(const ViewType& data) { return data.extent(0); }
static std::vector<T> toVec(const ViewType& data) {
std::vector<T> vec;
vec.resize(data.extent(0));
std::memcpy(vec.data(), data.data(), data.extent(0) * sizeof(T));
return vec;
}

static T at(const ViewType& data, size_t idx) { return data(idx); }
static ViewType fromMemory(T* data, size_t size) {
return ViewType(data, size);
}

static T& at(ViewType& data, size_t idx) { return data(idx); }
static ViewType fromVec(const std::vector<T>& data) {
ViewType view("", data.size());
Kokkos::parallel_for(
"InitView", view.extent(0),
KOKKOS_LAMBDA(const int i) { view(i) = static_cast<float>(data[i]); });

static void set(ViewType& data, size_t idx, const T& value) {
data(idx) = value;
return view;
}

static ViewType split(ViewType& data, size_t start, size_t end) {
return Kokkos::subview(data, std::make_pair(start, end));
}
// static const T* data(const ViewType& data) {return data.data(); }
static size_t size(const ViewType& data) { return data.extent(0); }

// static T at(const ViewType& data, size_t idx) { return data(idx); }

// static T& at(ViewType& data, size_t idx) { return data(idx); }

// static void set(ViewType& data, size_t idx, const T& value) {
// data(idx) = value;
// }

// static ViewType split(ViewType& data, size_t start, size_t end) {
// return Kokkos::subview(data, std::make_pair(start, end));
// }
};

#endif // KOKKOS_ENABLED_CHECKPOINT
Expand Down
76 changes: 61 additions & 15 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ struct AllreduceRbnMsg
int32_t step_ = {};
};

template <typename Scalar>
struct AllreduceRbnRawMsg
: Message {
using MessageParentType = vt::Message;
vt_msg_serialize_required();


AllreduceRbnRawMsg() = default;
AllreduceRbnRawMsg(AllreduceRbnRawMsg const&) = default;
AllreduceRbnRawMsg(AllreduceRbnRawMsg&&) = default;
~AllreduceRbnRawMsg() {
if (owning_) {
delete[] val_;
}
}

AllreduceRbnRawMsg(Scalar* in_val, size_t size, int step = 0)
: MessageParentType(),
val_(in_val),
size_(size),
step_(step) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);

s | size_;

if (s.isUnpacking()) {
owning_ = true;
val_ = new Scalar[size_];
}

checkpoint::dispatch::serializeArray(s, val_, size_);

s | step_;
}

Scalar* val_ = {};
size_t size_ = {};
int32_t step_ = {};
bool owning_ = false;
};

/**
* \struct Rabenseifner
* \brief Class implementing Rabenseifner's allreduce algorithm.
Expand All @@ -103,6 +147,7 @@ template <
>
struct Rabenseifner {
using DataType = DataHandler<DataT>;
using Scalar = typename DataType::Scalar;

/**
* \brief Constructor for Rabenseifner's allreduce algorithm.
Expand All @@ -111,7 +156,7 @@ struct Rabenseifner {
* \param num_nodes Total number of nodes involved in the allreduce operation.
* \param args Additional arguments for initializing the data value.
*/
template <typename... Args>
template <typename ...Args>
Rabenseifner(
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
Args&&... args);
Expand All @@ -123,7 +168,7 @@ struct Rabenseifner {
*
* \param args Additional arguments for initializing the data value.
*/
template <typename... Args>
template <typename ...Args>
void initialize(Args&&... args);

/**
Expand Down Expand Up @@ -153,7 +198,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
void adjustForPowerOfTwoRightHalf(AllreduceRbnMsg<DataT>* msg);
void adjustForPowerOfTwoRightHalf(AllreduceRbnRawMsg<Scalar>* msg);

/**
* \brief Handler for adjusting the left half of the process group.
Expand All @@ -162,7 +207,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
void adjustForPowerOfTwoLeftHalf(AllreduceRbnMsg<DataT>* msg);
void adjustForPowerOfTwoLeftHalf(AllreduceRbnRawMsg<Scalar>* msg);

/**
* \brief Final adjustment step for non-power-of-two process counts.
Expand All @@ -171,7 +216,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
void adjustForPowerOfTwoFinalPart(AllreduceRbnMsg<DataT>* msg);
void adjustForPowerOfTwoFinalPart(AllreduceRbnRawMsg<Scalar>* msg);

/**
* \brief Check if all scatter messages have been received.
Expand Down Expand Up @@ -215,7 +260,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
void scatterReduceIterHandler(AllreduceRbnMsg<DataT>* msg);
void scatterReduceIterHandler(AllreduceRbnRawMsg<Scalar>* msg);

/**
* \brief Check if all gather messages have been received.
Expand Down Expand Up @@ -259,7 +304,7 @@ struct Rabenseifner {
*
* \param msg Message containing the data from the partner process.
*/
void gatherIterHandler(AllreduceRbnMsg<DataT>* msg);
void gatherIterHandler(AllreduceRbnRawMsg<Scalar>* msg);

/**
* \brief Perform the final part of the allreduce operation.
Expand All @@ -282,12 +327,13 @@ struct Rabenseifner {
*
* \param msg Message containing the final result.
*/
void sendToExcludedNodesHandler(AllreduceRbnMsg<DataT>* msg);
void sendToExcludedNodesHandler(AllreduceRbnRawMsg<Scalar>* msg);

vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

DataT val_ = {};
// DataT val_ = {};
std::vector<Scalar> val_;
size_t size_ = {};
NodeType num_nodes_ = {};
NodeType this_node_ = {};
Expand All @@ -297,10 +343,10 @@ struct Rabenseifner {
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};

std::vector<int32_t> r_index_ = {};
std::vector<int32_t> r_count_ = {};
std::vector<int32_t> s_index_ = {};
std::vector<int32_t> s_count_ = {};
std::vector<uint32_t> r_index_ = {};
std::vector<uint32_t> r_count_ = {};
std::vector<uint32_t> s_index_ = {};
std::vector<uint32_t> s_count_ = {};

NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
Expand All @@ -314,7 +360,7 @@ struct Rabenseifner {
int32_t scatter_num_recv_ = 0;
std::vector<bool> scatter_steps_recv_ = {};
std::vector<bool> scatter_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT>>> scatter_messages_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ = {};
bool finished_scatter_part_ = false;

// Gather
Expand All @@ -323,7 +369,7 @@ struct Rabenseifner {
int32_t gather_num_recv_ = 0;
std::vector<bool> gather_steps_recv_ = {};
std::vector<bool> gather_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnMsg<DataT>>> gather_messages_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ = {};
};

} // namespace vt::collective::reduce::allreduce
Expand Down
Loading

0 comments on commit 7aeaa8e

Please sign in to comment.