Skip to content

Commit

Permalink
#2240: Provide documentation for RecursiveDoubling algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 18d5090 commit 9d88016
Show file tree
Hide file tree
Showing 4 changed files with 414 additions and 189 deletions.
303 changes: 116 additions & 187 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,199 +88,126 @@ struct AllreduceDblMsg
int32_t step_ = {};
};

/**
* \brief Class implementing the Recursive Doubling algorithm for allreduce operation.
*
* This class provides an implementation of the Recursive Doubling algorithm for the
* allreduce operation. It is parameterized by the data type to be reduced, the reduction
* operation, the object type, and the final handler.
*
* \tparam DataT The data type to be reduced.
* \tparam Op The reduction operation type.
* \tparam ObjT The object type.
* \tparam finalHandler The final handler.
*/
template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
struct DistanceDoubling {
struct RecursiveDoubling {
/**
* \brief Constructor for RecursiveDoubling class.
*
* Initializes the RecursiveDoubling object with the provided parameters.
*
* \param parentProxy The parent proxy.
* \param num_nodes The number of nodes.
* \param args Additional arguments for data initialization.
*/
template <typename... Args>
DistanceDoubling(
RecursiveDoubling(
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
Args&&... args)
: parent_proxy_(parentProxy),
num_nodes_(num_nodes),
this_node_(vt::theContext()->getNode()),
is_even_(this_node_ % 2 == 0),
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_),
finished_adjustment_part_(nprocs_rem_ == 0) {
initialize(std::forward<Args>(args)...);
}

Args&&... args);

/**
* \brief Start the allreduce operation.
*/
void allreduce();

/**
* \brief Initialize the RecursiveDoubling object.
*
* \param args Additional arguments for data initialization.
*/
template <typename... Args>
void initialize(Args&&... args) {
val_ = DataT(std::forward<Args>(args)...);
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

messages_.resize(num_steps_, nullptr);
steps_recv_.resize(num_steps_, false);
steps_reduced_.resize(num_steps_, false);
}

void allreduce() {
if (nprocs_rem_) {
adjustForPowerOfTwo();
} else {
reduceIter();
}
}

void adjustForPowerOfTwo() {
if (is_part_of_adjustment_group_ and not is_even_) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part1: Sending to Node {} \n", this_node_, this_node_ - 1);
}

proxy_[this_node_ - 1]
.template send<&DistanceDoubling::adjustForPowerOfTwoHandler>(val_);
}
}

void adjustForPowerOfTwoHandler(AllreduceDblMsg<DataT>* msg) {
if constexpr (isdebug) {
std::string data(1024, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part1 Handler: Received data ({}) "
"from {}\n",
this_node_, data, theContext()->getFromNodeCurrentTask());
}

Op<DataT>()(val_, msg->val_);

finished_adjustment_part_ = true;

reduceIter();
}

bool done() { return step_ == num_steps_ and allMessagesReceived(); }
bool isValid() { return (vrt_node_ != -1) and (step_ < num_steps_); }
bool allMessagesReceived() {
return std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; });
}
bool isReady() {
return ((is_part_of_adjustment_group_ and finished_adjustment_part_) and
step_ == 0) or
allMessagesReceived();
}

void reduceIter() {
// Ensure we have received all necessary messages
if (not isReady()) {
return;
}

auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;
if constexpr (isdebug) {
fmt::print(
"[{}] Part2 Step {}: Sending to Node {} \n", this_node_, step_, dest);
}

proxy_[dest].template send<&DistanceDoubling::reduceIterHandler>(
val_, step_);

mask_ <<= 1;
step_++;

tryReduce(step_ - 1);

if (done()) {
finalPart();
} else if (isReady()) {
reduceIter();
}
}

void tryReduce(int32_t step) {
if (
(step < step_) and not steps_reduced_[step] and steps_recv_[step] and
std::all_of(
steps_reduced_.cbegin(), steps_reduced_.cbegin() + step,
[](const auto val) { return val; })) {
Op<DataT>()(val_, messages_.at(step)->val_);
steps_reduced_[step] = true;
}
}

void reduceIterHandler(AllreduceDblMsg<DataT>* msg) {
if constexpr (isdebug) {
std::string data(1024, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part2 Step {} mask_= {} nprocs_pof2_ = {}: "
"Received data ({}) "
"from {}\n",
this_node_, msg->step_, mask_, nprocs_pof2_, data,
theContext()->getFromNodeCurrentTask());
}

messages_.at(msg->step_) = promoteMsg(msg);
steps_recv_[msg->step_] = true;

// Special case when we receive step 2 message before step 1 is done on this node
if (not finished_adjustment_part_) {
return;
}

tryReduce(msg->step_);

if ((mask_ < nprocs_pof2_) and isReady()) {
reduceIter();

} else if (done()) {
finalPart();
}
}

void sendToExcludedNodes() {
if (is_part_of_adjustment_group_ and is_even_) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part3 : Sending to Node {} \n", this_node_, this_node_ + 1);
}
proxy_[this_node_ + 1]
.template send<&DistanceDoubling::sendToExcludedNodesHandler>(val_);
}
}

void sendToExcludedNodesHandler(AllreduceDblMsg<DataT>* msg) {
val_ = msg->val_;

parent_proxy_[this_node_].template invoke<finalHandler>(val_);
completed_ = true;
}

void finalPart() {
if (completed_) {
return;
}

if (nprocs_rem_) {
sendToExcludedNodes();
}

parent_proxy_[this_node_].template invoke<finalHandler>(val_);
completed_ = true;
}

vt::objgroup::proxy::Proxy<DistanceDoubling> proxy_ = {};
void initialize(Args&&... args);

/**
* \brief Adjust for power of two nodes.
*/
void adjustForPowerOfTwo();

/**
* \brief Handler for adjusting for power of two nodes.
*
* \param msg Pointer to the message.
*/
void adjustForPowerOfTwoHandler(AllreduceDblMsg<DataT>* msg);

/**
* \brief Check if the allreduce operation is done.
*
* \return True if the operation is done, otherwise false.
*/
bool done();

/**
* \brief Check if the current state is valid for allreduce.
*
* \return True if the state is valid, otherwise false.
*/
bool isValid();

/**
* \brief Check if all messages are received for the current step.
*
* \return True if all messages are received, otherwise false.
*/
bool allMessagesReceived();

/**
* \brief Check if the object is ready for the next step of allreduce.
*
* \return True if ready, otherwise false.
*/
bool isReady();

/**
* \brief Perform the next step of the allreduce operation.
*/
void reduceIter();

/**
* \brief Try to reduce the message at the specified step.
*
* \param step The step at which to try reduction.
*/
void tryReduce(int32_t step);

/**
* \brief Handler for the reduce iteration.
*
* \param msg Pointer to the message.
*/
void reduceIterHandler(AllreduceDblMsg<DataT>* msg);

/**
* \brief Send data to excluded nodes for finalization.
*/
void sendToExcludedNodes();

/**
* \brief Handler for sending data to excluded nodes.
*
* \param msg Pointer to the message.
*/
void sendToExcludedNodesHandler(AllreduceDblMsg<DataT>* msg);

/**
* \brief Perform the final part of the allreduce operation.
*/
void finalPart();

vt::objgroup::proxy::Proxy<RecursiveDoubling> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

DataT val_ = {};
Expand Down Expand Up @@ -309,4 +236,6 @@ struct DistanceDoubling {

} // namespace vt::collective::reduce::allreduce

#include "recursive_doubling.impl.h"

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RECURSIVE_DOUBLING_H*/
Loading

0 comments on commit 9d88016

Please sign in to comment.