Skip to content

Commit

Permalink
#2240: Store Reducers by tuple(ProxyType, DataType, OperandType)
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 6ba48d3 commit 150b883
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 40 deletions.
8 changes: 8 additions & 0 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,18 @@

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H

#include "data_handler.h"
#include "rabenseifner_msg.h"
#include "vt/messaging/message/shared_message.h"

#include <vector>
#include <type_traits>

namespace vt {
template <typename T>
using remove_cvref = std::remove_cv_t<std::remove_reference_t<T>>;
}

namespace vt::collective::reduce::allreduce {

Expand Down
2 changes: 2 additions & 0 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,10 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT, auto finalHandler
>
struct Rabenseifner {
using Data = DataT;
using DataType = DataHandler<DataT>;
using Scalar = typename DataType::Scalar;
using ReduceOp = Op<Scalar>;
using DataHelperT = DataHelper<Scalar, DataT>;
using StateT = State<Scalar, DataT>;

Expand Down
2 changes: 2 additions & 0 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
struct RecursiveDoubling {
using Data = DataT;
using DataType = DataHandler<DataT>;
using Scalar = typename DataHandler<DataT>::Scalar;
using ReduceOp = Op<Scalar>;
/**
* \brief Constructor for RecursiveDoubling class.
*
Expand Down
9 changes: 0 additions & 9 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -354,15 +354,6 @@ void RecursiveDoubling<DataT, Op, ObjT, finalHandler>::finalPart(size_t id) {
parent_proxy_[this_node_].template invoke<finalHandler>(state.val_);

state.completed_ = true;

state.adjust_message_ = nullptr;
state.messages_.clear();

states_.erase(id);
// std::fill(state.messages_.begin(), state.messages_.end(), nullptr);

// state.steps_recv_.assign(num_steps_, false);
// state.steps_reduced_.assign(num_steps_, false);
}

} // namespace vt::collective::reduce::allreduce
Expand Down
14 changes: 11 additions & 3 deletions src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@
#include "vt/messaging/pending_send.h"
#include "vt/elm/elm_id.h"
#include "vt/utils/fntraits/fntraits.h"
#include "vt/utils/hash/hash_tuple.h"

#include <memory>
#include <functional>
#include <unordered_map>
#include <vector>
#include <typeindex>

namespace vt { namespace objgroup {

Expand Down Expand Up @@ -91,6 +93,11 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
using HolderBaseType = holder::HolderBase;
using HolderBasePtrType = std::unique_ptr<HolderBaseType>;
using PendingSendType = messaging::PendingSend;
using ReduceDataType = std::type_index;
using ReduceOperandType = std::type_index;
using ReducerMapType = std::unordered_map<
std::tuple<ObjGroupProxyType, ReduceDataType, ReduceOperandType>,
ObjGroupProxyType>;

public:
/**
Expand Down Expand Up @@ -507,9 +514,10 @@ ObjGroupManager::PendingSendType allreduce(ProxyType<ObjT> proxy, Args&&... data
std::unordered_map<ObjGroupProxyType, std::vector<ActionType>> pending_;
/// Map of object groups' labels
std::unordered_map<ObjGroupProxyType, std::string> labels_;

std::unordered_map<ObjGroupProxyType, ObjGroupProxyType> reducersRD_;
std::unordered_map<ObjGroupProxyType, ObjGroupProxyType> reducersR_;
/// Recursive Doubling reducers
ReducerMapType reducers_recursive_doubling_;
/// Rabenseifner reducers
ReducerMapType reducers_rabenseifner_;
};

}} /* end namespace vt::objgroup */
Expand Down
34 changes: 26 additions & 8 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/collective/reduce/allreduce/recursive_doubling.h"
#include "vt/collective/reduce/allreduce/type.h"
#include "vt/collective/reduce/allreduce/helpers.h"
#include <utility>
#include <array>

Expand Down Expand Up @@ -279,21 +280,36 @@ ObjGroupManager::PendingSendType ObjGroupManager::allreduce(

proxy::Proxy<Reducer> grp_proxy = {};

auto& reducers = Reducer::type_ == ReducerType::Rabenseifner ? reducersR_ : reducersRD_;
if (reducers.find(proxy.getProxy()) != reducers.end()) {
auto* obj = reinterpret_cast<Reducer*>(
objs_.at(reducers.at(proxy.getProxy()))->getPtr()
auto& reducers = Reducer::type_ == ReducerType::Rabenseifner ?
reducers_rabenseifner_ :
reducers_recursive_doubling_;
auto const key = std::make_tuple(
proxy.getProxy(), std::type_index(typeid(typename Reducer::Data)),
std::type_index(typeid(typename Reducer::ReduceOp))
);
if (reducers.find(key) != reducers.end()) {
vt_debug_print(
verbose, allreduce, "Found reducer (type: {}) for proxy {:x}",
TypeToString(Reducer::type_), proxy.getProxy()
);

auto* obj =
reinterpret_cast<Reducer*>(objs_.at(reducers.at(key))->getPtr());
id = obj->generateNewId();
obj->initialize(id, std::forward<Args>(data)...);
grp_proxy = obj->proxy_;
} else {
vt_debug_print(
verbose, allreduce, "Creating reducer (type: {}) for proxy {:x}",
TypeToString(Reducer::type_), proxy.getProxy()
);

grp_proxy = vt::theObjGroup()->makeCollective<Reducer>(
TypeToString(Reducer::type_), proxy,
num_nodes, std::forward<Args>(data)...
TypeToString(Reducer::type_), proxy, num_nodes,
std::forward<Args>(data)...
);
grp_proxy[this_node].get()->proxy_ = grp_proxy;
reducers[proxy.getProxy()] = grp_proxy.getProxy();
reducers[key] = grp_proxy.getProxy();
id = grp_proxy[this_node].get()->id_ - 1;
}

Expand All @@ -314,9 +330,10 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
}

auto const payload_size =
collective::reduce::allreduce::DataHandler<DataT>::size(
collective::reduce::allreduce::DataHandler<remove_cvref<DataT>>::size(
std::forward<Args>(data)...
);

if (payload_size < 2048) {
using Reducer =
vt::collective::reduce::allreduce::RecursiveDoubling<DataT, Op, ObjT, f>;
Expand All @@ -327,6 +344,7 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
return allreduce<Reducer>(proxy, std::forward<Args>(data)...);
}

// Silence nvcc warning
return PendingSendType{nullptr};
}

Expand Down
4 changes: 2 additions & 2 deletions src/vt/objgroup/proxy/proxy_objgroup.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "vt/messaging/param_msg.h"
#include "vt/objgroup/proxy/proxy_bits.h"
#include "vt/collective/reduce/get_reduce_stamp.h"
#include "vt/collective/reduce/allreduce/helpers.h"

namespace vt { namespace objgroup { namespace proxy {

Expand Down Expand Up @@ -215,8 +216,7 @@ Proxy<ObjT>::allreduce_h(
) const {
auto proxy = Proxy<ObjT>(*this);

// using DataT = std::tuple<std::decay_t<Args>...>;
return theObjGroup()->allreduce<f, ObjT, Op, std::decay_t<Args>...>(
return theObjGroup()->allreduce<f, ObjT, Op, remove_cvref<Args>...>(
proxy, std::forward<Args>(args)...);
}

Expand Down
42 changes: 35 additions & 7 deletions src/vt/utils/hash/hash_tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,44 @@
#include <tuple>

namespace std {
namespace {

template <typename A, typename B>
struct hash<std::tuple<A, B>> {
size_t operator()(std::tuple<A, B> const& in) const {
auto const& v1 = std::hash<A>()(std::get<0>(in));
auto const& v2 = std::hash<B>()(std::get<1>(in));
return v1 ^ v2;
// Code from boost
// Reciprocal of the golden ratio helps spread entropy
// and handles duplicates.
// See Mike Seymour in magic-numbers-in-boosthash-combine:
// http://stackoverflow.com/questions/4948780

template <class T>
inline void hash_combine(std::size_t& seed, T const& v) {
seed ^= std::hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// Recursive template code derived from Matthieu M.
template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
struct HashValueImpl {
static void apply(size_t& seed, Tuple const& tuple) {
HashValueImpl<Tuple, Index - 1>::apply(seed, tuple);
hash_combine(seed, std::get<Index>(tuple));
}
};

}
template <class Tuple>
struct HashValueImpl<Tuple, 0> {
static void apply(size_t& seed, Tuple const& tuple) {
hash_combine(seed, std::get<0>(tuple));
}
};
} // namespace

template <typename... TT>
struct hash<std::tuple<TT...>> {
size_t operator()(std::tuple<TT...> const& tt) const {
size_t seed = 0;
HashValueImpl<std::tuple<TT...>>::apply(seed, tt);
return seed;
}
};
} // namespace std

#endif /*INCLUDED_VT_UTILS_HASH_HASH_TUPLE_H*/
18 changes: 11 additions & 7 deletions tests/unit/objgroup/test_objgroup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@

#include "test_objgroup_common.h"
#include "test_helpers.h"
#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/configs/types/types_type.h"
#include "vt/objgroup/manager.h"

#include <typeinfo>
Expand Down Expand Up @@ -266,7 +268,7 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
auto const my_node = vt::theContext()->getNode();

TestObjGroup::total_verify_expected_ = 0;
auto proxy = vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_reduce");
auto proxy = vt::theObjGroup()->makeCollective<MyObjA>("test_proxy_allreduce");

vt::theCollective()->barrier();

Expand All @@ -289,7 +291,7 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {
EXPECT_EQ(MyObjA::total_verify_expected_, 3);
runInEpochCollective([&] {
using Reducer = vt::collective::reduce::allreduce::RecursiveDoubling<
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec<int, 256>
>;
std::vector<int> payload(256, my_node);
theObjGroup()->allreduce<Reducer>(proxy, payload);
Expand All @@ -299,18 +301,20 @@ TEST_F(TestObjGroup, test_proxy_allreduce) {

runInEpochCollective([&] {
using Reducer = vt::collective::reduce::allreduce::Rabenseifner<
std::vector<int>, PlusOp, MyObjA, &MyObjA::verifyAllredVec
NodeType, PlusOp, MyObjA, &MyObjA::verifyAllred<1>
>;
std::vector<int> payload(256, my_node);
theObjGroup()->allreduce<Reducer>(proxy, payload);
theObjGroup()->allreduce<Reducer>(proxy, payload);
std::vector<int> payload(2048, my_node);
theObjGroup()->allreduce<Reducer>(proxy, my_node);

std::vector<short> payload_large(2048 * 2, my_node);
theObjGroup()->allreduce<Reducer>(proxy, my_node);
});

EXPECT_EQ(MyObjA::total_verify_expected_, 6);

runInEpochCollective([&] {
using Reducer = vt::collective::reduce::allreduce::Rabenseifner<
VectorPayload, PlusOp, MyObjA, &MyObjA::verifyAllredVecPayload>;
VectorPayload, PlusOp, MyObjA, &MyObjA::verifyAllredVecPayload<VectorPayload, 256>>;
std::vector<int> payload(256, my_node);
VectorPayload data{payload};
theObjGroup()->allreduce<Reducer>(proxy, data);
Expand Down
12 changes: 8 additions & 4 deletions tests/unit/objgroup/test_objgroup_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ struct MyObjA {
total_verify_expected_++;
}

void verifyAllredVec(std::vector<int> vec) {
template<typename Scalar, int32_t size>
void verifyAllredVec(std::vector<Scalar> vec) {
auto final_size = vec.size();
EXPECT_EQ(final_size, 256);
EXPECT_EQ(final_size, size);

auto n = vt::theContext()->getNumNodes();
auto const n = theContext()->getNumNodes();
auto const total_sum = n * (n - 1)/2;
for(auto val : vec){
EXPECT_EQ(val, total_sum);
Expand All @@ -144,7 +145,10 @@ struct MyObjA {
total_verify_expected_++;
}

void verifyAllredVecPayload(VectorPayload vec) { verifyAllredVec(vec.vec_); }
template <typename DataT, int32_t size>
void verifyAllredVecPayload(VectorPayload vec) {
verifyAllredVec<typename decltype(DataT::vec_)::value_type, size>(vec.vec_);
}

#if MAGISTRATE_KOKKOS_ENABLED
void verifyAllredView(Kokkos::View<float*, Kokkos::HostSpace> view) {
Expand Down

0 comments on commit 150b883

Please sign in to comment.