Skip to content

Commit

Permalink
#2281: RecursiveDoubling allreduce - add Collection support
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Sep 16, 2024
1 parent 90c5668 commit 1568c4a
Show file tree
Hide file tree
Showing 7 changed files with 178 additions and 69 deletions.
36 changes: 36 additions & 0 deletions src/vt/collective/reduce/allreduce/recursive_doubling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,42 @@

namespace vt::collective::reduce::allreduce {

RecursiveDoubling::RecursiveDoubling(
detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems)
: collection_proxy_(proxy.get()),
local_num_elems_(num_elems),
nodes_(theGroup()->GetGroupNodes(group.get())),
num_nodes_(nodes_.size()),
this_node_(theContext()->getNode()),
num_steps_(static_cast<uint32_t>(std::log2(num_nodes_))),
nprocs_pof2_(1 << num_steps_),
nprocs_rem_(num_nodes_ - nprocs_pof2_) {
auto const is_default_group = theGroup()->isGroupDefault(group.get());
if (not is_default_group) {
auto it = std::find(nodes_.begin(), nodes_.end(), theContext()->getNode());
vtAssert(it != nodes_.end(), "This node was not found in group nodes!");

this_node_ = it - nodes_.begin();
}

is_even_ = this_node_ % 2 == 0;
is_part_of_adjustment_group_ = this_node_ < (2 * nprocs_rem_);
if (is_part_of_adjustment_group_) {
if (is_even_) {
vrt_node_ = this_node_ / 2;
} else {
vrt_node_ = -1;
}
} else {
vrt_node_ = this_node_ - nprocs_rem_;
}

vt_debug_print(
terse, allreduce,
"RecursiveDoubling (this={}): proxy={:x} proxy_={} local_num_elems={}\n",
print_ptr(this), proxy.get(), proxy_.getProxy(), local_num_elems_);
}

RecursiveDoubling::RecursiveDoubling(
detail::StrongObjGroup objgroup)
: objgroup_proxy_(objgroup.get()),
Expand Down
1 change: 1 addition & 0 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ namespace vt::collective::reduce::allreduce {
*/

struct RecursiveDoubling {
RecursiveDoubling(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems);
/**
* \brief Constructor for RecursiveDoubling class.
*
Expand Down
3 changes: 2 additions & 1 deletion src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ struct CollectionManager
);


template <auto f, typename ColT, template <typename Arg> class Op, typename ...Args>
template <typename ReducerT, auto f, typename ColT, template <typename Arg> class Op, typename ...Args>
messaging::PendingSend reduceLocal(
CollectionProxyWrapType<ColT> const& proxy, Args &&... args
);
Expand Down Expand Up @@ -1796,6 +1796,7 @@ struct CollectionManager

// Allreduce stuff, probably should be moved elsewhere
std::unordered_map<VirtualProxyType, ObjGroupProxyType> rabenseifner_reducers_;
std::unordered_map<VirtualProxyType, ObjGroupProxyType> recursive_doubling_reducers_;
};

}}} /* end namespace vt::vrt::collection */
Expand Down
101 changes: 71 additions & 30 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
//@HEADER
*/

#include "vt/collective/reduce/allreduce/recursive_doubling.h"
#include "vt/collective/reduce/allreduce/type.h"
#include <type_traits>
#if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H
#define INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H

Expand Down Expand Up @@ -894,7 +897,7 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler(
}

template <
auto f, typename ColT, template <typename Arg> class Op, typename... Args>
typename ReducerT, auto f, typename ColT, template <typename Arg> class Op, typename... Args>
messaging::PendingSend CollectionManager::reduceLocal(
CollectionProxyWrapType<ColT> const& proxy, Args&&... args) {
using namespace collective::reduce::allreduce;
Expand All @@ -913,44 +916,82 @@ messaging::PendingSend CollectionManager::reduceLocal(
auto const group = elm_holder->group();
bool const use_group = group_ready && send_group;

using Reducer = collective::reduce::allreduce::Rabenseifner;
auto stamp = proxy(idx).tryGetLocalPtr()->getNextAllreduceStamp();
auto const id = std::get<collective::reduce::detail::StrongSeq>(stamp).get();

auto cb = vt::theCB()->makeCallbackBcastCollectiveProxy<f>(proxy);

// Incorrect! will yield same reducer for different Op/payload size/final handler etc.
if (auto reducer = rabenseifner_reducers_.find(col_proxy);
reducer == rabenseifner_reducers_.end()) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
if constexpr (std::is_same_v<ReducerT, RabenseifnerT>) {
using Reducer = collective::reduce::allreduce::Rabenseifner;
if (auto reducer = rabenseifner_reducers_.find(col_proxy);
reducer == rabenseifner_reducers_.end()) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
vt_debug_print(
terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongVrtProxy{col_proxy},
collective::reduce::detail::StrongGroup{group}, num_elms);

rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy();
auto* obj = obj_proxy[theContext()->getNode()].get();
obj->proxy_ = obj_proxy;

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}
} else {

vt_debug_print(terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongVrtProxy{col_proxy},
collective::reduce::detail::StrongGroup{group}, num_elms
);

rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy();
auto* obj = obj_proxy[theContext()->getNode()].get();
obj->proxy_ = obj_proxy;

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
vt_debug_print(
terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy =
reducer->second; // rabenseifner_reducers_.at(col_proxy);
auto typed_proxy =
static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
auto* obj = typed_proxy[theContext()->getNode()].get();

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}
}
} else {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
using Reducer = collective::reduce::allreduce::RecursiveDoubling;
if (auto reducer = recursive_doubling_reducers_.find(col_proxy);
reducer == recursive_doubling_reducers_.end()) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
vt_debug_print(
terse, allreduce, "Creating Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongVrtProxy{col_proxy},
collective::reduce::detail::StrongGroup{group}, num_elms);

recursive_doubling_reducers_[col_proxy] = obj_proxy.getProxy();
auto* obj = obj_proxy[theContext()->getNode()].get();
obj->proxy_ = obj_proxy;

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}
} else {
vt_debug_print(terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy = reducer->second; // rabenseifner_reducers_.at(col_proxy);
auto typed_proxy =
static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
auto* obj = typed_proxy[theContext()->getNode()].get();

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
vt_debug_print(
terse, allreduce, "Reusing Reducer on idx={} with id={}\n", idx, id);
auto obj_proxy =
reducer->second; // rabenseifner_reducers_.at(col_proxy);
auto typed_proxy =
static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
auto* obj = typed_proxy[theContext()->getNode()].get();

obj->template setFinalHandler<DataT>(cb, id);
obj->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/vt/vrt/collection/reducable/reducable.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ struct Reducable : BaseProxyT {
) const;

template <
typename ReducerT,
auto f,
template <typename Arg> class Op = collective::NoneOp,
typename... Args
>
messaging::PendingSend allreduce_h(
messaging::PendingSend allreduce(
Args&&... args
) const;

Expand Down
6 changes: 3 additions & 3 deletions src/vt/vrt/collection/reducable/reducable.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::allreduce(
}

template <typename ColT, typename IndexT, typename BaseProxyT>
template <auto f, template <typename Arg> class Op, typename... Args>
messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::allreduce_h(
template <typename ReducerT, auto f, template <typename Arg> class Op, typename... Args>
messaging::PendingSend Reducable<ColT,IndexT,BaseProxyT>::allreduce(
Args&&... args
) const {
auto const proxy = this->getProxy();
return theCollection()->reduceLocal<f, ColT, Op>(
return theCollection()->reduceLocal<ReducerT, f, ColT, Op>(
proxy, std::forward<Args>(args)...);
}

Expand Down
97 changes: 63 additions & 34 deletions tests/perf/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,60 +291,92 @@ VT_PERF_TEST(MyTest, test_allreduce_group_rabenseifner) {
}
}

struct Hello : vt::Collection<Hello, vt::Index1D> {
Hello() {
struct RabensifnerColl : vt::Collection<RabensifnerColl, vt::Index1D> {
RabensifnerColl() {
for (auto const payload_size : payloadSizes) {
timer_names_[payload_size] = fmt::format("Collection {}", payload_size);
timer_names_[payload_size] = fmt::format("Collection Rabenseifner {}", payload_size);
}
}

void finalMaxHan(std::vector<int32_t> result) {
std::string result_s = "";
for (auto val : result) {
result_s.append(fmt::format("{} ", val));
}
fmt::print(
"[{}]: Allreduce finalMaxHan (Values=[{}]), idx={}\n",
theContext()->getNode(), result_s, getIndex().x());
void allreduceHan(std::vector<int32_t> result) {
col_send_done_ = true;
parent_->StopTimer(timer_names_.at(result.size()));
}

void executeAllreduce(size_t payload_size) {
auto proxy = this->getCollectionProxy();

std::vector<int32_t> payload(payload_size, getIndex().x());
parent_->StartTimer(timer_names_.at(payload_size));
proxy.allreduce<
collective::reduce::allreduce::RabenseifnerT, &RabensifnerColl::allreduceHan,
collective::PlusOp
>(payload);
}

bool col_send_done_ = false;
std::unordered_map<size_t, std::string> timer_names_ = {};
MyTest* parent_ = {};
};

VT_PERF_TEST(MyTest, test_allreduce_collection_rabenseifner) {
auto const num_elms_per_node = 1;
auto range = vt::Index1D(int32_t{num_nodes_ * num_elms_per_node});
auto proxy = vt::makeCollection<RabensifnerColl>("test_collection_allreduce")
.bounds(range)
.bulkInsert()
.wait();

// col_send_done_ = true;
// parent_->StopTimer(timer_names_.at(result.size()));
auto const thisNode = vt::theContext()->getNode();
auto const nextNode = (thisNode + 1) % num_nodes_;

theCollective()->barrier();

auto const elm = thisNode * num_elms_per_node;
proxy[elm].tryGetLocalPtr()->parent_ = this;

for (auto payload_size : payloadSizes) {
proxy.broadcastCollective<&RabensifnerColl::executeAllreduce>(payload_size);

// We run 1 coll elem per node, so it should be ok
theSched()->runSchedulerWhile(
[&] { return !proxy[elm].tryGetLocalPtr()->col_send_done_; });
proxy[elm].tryGetLocalPtr()->col_send_done_ = false;
}
}

void finalHan(std::vector<int32_t> result) {
// std::string result_s = "";
// for(auto val : result){
// result_s.append(fmt::format("{} ", val));
// }
// fmt::print(
// "[{}]: Allreduce handler (Values=[{}]), idx={}\n",
// theContext()->getNode(), result_s, getIndex().x()
// );
struct RecursiveDoublingColl : vt::Collection<RecursiveDoublingColl, vt::Index1D> {
RecursiveDoublingColl() {
for (auto const payload_size : payloadSizes) {
timer_names_[payload_size] = fmt::format("Collection RecursiveDoubling {}", payload_size);
}
}

void allreduceHan(std::vector<int32_t> result) {
col_send_done_ = true;
parent_->StopTimer(timer_names_.at(result.size()));
}

void handler(size_t payload_size) {
void executeAllreduce(size_t payload_size) {
auto proxy = this->getCollectionProxy();

std::vector<int32_t> payload(payload_size, getIndex().x());
parent_->StartTimer(timer_names_.at(payload_size));
proxy.allreduce_h<&Hello::finalHan, collective::PlusOp>(payload);

// proxy.allreduce_h<&Hello::finalMaxHan, collective::MaxOp>(
// std::move(payload));
proxy.allreduce<
collective::reduce::allreduce::RecursiveDoublingT, &RecursiveDoublingColl::allreduceHan,
collective::PlusOp
>(payload);
}

bool col_send_done_ = false;
std::unordered_map<size_t, std::string> timer_names_ = {};
MyTest* parent_ = {};
};

VT_PERF_TEST(MyTest, test_allreduce_collection_rabenseifner) {
VT_PERF_TEST(MyTest, test_allreduce_collection_racursive_doubling) {
auto const num_elms_per_node = 1;
auto range = vt::Index1D(int32_t{num_nodes_ * num_elms_per_node});
auto proxy = vt::makeCollection<Hello>("test_collection_send")
auto proxy = vt::makeCollection<RecursiveDoublingColl>("test_collection_allreduce")
.bounds(range)
.bulkInsert()
.wait();
Expand All @@ -355,13 +387,10 @@ VT_PERF_TEST(MyTest, test_allreduce_collection_rabenseifner) {
theCollective()->barrier();

auto const elm = thisNode * num_elms_per_node;

proxy[elm].tryGetLocalPtr()->parent_ = this;
proxy.broadcastCollective<&Hello::handler>(payloadSizes.front());
theSched()->runSchedulerWhile(
[&] { return !proxy[elm].tryGetLocalPtr()->col_send_done_; });

for (auto payload_size : payloadSizes) {
proxy.broadcastCollective<&Hello::handler>(payload_size);
proxy.broadcastCollective<&RecursiveDoublingColl::executeAllreduce>(payload_size);

// We run 1 coll elem per node, so it should be ok
theSched()->runSchedulerWhile(
Expand Down

0 comments on commit 1568c4a

Please sign in to comment.