From a6d694e784505a0e163e98f161c016646c45b591 Mon Sep 17 00:00:00 2001 From: Jacob Domagala Date: Sun, 7 Apr 2024 22:22:37 +0200 Subject: [PATCH] #2240: Fix non power of 2 for new allreduce --- src/vt/collective/reduce/allreduce/allreduce.h | 12 ++++++++++++ src/vt/objgroup/manager.impl.h | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/src/vt/collective/reduce/allreduce/allreduce.h b/src/vt/collective/reduce/allreduce/allreduce.h index 3dabb10b5d..621d220df0 100644 --- a/src/vt/collective/reduce/allreduce/allreduce.h +++ b/src/vt/collective/reduce/allreduce/allreduce.h @@ -360,6 +360,18 @@ struct Allreduce { } } + void partFour() { + if (is_part_of_adjustment_group_ and is_even_) { + if constexpr (debug) { + fmt::print( + "[{}] Part4 : Sending to Node {} \n", this_node_, this_node_ + 1); + } + proxy_[this_node_ + 1].template send<&Allreduce::partFourHandler>(val_); + } + } + + void partFourHandler(AllreduceMsg* msg) { val_ = msg->val_; } + NodeType this_node_ = {}; bool is_even_ = false; vt::objgroup::proxy::Proxy proxy_ = {}; diff --git a/src/vt/objgroup/manager.impl.h b/src/vt/objgroup/manager.impl.h index 2aa3163613..c5ca868451 100644 --- a/src/vt/objgroup/manager.impl.h +++ b/src/vt/objgroup/manager.impl.h @@ -274,6 +274,10 @@ ObjGroupManager::allreduce_r(ProxyType proxy, const DataT& data) { auto const this_node = vt::theContext()->getNode(); auto const num_nodes = theContext()->getNumNodes(); + if(num_nodes < 2){ + return PendingSendType{nullptr}; + } + using Reducer = collective::reduce::allreduce::Allreduce; auto grp_proxy = @@ -294,6 +298,10 @@ ObjGroupManager::allreduce_r(ProxyType proxy, const DataT& data) { grp_proxy[this_node].template invoke<&Reducer::partThree>(); }); + vt::runInEpochCollective([=] { + grp_proxy[this_node].template invoke<&Reducer::partFour>(); + }); + proxy[this_node].template invoke(grp_proxy.get()->val_); return PendingSendType{nullptr};