Skip to content

Commit

Permalink
#2240: Working RecursiveDoubling with multiple allreduce in flight
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 7aeaa8e commit 593132f
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 188 deletions.
116 changes: 48 additions & 68 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,6 @@
namespace vt::collective::reduce::allreduce {

template <typename DataT>
struct AllreduceDblMsg
: SerializeIfNeeded<vt::Message, AllreduceDblMsg<DataT>, DataT> {
using MessageParentType =
SerializeIfNeeded<::vt::Message, AllreduceDblMsg<DataT>, DataT>;

AllreduceDblMsg() = default;
AllreduceDblMsg(AllreduceDblMsg const&) = default;
AllreduceDblMsg(AllreduceDblMsg&&) = default;

AllreduceDblMsg(DataT&& in_val, int step = 0)
: MessageParentType(),
val_(std::forward<DataT>(in_val)),
step_(step) { }
AllreduceDblMsg(DataT const& in_val, int step = 0)
: MessageParentType(),
val_(in_val),
step_(step) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);
s | val_;
s | step_;
}

DataT val_ = {};
int32_t step_ = {};
};

template <typename Scalar>
struct AllreduceDblRawMsg
: Message {
using MessageParentType = vt::Message;
Expand All @@ -99,34 +69,32 @@ struct AllreduceDblRawMsg
AllreduceDblRawMsg(AllreduceDblRawMsg&&) = default;
~AllreduceDblRawMsg() {
if (owning_) {
delete[] val_;
delete val_;
}
}

AllreduceDblRawMsg(std::vector<Scalar>& in_val, int step = 0)
AllreduceDblRawMsg(DataT const& in_val, size_t id, int step = 0)
: MessageParentType(),
val_(in_val.data()),
size_(in_val.size()),
val_(&in_val),
id_(id),
step_(step) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);

s | size_;

if (s.isUnpacking()) {
owning_ = true;
val_ = new Scalar[size_];
val_ = new DataT();
}

checkpoint::dispatch::serializeArray(s, val_, size_);

s | *val_;
s | id_;
s | step_;
}

Scalar* val_ = {};
size_t size_ = {};
const DataT* val_ = {};
size_t id_ = {};
int32_t step_ = {};
bool owning_ = false;
};
Expand Down Expand Up @@ -158,103 +126,126 @@ struct RecursiveDoubling {
* \param num_nodes The number of nodes.
* \param args Additional arguments for data initialization.
*/
template <typename... Args>
RecursiveDoubling(
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
const DataT& data);
Args&&... data);

/**
* \brief Start the allreduce operation.
*/
void allreduce();
void allreduce(size_t id);

/**
* \brief Initialize the RecursiveDoubling object.
*
* \param args Additional arguments for data initialization.
*/
void initialize(const DataT& data);
template <typename... Args>
void initialize(size_t id, Args&&... data);
void initializeState(size_t id);

size_t generateNewId() { return id_++; }

/**
* \brief Adjust for power of two nodes.
*/
void adjustForPowerOfTwo();
void adjustForPowerOfTwo(size_t id);

/**
* \brief Handler for adjusting for power of two nodes.
*
* \param msg Pointer to the message.
*/
void adjustForPowerOfTwoHandler(AllreduceDblRawMsg<Scalar>* msg);
void adjustForPowerOfTwoHandler(AllreduceDblRawMsg<DataT>* msg);

/**
* \brief Check if the allreduce operation is done.
*
* \return True if the operation is done, otherwise false.
*/
bool done();
bool isDone(size_t id);

/**
* \brief Check if the current state is valid for allreduce.
*
* \return True if the state is valid, otherwise false.
*/
bool isValid();
bool isValid(size_t id);

/**
* \brief Check if all messages are received for the current step.
*
* \return True if all messages are received, otherwise false.
*/
bool allMessagesReceived();
bool allMessagesReceived(size_t id);

/**
* \brief Check if the object is ready for the next step of allreduce.
*
* \return True if ready, otherwise false.
*/
bool isReady();
bool isReady(size_t id);

/**
* \brief Perform the next step of the allreduce operation.
*/
void reduceIter();
void reduceIter(size_t id);

/**
* \brief Try to reduce the message at the specified step.
*
* \param step The step at which to try reduction.
*/
void tryReduce(int32_t step);
void tryReduce(size_t id, int32_t step);

/**
* \brief Handler for the reduce iteration.
*
* \param msg Pointer to the message.
*/
void reduceIterHandler(AllreduceDblRawMsg<Scalar>* msg);
void reduceIterHandler(AllreduceDblRawMsg<DataT>* msg);

/**
* \brief Send data to excluded nodes for finalization.
*/
void sendToExcludedNodes();
void sendToExcludedNodes(size_t id);

/**
* \brief Handler for sending data to excluded nodes.
*
* \param msg Pointer to the message.
*/
void sendToExcludedNodesHandler(AllreduceDblRawMsg<Scalar>* msg);
void sendToExcludedNodesHandler(AllreduceDblRawMsg<DataT>* msg);

/**
* \brief Perform the final part of the allreduce operation.
*/
void finalPart();
void finalPart(size_t id);

vt::objgroup::proxy::Proxy<RecursiveDoubling> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

// DataT val_ = {};
std::vector<Scalar> val_;

struct State{
DataT val_ = {};
bool finished_adjustment_part_ = false;
MsgSharedPtr<AllreduceDblRawMsg<DataT>> adjust_message_ = nullptr;

int32_t mask_ = 1;
int32_t step_ = 0;
bool initialized_ = false;
bool completed_ = false;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_reduced_ = {};
std::vector<MsgSharedPtr<AllreduceDblRawMsg<DataT>>> messages_ = {};
};

size_t id_ = 0;
std::unordered_map<size_t, State> states_ = {};

NodeType num_nodes_ = {};
NodeType this_node_ = {};

Expand All @@ -265,17 +256,6 @@ struct RecursiveDoubling {

NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
bool finished_adjustment_part_ = false;

int32_t mask_ = 1;
int32_t step_ = 0;

bool completed_ = false;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_reduced_ = {};

std::vector<MsgSharedPtr<AllreduceDblRawMsg<Scalar>>> messages_ = {};
};

} // namespace vt::collective::reduce::allreduce
Expand Down
Loading

0 comments on commit 593132f

Please sign in to comment.