Skip to content

Commit

Permalink
#2240: DataHandler for Rabenseifner allreduce that provides common AP…
Browse files Browse the repository at this point in the history
…I for various data types
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 42cf608 commit 0ee7848
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 9 deletions.
100 changes: 100 additions & 0 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

/*
//@HEADER
// *****************************************************************************
//
// data_handler.h
// DARMA/vt => Virtual Transport
//
// Copyright 2019-2021 National Technology & Engineering Solutions of Sandia, LLC
// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S.
// Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// * Neither the name of the copyright holder nor the names of its
// contributors may be used to endorse or promote products derived from this
// software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact [email protected]
//
// *****************************************************************************
//@HEADER
*/

#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H

#include <vector>

#ifdef VT_KOKKOS_ENABLED
#include <Kokkos_Core.hpp>
#endif

template <typename Container>
class DataHandler {
public:
using Scalar = float;

static size_t size(const Container& data);
static Scalar& at(Container& data, size_t idx);
static void set(Container& data, size_t idx, const Scalar& value);
static Container split(Container& data, size_t start, size_t end);
};

template <typename T>
class DataHandler<std::vector<T>> {
public:
using Scalar = T;
static size_t size(const std::vector<T>& data) { return data.size(); }
static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
static T& at(std::vector<T>& data, size_t idx) { return data[idx]; }
static void set(std::vector<T>& data, size_t idx, const T& value) {
data[idx] = value;
}
static std::vector<T> split(std::vector<T>& data, size_t start, size_t end) {
return std::vector<T>{data.begin() + start, data.begin() + end};
}
};

#ifdef VT_KOKKOS_ENABLED
template <typename T, typename... Props>
class DataHandler<Kokkos::View<T*, Props...>> {
public:
static size_t size(const Kokkos::View<T*, Props...>& data) {
return data.extent(0);
}
static T at(const Kokkos::View<T*, Props...>& data, size_t idx) {
return data(idx);
}
static T& at(Kokkos::View<T*, Props...>& data, size_t idx) {
return data(idx);
}
static void
set(Kokkos::View<T*, Props...>& data, size_t idx, const T& value) {
data(idx) = value;
}
};
#endif // VT_KOKKOS_ENABLED

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H*/
29 changes: 20 additions & 9 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
//@HEADER
*/


#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H

Expand All @@ -51,6 +50,7 @@
#include "vt/objgroup/proxy/proxy_objgroup.h"
#include "vt/registry/auto/auto_registry.h"
#include "vt/pipe/pipe_manager.h"
#include "data_handler.h"

#include <tuple>
#include <cstdint>
Expand Down Expand Up @@ -93,6 +93,8 @@ template <
typename DataT, template <typename Arg> class Op, typename ObjT,
auto finalHandler>
struct Rabenseifner {
using DataType = DataHandler<DataT>;

template <typename... Args>
Rabenseifner(
vt::objgroup::proxy::Proxy<ObjT> parentProxy, NodeType num_nodes,
Expand Down Expand Up @@ -141,6 +143,7 @@ struct Rabenseifner {

int step = 0;
size_t wsize = val_.size();
size_ = wsize;
for (int mask = 1; mask < nprocs_pof2_; mask <<= 1) {
auto vdest = vrt_node_ ^ mask;
auto dest = (vdest < nprocs_rem_) ? vdest * 2 : vdest + nprocs_rem_;
Expand Down Expand Up @@ -188,35 +191,39 @@ struct Rabenseifner {
if (is_even_) {
proxy_[partner]
.template send<&Rabenseifner::adjustForPowerOfTwoRightHalf>(
DataT{val_.begin() + (val_.size() / 2), val_.end()});
DataType::split(val_, size_ / 2, size_));
} else {
proxy_[partner]
.template send<&Rabenseifner::adjustForPowerOfTwoLeftHalf>(
DataT{val_.begin(), val_.end() - (val_.size() / 2)});
DataType::split(val_, 0, size_ / 2));
}
}
}

void adjustForPowerOfTwoRightHalf(AllreduceRbnMsg<DataT>* msg) {

for (uint32_t i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] += msg->val_[i];
Op<typename DataType::Scalar>()(
DataType::at(val_, (size_ / 2) + i), DataType::at(msg->val_, i));
}

// Send to left node
proxy_[theContext()->getNode() - 1]
.template send<&Rabenseifner::adjustForPowerOfTwoFinalPart>(
DataT{val_.begin() + (val_.size() / 2), val_.end()});
DataType::split(val_, size_ / 2, size_));
// DataT{val_.begin() + (val_.size() / 2), val_.end()});
}

void adjustForPowerOfTwoLeftHalf(AllreduceRbnMsg<DataT>* msg) {
for (uint32_t i = 0; i < msg->val_.size(); i++) {
val_[i] += msg->val_[i];
Op<typename DataT::value_type>()(DataType::at(val_, i), DataType::at(msg->val_, i));
}
}

void adjustForPowerOfTwoFinalPart(AllreduceRbnMsg<DataT>* msg) {
for (uint32_t i = 0; i < msg->val_.size(); i++) {
val_[(val_.size() / 2) + i] = msg->val_[i];
DataType::at(val_, (val_.size() / 2) + i) = DataType::at(msg->val_, i);
// val_[(val_.size() / 2) + i] = msg->val_[i];
}

finished_adjustment_part_ = true;
Expand Down Expand Up @@ -262,7 +269,9 @@ struct Rabenseifner {
auto& in_val = in_msg->val_;
for (uint32_t i = 0; i < in_val.size(); i++) {
Op<typename DataT::value_type>()(
val_[r_index_[in_msg->step_] + i], in_val[i]);
DataType::at(val_, r_index_[in_msg->step_] + i),
DataType::at(in_val, i));
// val_[r_index_[in_msg->step_] + i], in_val[i]);
}

scatter_steps_reduced_[step] = true;
Expand Down Expand Up @@ -357,7 +366,8 @@ struct Rabenseifner {
auto& in_msg = gather_messages_.at(step);
auto& in_val = in_msg->val_;
for (uint32_t i = 0; i < in_val.size(); i++) {
val_[s_index_[in_msg->step_] + i] = in_val[i];
DataType::at(val_, s_index_[in_msg->step_] + i) = DataType::at(in_val, i);
//val_[s_index_[in_msg->step_] + i] = in_val[i];
}

gather_steps_reduced_[step] = true;
Expand Down Expand Up @@ -458,6 +468,7 @@ struct Rabenseifner {
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

DataT val_ = {};
size_t size_ = {};
NodeType this_node_ = {};
NodeType num_nodes_ = {};
bool is_even_ = false;
Expand Down

0 comments on commit 0ee7848

Please sign in to comment.