Skip to content

Commit

Permalink
#2240: Working Rabenseifner (non-commutative ops)
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent f90e5c7 commit 0715c52
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 72 deletions.
218 changes: 150 additions & 68 deletions src/vt/collective/reduce/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@

namespace vt::collective::reduce::allreduce {

constexpr bool debug = false;

template <typename DataT>
struct AllreduceMsg
: SerializeIfNeeded<vt::Message, AllreduceMsg<DataT>, DataT> {
Expand Down Expand Up @@ -137,37 +139,37 @@ struct Allreduce {
}
}

// std::string str(1024, 0x0);
// for (int i = 0; i < num_steps_; ++i) {
// str.append(fmt::format(
// "Step{}: send_idx = {} send_count = {} recieve_idx = {} recieve_count "
// "= {}\n",
// i, s_index_[i], s_count_[i], r_index_[i], r_count_[i]));
// }
// fmt::print(
// "[{}] Initialize with size = {} num_steps {} \n {}", this_node_, w_size_,
// num_steps_, str);
expected_send_ = num_steps_;
expected_recv_ = num_steps_;
steps_sent_.resize(num_steps_, false);
steps_recv_.resize(num_steps_, false);

if constexpr (debug) {
std::string str(1024, 0x0);
for (int i = 0; i < num_steps_; ++i) {
str.append(fmt::format(
"Step{}: send_idx = {} send_count = {} recieve_idx = {} "
"recieve_count "
"= {}\n",
i, s_index_[i], s_count_[i], r_index_[i], r_count_[i]));
}
fmt::print(
"[{}] Initialize with size = {} num_steps {} \n {}", this_node_,
w_size_, num_steps_, str);
}
}

void partOneCollective() {
void partOne() {
if (is_part_of_adjustment_group_) {
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1;

if (is_even_) {
proxy_[partner].template send<&Allreduce::partOneRightHalf>(
std::vector<int32_t>{val_.begin() + (val_.size() / 2), val_.end()});
vrt_node_ = this_node_ / 2;
DataT{val_.begin() + (val_.size() / 2), val_.end()});
} else {
proxy_[partner].template send<&Allreduce::partOneLeftHalf>(
std::vector<int32_t>{val_.begin(), val_.end() - (val_.size() / 2)});
vrt_node_ = -1;
DataT{val_.begin(), val_.end() - (val_.size() / 2)});
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

if (nprocs_rem_ == 0) {
partTwo();
}
}

Expand All @@ -179,7 +181,7 @@ struct Allreduce {
// Send to left node
proxy_[theContext()->getNode() - 1]
.template send<&Allreduce::partOneFinalPart>(
std::vector<int32_t>{val_.begin() + (val_.size() / 2), val_.end()});
DataT{val_.begin() + (val_.size() / 2), val_.end()});
}

void partOneLeftHalf(AllreduceMsg<DataT>* msg) {
Expand All @@ -193,95 +195,167 @@ struct Allreduce {
val_[(val_.size() / 2) + i] = msg->val_[i];
}

partTwo();
// partTwo();
}

void partTwo() {
if (
vrt_node_ == -1 or (step_ >= num_steps_) or
(not std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; }))) {
return;
}

auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

// fmt::print(
// "[{}] Part2 Step {}: Sending to Node {} starting with idx = {} and count "
// "{} \n",
// this_node_, step_, dest, s_index_[step_], s_count_[step_]);
if constexpr (debug) {
fmt::print(
"[{}] Part2 Step {}: Sending to Node {} starting with idx = {} and "
"count "
"{} \n",
this_node_, step_, dest, s_index_[step_], s_count_[step_]);
}
proxy_[dest].template send<&Allreduce::partTwoHandler>(
std::vector<int32_t>{
DataT{
val_.begin() + (s_index_[step_]),
val_.begin() + (s_index_[step_]) + s_count_[step_]},
step_);

mask_ <<= 1;
if (step_ + 1 < num_steps_) {
step_++;
num_send_++;
steps_sent_[step_] = true;
step_++;

if (std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; })) {
partTwo();
}
}

void partTwoHandler(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[r_index_[msg->step_] + i] += msg->val_[i];
}

// std::string data(128, 0x0);
// for (auto val : msg->val_) {
// data.append(fmt::format("{} ", val));
// }
// fmt::print(
// "[{}] Part2 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
// msg->step_, data, r_index_[msg->step_],
// theContext()->getFromNodeCurrentTask());

if constexpr (debug) {
std::string data(128, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part2 Step {} mask_= {} nprocs_pof2_ = {}: Received data ({}) "
"idx = {} from {}\n",
this_node_, msg->step_, mask_, nprocs_pof2_, data, r_index_[msg->step_],
theContext()->getFromNodeCurrentTask());
}
steps_recv_[msg->step_] = true;
num_recv_++;
if (mask_ < nprocs_pof2_) {
partTwo();
if (std::all_of(
steps_recv_.cbegin(), steps_recv_.cbegin() + step_,
[](const auto val) { return val; })) {
partTwo();
}
} else {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
partThree();
// step_ = num_steps_ - 1;
// mask_ = nprocs_pof2_ >> 1;
// partThree();
}
}

void partThree() {
if (
vrt_node_ == -1 or
(not std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; }))) {
return;
}

if (not startedPartThree_) {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
num_send_ = 0;
num_recv_ = 0;
startedPartThree_ = true;
std::fill(steps_sent_.begin(), steps_sent_.end(), false);
std::fill(steps_recv_.begin(), steps_recv_.end(), false);
}

auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

// std::string data(128, 0x0);
// auto subV = std::vector<int32_t>{
// val_.begin() + (r_index_[step_]),
// val_.begin() + (r_index_[step_]) + r_count_[step_]};
// for (auto val : subV) {
// data.append(fmt::format("{} ", val));
// }

// fmt::print(
// "[{}] Part3 Step {}: Sending to Node {} starting with idx = {} and count "
// "{} "
// "data={} \n",
// this_node_, step_, dest, r_index_[step_], r_count_[step_], data);
if constexpr (debug) {
std::string data(128, 0x0);
auto subV = std::vector<int32_t>{
val_.begin() + (r_index_[step_]),
val_.begin() + (r_index_[step_]) + r_count_[step_]};
for (auto val : subV) {
data.append(fmt::format("{} ", val));
}

fmt::print(
"[{}] Part3 Step {}: Sending to Node {} starting with idx = {} and "
"count "
"{} "
"data={} \n",
this_node_, step_, dest, r_index_[step_], r_count_[step_], data);
}
proxy_[dest].template send<&Allreduce::partThreeHandler>(
std::vector<int32_t>{
DataT{
val_.begin() + (r_index_[step_]),
val_.begin() + (r_index_[step_]) + r_count_[step_]},
step_);

steps_sent_[step_] = true;
num_send_++;
mask_ >>= 1;
step_--;
if (
step_ >= 0 and
std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; })) {
partThree();
}
}

void partThreeHandler(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[s_index_[msg->step_] + i] = msg->val_[i];
}

// std::string data(128, 0x0);
// for (auto val : msg->val_) {
// data.append(fmt::format("{} ", val));
// }
// fmt::print(
// "[{}] Part3 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
// msg->step_, data, s_index_[msg->step_],
// theContext()->getFromNodeCurrentTask());
if (not startedPartThree_) {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
num_send_ = 0;
num_recv_ = 0;
startedPartThree_ = true;
std::fill(steps_sent_.begin(), steps_sent_.end(), false);
std::fill(steps_recv_.begin(), steps_recv_.end(), false);
}

if (mask_ > 0) {
num_recv_++;
if constexpr (debug) {
std::string data(128, 0x0);
for (auto val : msg->val_) {
data.append(fmt::format("{} ", val));
}
fmt::print(
"[{}] Part3 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
msg->step_, data, s_index_[msg->step_],
theContext()->getFromNodeCurrentTask());
}

steps_recv_[msg->step_] = true;

if (
mask_ > 0 and
((step_ == num_steps_ - 1) or
std::all_of(
steps_recv_.cbegin() + step_ + 1, steps_recv_.cend(),
[](const auto val) { return val; }))) {
partThree();
}
}
Expand All @@ -296,9 +370,17 @@ struct Allreduce {
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};
int32_t mask_ = 1;
bool startedPartThree_ = false;

size_t w_size_ = {};
int32_t step_ = 0;
int32_t num_send_ = 0;
int32_t expected_send_ = 0;
int32_t num_recv_ = 0;
int32_t expected_recv_ = 0;

std::vector<bool> steps_recv_ = {};
std::vector<bool> steps_sent_ = {};
std::vector<int32_t> r_index_ = {};
std::vector<int32_t> r_count_ = {};
std::vector<int32_t> s_index_ = {};
Expand Down
10 changes: 9 additions & 1 deletion src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,15 @@ ObjGroupManager::allreduce_r(ProxyType<ObjT> proxy, const DataT& data) {
data, grp_proxy, num_nodes);

vt::runInEpochCollective([=] {
grp_proxy[this_node].template invoke<&Reducer::partOneCollective>();
grp_proxy[this_node].template invoke<&Reducer::partOne>();
});

vt::runInEpochCollective([=] {
grp_proxy[this_node].template invoke<&Reducer::partTwo>();
});

vt::runInEpochCollective([=] {
grp_proxy[this_node].template invoke<&Reducer::partThree>();
});

proxy[this_node].template invoke<f>(grp_proxy.get()->val_);
Expand Down
7 changes: 6 additions & 1 deletion tests/perf/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
*/
#include "common/test_harness.h"
#include "vt/collective/collective_alg.h"
#include "vt/configs/error/config_assert.h"
#include "vt/context/context.h"
#include <unordered_map>
#include <vt/collective/collective_ops.h>
Expand Down Expand Up @@ -73,7 +74,11 @@ struct NodeObj {
// fmt::print(
// "\n[{}]: allreduce_h done! (Size == {}) Results are ...\n",
// theContext()->getNode(), in.size());

// const auto p = theContext()->getNumNodes();
// const auto expected = (p * (p + 1)) / 2;
// for (auto val : in) {
// vtAssert(val == expected, "FAILURE!");
// }
// for (int node = 0; node < theContext()->getNumNodes(); ++node) {
// if (node == theContext()->getNode()) {
// std::string printer(128, 0x0);
Expand Down
4 changes: 2 additions & 2 deletions tests/perf/send_cost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,9 @@ VT_PERF_TEST(SendTest, test_objgroup_send) {
auto const thisNode = vt::theContext()->getNode();
auto const lastNode = theContext()->getNumNodes() - 1;

int nsteps = 2;
int nsteps = static_cast<int32_t>(log2(theContext()->getNumNodes()));
auto nprocs_rem = 0;
size_t count = 32; //1 << 6;
size_t count = 16; //1 << 6;
auto* buf = (int32_t*)malloc(sizeof(int32_t) * count);
auto nprocs_pof2 = 1 << nsteps;
auto rank = theContext()->getNode();
Expand Down

0 comments on commit 0715c52

Please sign in to comment.