Skip to content

Commit

Permalink
#2240: Initial work for new allreduce
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Apr 10, 2024
1 parent 0daa9c8 commit 186fb44
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 5 deletions.
130 changes: 130 additions & 0 deletions src/vt/collective/reduce/allreduce/allreduce.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
//@HEADER
// *****************************************************************************
//
// reduce.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_ALLREDUCE_H

#include "vt/config.h"
#include "vt/context/context.h"
#include "vt/messaging/message/message.h"

#include <tuple>
#include <cstdint>

namespace vt::collective::reduce::alleduce {

template <typename DataT>
struct AllreduceMsg
: SerializeIfNeeded<vt::Message, AllreduceMsg<DataT>, DataT> {
using MessageParentType =
SerializeIfNeeded<::vt::Message, AllreduceMsg<DataT>, DataT>;

AllreduceMsg() = default;
AllreduceMsg(AllreduceMsg const&) = default;
AllreduceMsg(AllreduceMsg&&) = default;

explicit AllreduceMsg(DataT&& in_val)
: MessageParentType(),
val_(std::forward<DataT>(in_val)) { }
explicit AllreduceMsg(DataT const& in_val)
: MessageParentType(),
val_(in_val) { }

template <typename SerializeT>
void serialize(SerializeT& s) {
MessageParentType::serialize(s);
s | val_;
}

DataT val_ = {};
};

template <typename DataT>
struct Allreduce {
void rightHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[(val_.size() / 2) + i] += msg->vec_[i];
}
}

void rightHalfComplete(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[(val_.size() / 2) + i] = msg->vec_[i];
}
}

void leftHalf(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[i] += msg->vec_[i];
}
}

void leftHalfComplete(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[i] = msg->vec_[i];
}
}

void sendHandler(AllreduceMsg<DataT>* msg) {
uint32_t start = is_even_ ? 0 : val_.size() / 2;
uint32_t end = is_even_ ? val_.size() / 2 : val_.size();
for (int i = 0; start < end; start++) {
val_[start] += msg->vec_[i++];
}
}

void reducedHan(AllreduceMsg<DataT>* msg) {
for (int i = 0; i < msg->vec_.size(); i++) {
val_[val_.size() / 2 + i] = msg->vec_[i];
}
}

Allreduce() { is_even_ = theContext()->getNode() % 2 == 0; }

bool is_even_ = false;
DataT val_ = {};
};

} // namespace vt::collective::reduce::alleduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_REDUCE_H*/
128 changes: 128 additions & 0 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@


#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H

#include "vt/messaging/message/shared_message.h"
#include "vt/objgroup/manager.h"
#include "vt/collective/reduce/allreduce/allreduce.h"

#include <utility>

namespace vt::collective::reduce::alleduce {

template <auto f, template <typename Arg> class Op, typename... Args>
void allreduce(Args&&... data) {

auto msg = vt::makeMessage<AllreduceMsg>(std::forward<Args>(data)...);
auto const this_node = vt::theContext()->getNode();
auto const num_nodes = theContext()->getNumNodes();

using Reducer = Allreduce<Args...>;

auto grp_proxy =
vt::theObjGroup()->makeCollective<Reducer>("allreduce_rabenseifner");

auto const lastNode = num_nodes - 1;
auto const num_steps = static_cast<int32_t>(log2(num_nodes));
auto const nprocs_pof2 = 1 << num_steps;
auto const nprocs_rem = num_nodes - nprocs_pof2;

////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////// STEP 1 ////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////

int vrt_node;
bool const is_part_of_adjustment_group = this_node < (2 * nprocs_rem);
bool const is_even = this_node % 2 == 0;
vt::runInEpochCollective([=, &vrt_node] {
vt::runInEpochCollective([=] {
if (is_part_of_adjustment_group) {
auto const partner = is_even ? this_node + 1 : this_node - 1;
grp_proxy[partner].send<&Reducer::sendHandler>(std::forward<Args...>(data...));
}
});

vt::runInEpochCollective([=] {
if (is_part_of_adjustment_group and not is_even) {
auto& vec = grp_proxy[this_node].get()->data_;
grp_proxy[this_node - 1].send<&Reducer::reducedHan>(
std::vector<int32_t>{vec.begin() + (vec.size() / 2), vec.end()});
}
});

if (is_part_of_adjustment_group) {
if (is_even) {
vrt_node = this_node / 2;
} else {
vrt_node = -1;
}

} else { /* rank >= 2 * nprocs_rem */
vrt_node = this_node - nprocs_rem;
}
});

////////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////// STEP 2 ////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////

// int step = 0;
// auto const wsize = data.size();

// auto& vec = grp_proxy[this_node].get()->data_;

// /*
// Scatter Reduce (distance doubling with vector halving)
// */
// for (int mask = 1; mask < (1 << num_steps); mask <<= 1) {
// int vdest = vrt_node ^ mask;
// int dest = (vdest < nprocs_rem) ? vdest * 2 : vdest + nprocs_rem;

// vt::runInEpochCollective([=] {
// if (vrt_node != -1) {
// if (this_node < dest) {
// grp_proxy[dest].send<&NodeObj::rightHalf>(
// std::vector<int32_t>{vec.begin() + (vec.size() / 2), vec.end()});
// } else {
// grp_proxy[dest].send<&NodeObj::leftHalf>(
// std::vector<int32_t>{vec.begin(), vec.end() - (vec.size() / 2)});
// }
// }
// });
// }

// step = num_steps - 1;

// /*
// AllGather (distance halving with vector halving)
// */
// for (int mask = (1 << num_steps) >> 1; mask > 0; mask >>= 1) {
// int vdest = vrt_node ^ mask;
// /* Translate vdest virtual rank to real rank */
// int dest = (vdest < nprocs_rem) ? vdest * 2 : vdest + nprocs_rem;
// vt::runInEpochCollective([=] {
// if (vrt_node != -1) {
// if (this_node < dest) {
// grp_proxy[dest].send<&NodeObj::leftHalfComplete>(
// std::vector<int32_t>{vec.begin(), vec.end() - (vec.size() / 2)});
// } else {
// grp_proxy[dest].send<&NodeObj::rightHalfComplete>(
// std::vector<int32_t>{vec.begin() + (vec.size() / 2), vec.end()});
// }
// }
// });
// }

/*
Send to excluded nodes (if needed)
*/

/*
Local invoke of the handler
*/
}

} // namespace vt::collective::reduce::alleduce

#endif // INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
111 changes: 111 additions & 0 deletions tests/perf/allreduce.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
//@HEADER
// *****************************************************************************
//
// reduce.cc
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/
#include "common/test_harness.h"
#include "vt/context/context.h"
#include <unordered_map>
#include <vt/collective/collective_ops.h>
#include <vt/objgroup/manager.h>
#include <vt/messaging/active.h>
#include <vt/collective/reduce/allreduce/allreduce.h>

#include <fmt-vt/core.h>

using namespace vt;
using namespace vt::tests::perf::common;

static constexpr int num_iters = 1;

struct MyTest : PerfTestHarness { };

struct NodeObj {
explicit NodeObj(MyTest* test_obj) : test_obj_(test_obj) { }

void initialize() { proxy_ = vt::theObjGroup()->getProxy<NodeObj>(this);
}
struct MyMsg : vt::Message {};

void reduceComplete(std::vector<int32_t> in) {
reduce_counter_++;
test_obj_->StopTimer(fmt::format("{} reduce", i));
test_obj_->GetMemoryUsage();
if (i < num_iters) {
i++;
auto this_node = theContext()->getNode();
proxy_[this_node].send<MyMsg, &NodeObj::perfReduce>();
} else if (theContext()->getNode() == 0) {
theTerm()->enableTD();
}
}

void perfReduce(MyMsg* in_msg) {
test_obj_->StartTimer(fmt::format("{} reduce", i));

proxy_.allreduce<&NodeObj::reduceComplete, collective::PlusOp>(data_);
}

private:
MyTest* test_obj_ = nullptr;
vt::objgroup::proxy::Proxy<NodeObj> proxy_ = {};
int reduce_counter_ = -1;
int i = 0;
std::vector<int32_t> data_ = {};
};

VT_PERF_TEST(MyTest, test_reduce) {
auto grp_proxy = vt::theObjGroup()->makeCollective<NodeObj>(
"test_reduce", this
);

if (theContext()->getNode() == 0) {
theTerm()->disableTD();
}

std::vector<int32_t> data(1024, theContext()->getNode());
grp_proxy.allreduce<&NodeObj::reduceComplete, collective::PlusOp>(data);

if (theContext()->getNode() == 0) {
theTerm()->enableTD();
}
}

VT_PERF_TEST_MAIN()
Loading

0 comments on commit 186fb44

Please sign in to comment.