Skip to content

Commit

Permalink
#2240: Update Rabenseifner to use ID for each allreduce and update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 593132f commit da46b94
Show file tree
Hide file tree
Showing 11 changed files with 417 additions and 359 deletions.
14 changes: 7 additions & 7 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

#include <vector>

#ifdef VT_KOKKOS_ENABLED
#ifdef KOKKOS_ENABLED_CHECKPOINT
#include <Kokkos_Core.hpp>
#endif

Expand All @@ -57,21 +57,22 @@ template <typename DataType, typename Enable = void>
class DataHandler {
public:
using Scalar = void;
static size_t size(void) { return 0; }
};

template <typename Scalar>
class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::value>::type> {
template <typename ScalarType>
class DataHandler<ScalarType, typename std::enable_if<std::is_arithmetic<ScalarType>::value>::type> {
public:
using ScalarType = Scalar;
using Scalar = ScalarType;

static std::vector<ScalarType> toVec(const ScalarType& data) { return std::vector<ScalarType>{data}; }
static ScalarType fromVec(const std::vector<ScalarType>& data) { return data[0]; }
static ScalarType fromMemory(ScalarType* data, size_t count) {
static ScalarType fromMemory(ScalarType* data, size_t) {
return *data;
}

// static const ScalarType* data(const ScalarType& data) { return &data; }
// static size_t size(const ScalarType&) { return 1; }
static size_t size(const ScalarType&) { return 1; }
// static ScalarType& at(ScalarType& data, size_t) { return data; }
// static void set(ScalarType& data, size_t, const ScalarType& value) { data = value; }
// static ScalarType split(ScalarType&, size_t, size_t) { return ScalarType{}; }
Expand All @@ -80,7 +81,6 @@ class DataHandler<Scalar, typename std::enable_if<std::is_arithmetic<Scalar>::va
template <typename T>
class DataHandler<std::vector<T>> {
public:
using UnderlyingType = std::vector<T>;
using Scalar = T;

static const std::vector<T>& toVec(const std::vector<T>& data) { return data; }
Expand Down
138 changes: 63 additions & 75 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,36 +56,6 @@

namespace vt::collective::reduce::allreduce {

template <typename DataT>
struct AllreduceRbnMsg
: SerializeIfNeeded<vt::Message, AllreduceRbnMsg<DataT>, DataT> {
using MessageParentType =
SerializeIfNeeded<::vt::Message, AllreduceRbnMsg<DataT>, DataT>;

AllreduceRbnMsg() = default;
AllreduceRbnMsg(AllreduceRbnMsg const&) = default;
AllreduceRbnMsg(AllreduceRbnMsg&&) = default;

AllreduceRbnMsg(DataT&& in_val, int step = 0)
: MessageParentType(),
val_(std::forward<DataT>(in_val)),
step_(step) { }
AllreduceRbnMsg(DataT const& in_val, int step = 0)
: MessageParentType(),
val_(in_val),
step_(step) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);
s | val_;
s | step_;
}

DataT val_ = {};
int32_t step_ = {};
};

template <typename Scalar>
struct AllreduceRbnRawMsg
: Message {
Expand All @@ -102,10 +72,11 @@ struct AllreduceRbnRawMsg
}
}

AllreduceRbnRawMsg(Scalar* in_val, size_t size, int step = 0)
AllreduceRbnRawMsg(Scalar* in_val, size_t size, size_t id, int step = 0)
: MessageParentType(),
val_(in_val),
size_(size),
id_(id),
step_(step) { }

template <typename SerializeT>
Expand All @@ -121,11 +92,13 @@ struct AllreduceRbnRawMsg

checkpoint::dispatch::serializeArray(s, val_, size_);

s | id_;
s | step_;
}

Scalar* val_ = {};
size_t size_ = {};
size_t id_ = {};
int32_t step_ = {};
bool owning_ = false;
};
Expand Down Expand Up @@ -169,27 +142,30 @@ struct Rabenseifner {
* \param args Additional arguments for initializing the data value.
*/
template <typename ...Args>
void initialize(Args&&... args);
void initialize(size_t id, Args&&... args);

void initializeState(size_t id);
size_t generateNewId() { return id_++; }

/**
* \brief Execute the final handler callback with the reduced result.
*/
void executeFinalHan();
void executeFinalHan(size_t id);

/**
* \brief Perform the allreduce operation.
*
* This function starts the allreduce operation, adjusting for non-power-of-two process counts if necessary.
*/
void allreduce();
void allreduce(size_t id);

/**
* \brief Adjust the process count to the nearest power-of-two.
*
* This function performs additional steps to handle non-power-of-two process counts, ensuring that the
* main scatter-reduce and gather-allgather phases can proceed with a power-of-two number of processes.
*/
void adjustForPowerOfTwo();
void adjustForPowerOfTwo(size_t id);

/**
* \brief Handler for adjusting the right half of the process group.
Expand Down Expand Up @@ -223,35 +199,35 @@ struct Rabenseifner {
*
* \return True if all scatter messages have been received, false otherwise.
*/
bool scatterAllMessagesReceived();
bool scatterAllMessagesReceived(size_t id);

/**
* \brief Check if the scatter phase is complete.
*
* \return True if the scatter phase is complete, false otherwise.
*/
bool scatterIsDone();
bool scatterIsDone(size_t id);

/**
* \brief Check if the scatter phase is ready to proceed.
*
* \return True if the scatter phase is ready to proceed, false otherwise.
*/
bool scatterIsReady();
bool scatterIsReady(size_t id);

/**
* \brief Try to reduce the received scatter messages.
*
* \param step The current step in the scatter phase.
*/
void scatterTryReduce(int32_t step);
void scatterTryReduce(size_t id, int32_t step);

/**
* \brief Perform the scatter-reduce iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the scatter phase.
*/
void scatterReduceIter();
void scatterReduceIter(size_t id);

/**
* \brief Handler for receiving scatter-reduce messages.
Expand All @@ -267,35 +243,35 @@ struct Rabenseifner {
*
* \return True if all gather messages have been received, false otherwise.
*/
bool gatherAllMessagesReceived();
bool gatherAllMessagesReceived(size_t id);

/**
* \brief Check if the gather phase is complete.
*
* \return True if the gather phase is complete, false otherwise.
*/
bool gatherIsDone();
bool gatherIsDone(size_t id);

/**
* \brief Check if the gather phase is ready to proceed.
*
* \return True if the gather phase is ready to proceed, false otherwise.
*/
bool gatherIsReady();
bool gatherIsReady(size_t id);

/**
* \brief Try to reduce the received gather messages.
*
* \param step The current step in the gather phase.
*/
void gatherTryReduce(int32_t step);
void gatherTryReduce(size_t id, int32_t step);

/**
* \brief Perform the gather iteration.
*
* This function sends data to the appropriate partner process and proceeds to the next step in the gather phase.
*/
void gatherIter();
void gatherIter(size_t id);

/**
* \brief Handler for receiving gather messages.
Expand All @@ -311,14 +287,14 @@ struct Rabenseifner {
*
* This function completes the allreduce operation, handling any remaining steps and invoking the final handler.
*/
void finalPart();
void finalPart(size_t id);

/**
* \brief Send the result to excluded nodes.
*
* This function handles the final step for non-power-of-two process counts, sending the reduced result to excluded nodes.
*/
void sendToExcludedNodes();
void sendToExcludedNodes(size_t id);

/**
* \brief Handler for receiving the final result on excluded nodes.
Expand All @@ -332,9 +308,46 @@ struct Rabenseifner {
vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

// DataT val_ = {};
std::vector<Scalar> val_;
size_t size_ = {};
struct State {
std::vector<Scalar> val_ = {};
size_t size_ = {};

bool finished_adjustment_part_ = false;
MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> left_adjust_message_ = nullptr;
MsgSharedPtr<AllreduceRbnRawMsg<Scalar>> right_adjust_message_ = nullptr;

int32_t mask_ = 1;
int32_t step_ = 0;
bool initialized_ = false;
bool completed_ = false;

// Scatter
int32_t scatter_mask_ = 1;
int32_t scatter_step_ = 0;
int32_t scatter_num_recv_ = 0;
std::vector<bool> scatter_steps_recv_ = {};
std::vector<bool> scatter_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ =
{};
bool finished_scatter_part_ = false;

// Gather
int32_t gather_step_ = 0;
int32_t gather_mask_ = 1;
int32_t gather_num_recv_ = 0;
std::vector<bool> gather_steps_recv_ = {};
std::vector<bool> gather_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ =
{};

std::vector<uint32_t> r_index_ = {};
std::vector<uint32_t> r_count_ = {};
std::vector<uint32_t> s_index_ = {};
std::vector<uint32_t> s_count_ = {};
};

size_t id_ = 0;
std::unordered_map<size_t, State> states_ = {};
NodeType num_nodes_ = {};
NodeType this_node_ = {};

Expand All @@ -343,33 +356,8 @@ struct Rabenseifner {
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};

std::vector<uint32_t> r_index_ = {};
std::vector<uint32_t> r_count_ = {};
std::vector<uint32_t> s_index_ = {};
std::vector<uint32_t> s_count_ = {};

NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
bool finished_adjustment_part_ = false;

bool completed_ = false;

// Scatter
int32_t scatter_mask_ = 1;
int32_t scatter_step_ = 0;
int32_t scatter_num_recv_ = 0;
std::vector<bool> scatter_steps_recv_ = {};
std::vector<bool> scatter_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> scatter_messages_ = {};
bool finished_scatter_part_ = false;

// Gather
int32_t gather_step_ = 0;
int32_t gather_mask_ = 1;
int32_t gather_num_recv_ = 0;
std::vector<bool> gather_steps_recv_ = {};
std::vector<bool> gather_steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceRbnRawMsg<Scalar>>> gather_messages_ = {};
};

} // namespace vt::collective::reduce::allreduce
Expand Down
Loading

0 comments on commit da46b94

Please sign in to comment.