Skip to content

Commit

Permalink
#2240: Semi working Rabenseifner
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent e7e1a2e commit f90e5c7
Show file tree
Hide file tree
Showing 8 changed files with 679 additions and 58 deletions.
230 changes: 205 additions & 25 deletions src/vt/collective/reduce/allreduce/allreduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@
#include "vt/config.h"
#include "vt/context/context.h"
#include "vt/messaging/message/message.h"
#include "vt/objgroup/proxy/proxy_objgroup.h"

#include <tuple>
#include <cstdint>

namespace vt::collective::reduce::alleduce {
namespace vt::collective::reduce::allreduce {

template <typename DataT>
struct AllreduceMsg
Expand All @@ -66,65 +67,244 @@ struct AllreduceMsg
explicit AllreduceMsg(DataT&& in_val)
: MessageParentType(),
val_(std::forward<DataT>(in_val)) { }
explicit AllreduceMsg(DataT const& in_val)
explicit AllreduceMsg(DataT const& in_val, int step = 0)
: MessageParentType(),
val_(in_val) { }
val_(in_val),
step_(step) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);
s | val_;
s | step_;
}

DataT val_ = {};
int32_t step_ = {};
};

template <typename DataT>
struct Allreduce {
void rightHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[(val_.size() / 2) + i] += msg->vec_[i];
void initialize(
const DataT& data, vt::objgroup::proxy::Proxy<Allreduce> proxy,
uint32_t num_nodes) {
this_node_ = vt::theContext()->getNode();
is_even_ = this_node_ % 2 == 0;
val_ = data;
proxy_ = proxy;
num_steps_ = static_cast<int32_t>(log2(num_nodes));
nprocs_pof2_ = 1 << num_steps_;
nprocs_rem_ = num_nodes - nprocs_pof2_;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

r_index_.resize(num_steps_, 0);
r_count_.resize(num_steps_, 0);
s_index_.resize(num_steps_, 0);
s_count_.resize(num_steps_, 0);

w_size_ = data.size();

int step = 0;
size_t wsize = data.size();
for (int mask = 1; mask < nprocs_pof2_; mask <<= 1) {
auto vdest = vrt_node_ ^ mask;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

if (this_node_ < dest) {
r_count_[step] = wsize / 2;
s_count_[step] = wsize - r_count_[step];
s_index_[step] = r_index_[step] + r_count_[step];
} else {
s_count_[step] = wsize / 2;
r_count_[step] = wsize - s_count_[step];
r_index_[step] = s_index_[step] + s_count_[step];
}

if (step + 1 < num_steps_) {
r_index_[step + 1] = r_index_[step];
s_index_[step + 1] = r_index_[step];
wsize = r_count_[step];
step++;
}
}

// std::string str(1024, 0x0);
// for (int i = 0; i < num_steps_; ++i) {
// str.append(fmt::format(
// "Step{}: send_idx = {} send_count = {} recieve_idx = {} recieve_count "
// "= {}\n",
// i, s_index_[i], s_count_[i], r_index_[i], r_count_[i]));
// }
// fmt::print(
// "[{}] Initialize with size = {} num_steps {} \n {}", this_node_, w_size_,
// num_steps_, str);
}

void partOneCollective() {
if (is_part_of_adjustment_group_) {
auto const partner = is_even_ ? this_node_ + 1 : this_node_ - 1;

if (is_even_) {
proxy_[partner].template send<&Allreduce::partOneRightHalf>(
std::vector<int32_t>{val_.begin() + (val_.size() / 2), val_.end()});
vrt_node_ = this_node_ / 2;
} else {
proxy_[partner].template send<&Allreduce::partOneLeftHalf>(
std::vector<int32_t>{val_.begin(), val_.end() - (val_.size() / 2)});
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

if (nprocs_rem_ == 0) {
partTwo();
}
}

void rightHalfComplete(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[(val_.size() / 2) + i] = msg->vec_[i];
void partOneRightHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] += msg->val_[i];
}

// Send to left node
proxy_[theContext()->getNode() - 1]
.template send<&Allreduce::partOneFinalPart>(
std::vector<int32_t>{val_.begin() + (val_.size() / 2), val_.end()});
}

void leftHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[i] += msg->vec_[i];
void partOneLeftHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[i] += msg->val_[i];
}
}

void leftHalfComplete(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[i] = msg->vec_[i];
void partOneFinalPart(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] = msg->val_[i];
}

partTwo();
}

void sendHandler(AllreduceMsg<DataT>* msg) {
uint32_t start = is_even_ ? 0 : val_.size() / 2;
uint32_t end = is_even_ ? val_.size() / 2 : val_.size();
for (int i = 0; start < end; start++) {
val_[start] += msg->vec_[i++];
void partTwo() {
auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

// fmt::print(
// "[{}] Part2 Step {}: Sending to Node {} starting with idx = {} and count "
// "{} \n",
// this_node_, step_, dest, s_index_[step_], s_count_[step_]);
proxy_[dest].template send<&Allreduce::partTwoHandler>(
std::vector<int32_t>{
val_.begin() + (s_index_[step_]),
val_.begin() + (s_index_[step_]) + s_count_[step_]},
step_);

mask_ <<= 1;
if (step_ + 1 < num_steps_) {
step_++;
}
}

void reducedHan(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[val_.size() / 2 + i] = msg->vec_[i];
void partTwoHandler(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[r_index_[msg->step_] + i] += msg->val_[i];
}

// std::string data(128, 0x0);
// for (auto val : msg->val_) {
// data.append(fmt::format("{} ", val));
// }
// fmt::print(
// "[{}] Part2 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
// msg->step_, data, r_index_[msg->step_],
// theContext()->getFromNodeCurrentTask());

if (mask_ < nprocs_pof2_) {
partTwo();
} else {
step_ = num_steps_ - 1;
mask_ = nprocs_pof2_ >> 1;
partThree();
}
}

Allreduce() { is_even_ = theContext()->getNode() % 2 == 0; }
void partThree() {
auto vdest = vrt_node_ ^ mask_;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;

// std::string data(128, 0x0);
// auto subV = std::vector<int32_t>{
// val_.begin() + (r_index_[step_]),
// val_.begin() + (r_index_[step_]) + r_count_[step_]};
// for (auto val : subV) {
// data.append(fmt::format("{} ", val));
// }

// fmt::print(
// "[{}] Part3 Step {}: Sending to Node {} starting with idx = {} and count "
// "{} "
// "data={} \n",
// this_node_, step_, dest, r_index_[step_], r_count_[step_], data);

proxy_[dest].template send<&Allreduce::partThreeHandler>(
std::vector<int32_t>{
val_.begin() + (r_index_[step_]),
val_.begin() + (r_index_[step_]) + r_count_[step_]},
step_);

mask_ >>= 1;
step_--;
}

void partThreeHandler(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->val_.size(); i++) {
val_[s_index_[msg->step_] + i] = msg->val_[i];
}

// std::string data(128, 0x0);
// for (auto val : msg->val_) {
// data.append(fmt::format("{} ", val));
// }
// fmt::print(
// "[{}] Part3 Step {}: Received data ({}) idx = {} from {}\n", this_node_,
// msg->step_, data, s_index_[msg->step_],
// theContext()->getFromNodeCurrentTask());

if (mask_ > 0) {
partThree();
}
}

NodeType this_node_ = {};
bool is_even_ = false;
vt::objgroup::proxy::Proxy<Allreduce> proxy_ = {};
DataT val_ = {};
NodeType vrt_node_ = {};
bool is_part_of_adjustment_group_ = false;
int32_t num_steps_ = {};
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};
int32_t mask_ = 1;

size_t w_size_ = {};
int32_t step_ = 0;
std::vector<int32_t> r_index_ = {};
std::vector<int32_t> r_count_ = {};
std::vector<int32_t> s_index_ = {};
std::vector<int32_t> s_count_ = {};
};

} // namespace vt::collective::reduce::alleduce
} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_REDUCE_H*/
12 changes: 6 additions & 6 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@

#include <utility>

namespace vt::collective::reduce::alleduce {
namespace vt::collective::reduce::allreduce {

template <auto f, template <typename Arg> class Op, typename... Args>
void allreduce(Args&&... data) {

void allreduce_r(Args&&... data) {
auto msg = vt::makeMessage<AllreduceMsg>(std::forward<Args>(data)...);
auto const this_node = vt::theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();
Expand All @@ -39,7 +38,8 @@ void allreduce(Args&&... data) {
vt::runInEpochCollective([=] {
if (is_part_of_adjustment_group) {
auto const partner = is_even ? this_node + 1 : this_node - 1;
grp_proxy[partner].send<&Reducer::sendHandler>(std::forward<Args...>(data...));
grp_proxy[partner].send<&Reducer::sendHandler>(
std::forward<Args...>(data...));
}
});

Expand Down Expand Up @@ -123,6 +123,6 @@ void allreduce(Args&&... data) {
*/
}

} // namespace vt::collective::reduce::alleduce
} // namespace vt::collective::reduce::allreduce

#endif // INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
#endif // INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
3 changes: 3 additions & 0 deletions src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
ProxyType<ObjT> proxy, std::string const& name, std::string const& parent = ""
);

template <auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
ObjGroupManager::PendingSendType allreduce_r(ProxyType<ObjT> proxy, const DataT& data);

/**
* \brief Perform a reduction over an objgroup
*
Expand Down
29 changes: 29 additions & 0 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
//@HEADER
*/

#include "vt/messaging/message/smart_ptr.h"
#include <utility>
#if !defined INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H
#define INCLUDED_VT_OBJGROUP_MANAGER_IMPL_H

Expand All @@ -57,6 +59,7 @@
#include "vt/collective/collective_alg.h"
#include "vt/messaging/active.h"
#include "vt/elm/elm_id_bits.h"
#include "vt/collective/reduce/allreduce/allreduce.h"

#include <memory>

Expand Down Expand Up @@ -262,6 +265,32 @@ ObjGroupManager::PendingSendType ObjGroupManager::broadcast(MsgSharedPtr<MsgT> m
return objgroup::broadcast(msg,han);
}

template <
auto f, typename ObjT, template <typename Arg> class Op, typename DataT>
ObjGroupManager::PendingSendType
ObjGroupManager::allreduce_r(ProxyType<ObjT> proxy, const DataT& data) {
// check payload size and choose appropriate algorithm

auto const this_node = vt::theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();

using Reducer = collective::reduce::allreduce::Allreduce<DataT>;

auto grp_proxy =
vt::theObjGroup()->makeCollective<Reducer>("allreduce_rabenseifner");

grp_proxy[this_node].template invoke<&Reducer::initialize>(
data, grp_proxy, num_nodes);

vt::runInEpochCollective([=] {
grp_proxy[this_node].template invoke<&Reducer::partOneCollective>();
});

proxy[this_node].template invoke<f>(grp_proxy.get()->val_);

return PendingSendType{nullptr};
}

template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>
ObjGroupManager::PendingSendType ObjGroupManager::reduce(
ProxyType<ObjT> proxy, MsgSharedPtr<MsgT> msg,
Expand Down
9 changes: 9 additions & 0 deletions src/vt/objgroup/proxy/proxy_objgroup.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ struct Proxy {
Args&&... args
) const;

template <
auto f,
template <typename Arg> class Op = collective::NoneOp,
typename... Args
>
PendingSendType allreduce_h(
Args&&... args
) const;

/**
* \brief Reduce back to a point target. Performs a reduction using operator
* `Op` followed by a send to `f` with the result.
Expand Down
Loading

0 comments on commit f90e5c7

Please sign in to comment.