Skip to content

Commit

Permalink
#2240: Update ObjGroup test to use custom DataHandler for Rabenseifne…
Browse files Browse the repository at this point in the history
…r allreduce test
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent d70050e commit 1f86e55
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 397 deletions.
4 changes: 4 additions & 0 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H

namespace vt::collective::reduce::allreduce {
#include <vector>

#ifdef VT_KOKKOS_ENABLED
Expand All @@ -65,6 +66,7 @@ class DataHandler {
template <typename T>
class DataHandler<std::vector<T>> {
public:
using UnderlyingType = std::vector<T>;
using Scalar = T;
static size_t size(const std::vector<T>& data) { return data.size(); }
static T at(const std::vector<T>& data, size_t idx) { return data[idx]; }
Expand Down Expand Up @@ -97,4 +99,6 @@ class DataHandler<Kokkos::View<T*, Props...>> {
};
#endif // VT_KOKKOS_ENABLED

} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_DATA_HANDLER_H*/
46 changes: 21 additions & 25 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ struct Rabenseifner {
nprocs_rem_(num_nodes_ - nprocs_pof2_),
finished_adjustment_part_(nprocs_rem_ == 0),
gather_step_(num_steps_ - 1),
gather_mask_(nprocs_pof2_ >> 1)
{
gather_mask_(nprocs_pof2_ >> 1) {
initialize(std::forward<Args>(args)...);
}

Expand Down Expand Up @@ -142,7 +141,7 @@ struct Rabenseifner {
s_count_.resize(num_steps_, 0);

int step = 0;
size_t wsize = val_.size();
size_t wsize = DataType::size(val_);
size_ = wsize;
for (int mask = 1; mask < nprocs_pof2_; mask <<= 1) {
auto vdest = vrt_node_ ^ mask;
Expand Down Expand Up @@ -170,7 +169,6 @@ struct Rabenseifner {
}

void executeFinalHan() {

// theCB()->makeSend<finalHandler>(parent_proxy_[this_node_]).sendTuple(std::make_tuple(val_));
parent_proxy_[this_node_].template invoke<finalHandler>(val_);
completed_ = true;
Expand Down Expand Up @@ -201,8 +199,7 @@ struct Rabenseifner {
}

void adjustForPowerOfTwoRightHalf(AllreduceRbnMsg<DataT>* msg) {

for (uint32_t i = 0; i < msg->val_.size(); i++) {
for (uint32_t i = 0; i < DataType::size(msg->val_); i++) {
Op<typename DataType::Scalar>()(
DataType::at(val_, (size_ / 2) + i), DataType::at(msg->val_, i));
}
Expand All @@ -211,19 +208,18 @@ struct Rabenseifner {
proxy_[theContext()->getNode() - 1]
.template send<&Rabenseifner::adjustForPowerOfTwoFinalPart>(
DataType::split(val_, size_ / 2, size_));
// DataT{val_.begin() + (val_.size() / 2), val_.end()});
}

void adjustForPowerOfTwoLeftHalf(AllreduceRbnMsg<DataT>* msg) {
for (uint32_t i = 0; i < msg->val_.size(); i++) {
Op<typename DataT::value_type>()(DataType::at(val_, i), DataType::at(msg->val_, i));
for (uint32_t i = 0; i < DataType::size(msg->val_); i++) {
Op<typename DataType::Scalar>()(
DataType::at(val_, i), DataType::at(msg->val_, i));
}
}

void adjustForPowerOfTwoFinalPart(AllreduceRbnMsg<DataT>* msg) {
for (uint32_t i = 0; i < msg->val_.size(); i++) {
DataType::at(val_, (val_.size() / 2) + i) = DataType::at(msg->val_, i);
// val_[(val_.size() / 2) + i] = msg->val_[i];
for (uint32_t i = 0; i < DataType::size(msg->val_); i++) {
DataType::at(val_, (size_ / 2) + i) = DataType::at(msg->val_, i);
}

finished_adjustment_part_ = true;
Expand Down Expand Up @@ -254,7 +250,7 @@ struct Rabenseifner {

bool scatterIsReady() {
return ((is_part_of_adjustment_group_ and finished_adjustment_part_) and
scatter_step_ == 0) or
scatter_step_ == 0) or
scatterAllMessagesReceived();
}

Expand All @@ -267,11 +263,11 @@ struct Rabenseifner {
[](const auto val) { return val; })) {
auto& in_msg = scatter_messages_.at(step);
auto& in_val = in_msg->val_;
for (uint32_t i = 0; i < in_val.size(); i++) {
Op<typename DataT::value_type>()(
for (uint32_t i = 0; i < DataType::size(in_val); i++) {
Op<typename DataType::Scalar>()(
DataType::at(val_, r_index_[in_msg->step_] + i),
DataType::at(in_val, i));
// val_[r_index_[in_msg->step_] + i], in_val[i]);
// val_[r_index_[in_msg->step_] + i], in_val[i]);
}

scatter_steps_reduced_[step] = true;
Expand All @@ -294,9 +290,9 @@ struct Rabenseifner {
s_count_[scatter_step_]);
}
proxy_[dest].template send<&Rabenseifner::scatterReduceIterHandler>(
DataT{
val_.begin() + (s_index_[scatter_step_]),
val_.begin() + (s_index_[scatter_step_]) + s_count_[scatter_step_]},
DataType::split(
val_, s_index_[scatter_step_],
s_index_[scatter_step_] + s_count_[scatter_step_]),
scatter_step_);

scatter_mask_ <<= 1;
Expand Down Expand Up @@ -365,9 +361,9 @@ struct Rabenseifner {
if (doRed) {
auto& in_msg = gather_messages_.at(step);
auto& in_val = in_msg->val_;
for (uint32_t i = 0; i < in_val.size(); i++) {
DataType::at(val_, s_index_[in_msg->step_] + i) = DataType::at(in_val, i);
//val_[s_index_[in_msg->step_] + i] = in_val[i];
for (uint32_t i = 0; i < DataType::size(in_val); i++) {
DataType::at(val_, s_index_[in_msg->step_] + i) =
DataType::at(in_val, i);
}

gather_steps_reduced_[step] = true;
Expand All @@ -391,9 +387,9 @@ struct Rabenseifner {
r_count_[gather_step_]);
}
proxy_[dest].template send<&Rabenseifner::gatherIterHandler>(
DataT{
val_.begin() + (r_index_[gather_step_]),
val_.begin() + (r_index_[gather_step_]) + r_count_[gather_step_]},
DataType::split(
val_, r_index_[gather_step_],
r_index_[gather_step_] + r_count_[gather_step_]),
gather_step_);

gather_mask_ >>= 1;
Expand Down
Loading

0 comments on commit 1f86e55

Please sign in to comment.