Skip to content

Commit

Permalink
#2240: Initial work for adding recursive doubling allreduce algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Apr 10, 2024
1 parent 193f6ed commit ce017d9
Show file tree
Hide file tree
Showing 6 changed files with 513 additions and 325 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_H
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DISTANCE_DOUBLING_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DISTANCE_DOUBLING_H

#include "vt/config.h"
#include "vt/context/context.h"
Expand All @@ -54,22 +54,22 @@

namespace vt::collective::reduce::allreduce {

constexpr bool debug = false;
constexpr bool isdebug = false;

template <typename DataT>
struct AllreduceMsg
: SerializeIfNeeded<vt::Message, AllreduceMsg<DataT>, DataT> {
struct AllreduceDblMsg
: SerializeIfNeeded<vt::Message, AllreduceDblMsg<DataT>, DataT> {
using MessageParentType =
SerializeIfNeeded<::vt::Message, AllreduceMsg<DataT>, DataT>;
SerializeIfNeeded<::vt::Message, AllreduceDblMsg<DataT>, DataT>;

AllreduceMsg() = default;
AllreduceMsg(AllreduceMsg const&) = default;
AllreduceMsg(AllreduceMsg&&) = default;
AllreduceDblMsg() = default;
AllreduceDblMsg(AllreduceDblMsg const&) = default;
AllreduceDblMsg(AllreduceDblMsg&&) = default;

explicit AllreduceMsg(DataT&& in_val)
explicit AllreduceDblMsg(DataT&& in_val)
: MessageParentType(),
val_(std::forward<DataT>(in_val)) { }
explicit AllreduceMsg(DataT const& in_val, int step = 0)
explicit AllreduceDblMsg(DataT const& in_val, int step = 0)
: MessageParentType(),
val_(in_val),
step_(step) { }
Expand All @@ -86,9 +86,9 @@ struct AllreduceMsg
};

template <typename DataT>
struct Allreduce {
struct DistanceDoubling {
void initialize(
const DataT& data, vt::objgroup::proxy::Proxy<Allreduce> proxy,
const DataT& data, vt::objgroup::proxy::Proxy<DistanceDoubling> proxy,
uint32_t num_nodes) {
this_node_ = vt::theContext()->getNode();
is_even_ = this_node_ % 2 == 0;
Expand All @@ -108,96 +108,27 @@ struct Allreduce {
vrt_node_ = this_node_ - nprocs_rem_;
}

r_index_.resize(num_steps_, 0);
r_count_.resize(num_steps_, 0);
s_index_.resize(num_steps_, 0);
s_count_.resize(num_steps_, 0);

w_size_ = data.size();

int step = 0;
size_t wsize = data.size();
for (int mask = 1; mask < nprocs_pof2_; mask <<= 1) {
auto vdest = vrt_node_ ^ mask;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

if (this_node_ < dest) {
r_count_[step] = wsize / 2;
s_count_[step] = wsize - r_count_[step];
s_index_[step] = r_index_[step] + r_count_[step];
} else {
s_count_[step] = wsize / 2;
r_count_[step] = wsize - s_count_[step];
r_index_[step] = s_index_[step] + s_count_[step];
}

if (step + 1 < num_steps_) {
r_index_[step + 1] = r_index_[step];
s_index_[step + 1] = r_index_[step];
wsize = r_count_[step];
step++;
}
}

expected_send_ = num_steps_;
expected_recv_ = num_steps_;
steps_sent_.resize(num_steps_, false);
steps_recv_.resize(num_steps_, false);

if constexpr (debug) {
std::string str(1024, 0x0);
for (int i = 0; i < num_steps_; ++i) {
str.append(fmt::format(
"Step{}: send_idx = {} send_count = {} recieve_idx = {} "
"recieve_count "
"= {}\n",
i, s_index_[i], s_count_[i], r_index_[i], r_count_[i]));
}
fmt::print(
"[{}] Initialize with size = {} num_steps {} \n {}", this_node_,
w_size_, num_steps_, str);
}
}

void partOne() {
if (is_part_of_adjustment_group_) {
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1;

if (is_even_) {
proxy_[partner].template send<&Allreduce::partOneRightHalf>(
DataT{val_.begin() + (val_.size() / 2), val_.end()});
} else {
proxy_[partner].template send<&Allreduce::partOneLeftHalf>(
DataT{val_.begin(), val_.end() - (val_.size() / 2)});
}
}
}

void partOneRightHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] += msg->val_[i];
if (is_part_of_adjustment_group_ and is_even_) {
proxy_[this_node_ + 1].template send<&DistanceDoubling::partOneHandler>(
val_);
}

// Send to left node
proxy_[theContext()->getNode() - 1]
.template send<&Allreduce::partOneFinalPart>(
DataT{val_.begin() + (val_.size() / 2), val_.end()});
}

void partOneLeftHalf(AllreduceMsg<DataT>* msg) {
void partOneHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[i] += msg->val_[i];
}
}

void partOneFinalPart(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] = msg->val_[i];
}

// partTwo();
}

void partTwo() {
if (
vrt_node_ == -1 or (step_ >= num_steps_) or
Expand All @@ -209,18 +140,11 @@ struct Allreduce {

auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;
if constexpr (debug) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part2 Step {}: Sending to Node {} starting with idx = {} and "
"count "
"{} \n",
this_node_, step_, dest, s_index_[step_], s_count_[step_]);
"[{}] Part2 Step {}: Sending to Node {} \n", this_node_, step_, dest);
}
proxy_[dest].template send<&Allreduce::partTwoHandler>(
DataT{
val_.begin() + (s_index_[step_]),
val_.begin() + (s_index_[step_]) + s_count_[step_]},
step_);
proxy_[dest].template send<&DistanceDoubling::partTwoHandler>(val_, step_);

mask_ <<= 1;
num_send_++;
Expand All @@ -234,19 +158,19 @@ struct Allreduce {
}
}

void partTwoHandler(AllreduceMsg<DataT>* msg) {
void partTwoHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[r_index_[msg->step_] + i] += msg->val_[i];
val_[i] += msg->val_[i];
}
if constexpr (debug) {
if constexpr (isdebug) {
std::string data(128, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part2 Step {} mask_= {} nprocs_pof2_ = {}: Received data ({}) "
"idx = {} from {}\n",
this_node_, msg->step_, mask_, nprocs_pof2_, data, r_index_[msg->step_],
this_node_, msg->step_, mask_, nprocs_pof2_, data,
theContext()->getFromNodeCurrentTask());
}
steps_recv_[msg->step_] = true;
Expand Down Expand Up @@ -286,27 +210,19 @@ struct Allreduce {
auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

if constexpr (debug) {
std::string data(128, 0x0);
auto subV = std::vector<int32_t>{
val_.begin() + (r_index_[step_]),
val_.begin() + (r_index_[step_]) + r_count_[step_]};
for (auto val : subV) {
if constexpr (isdebug) {
std::string data(1024, 0x0);

for (auto val : val_) {
data.append(fmt::format("{} ", val));
}

fmt::print(
"[{}] Part3 Step {}: Sending to Node {} starting with idx = {} and "
"count "
"{} "
"data={} \n",
this_node_, step_, dest, r_index_[step_], r_count_[step_], data);
"[{}] Part3 Step {}: Sending to Node {} data={} \n", this_node_, step_,
dest, data);
}
proxy_[dest].template send<&Allreduce::partThreeHandler>(
DataT{
val_.begin() + (r_index_[step_]),
val_.begin() + (r_index_[step_]) + r_count_[step_]},
step_);
proxy_[dest].template send<&DistanceDoubling::partThreeHandler>(
val_, step_);

steps_sent_[step_] = true;
num_send_++;
Expand All @@ -321,9 +237,9 @@ struct Allreduce {
}
}

void partThreeHandler(AllreduceMsg<DataT>* msg) {
void partThreeHandler(AllreduceDblMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[s_index_[msg->step_] + i] = msg->val_[i];
val_[i] = msg->val_[i];
}

if (not startedPartThree_) {
Expand All @@ -337,15 +253,14 @@ struct Allreduce {
}

num_recv_++;
if constexpr (debug) {
if constexpr (isdebug) {
std::string data(128, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part3 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
msg->step_, data, s_index_[msg->step_],
theContext()->getFromNodeCurrentTask());
"[{}] Part3 Step {}: Received data ({}) from {}\n", this_node_,
msg->step_, data, theContext()->getFromNodeCurrentTask());
}

steps_recv_[msg->step_] = true;
Expand All @@ -362,19 +277,20 @@ struct Allreduce {

void partFour() {
if (is_part_of_adjustment_group_ and is_even_) {
if constexpr (debug) {
if constexpr (isdebug) {
fmt::print(
"[{}] Part4 : Sending to Node {} \n", this_node_, this_node_ + 1);
}
proxy_[this_node_ + 1].template send<&Allreduce::partFourHandler>(val_);
proxy_[this_node_ + 1].template send<&DistanceDoubling::partFourHandler>(
val_);
}
}

void partFourHandler(AllreduceMsg<DataT>* msg) { val_ = msg->val_; }
void partFourHandler(AllreduceDblMsg<DataT>* msg) { val_ = msg->val_; }

NodeType this_node_ = {};
bool is_even_ = false;
vt::objgroup::proxy::Proxy<Allreduce> proxy_ = {};
vt::objgroup::proxy::Proxy<DistanceDoubling> proxy_ = {};
DataT val_ = {};
NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
Expand All @@ -393,12 +309,8 @@ struct Allreduce {

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_sent_ = {};
std::vector<int32_t> r_index_ = {};
std::vector<int32_t> r_count_ = {};
std::vector<int32_t> s_index_ = {};
std::vector<int32_t> s_count_ = {};
};

} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_REDUCE_H*/
#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H*/
Loading

0 comments on commit ce017d9

Please sign in to comment.