diff --git a/src/vt/collective/reduce/allreduce/state_holder.cc b/src/vt/collective/reduce/allreduce/state_holder.cc new file mode 100644 index 0000000000..85edbfa7be --- /dev/null +++ b/src/vt/collective/reduce/allreduce/state_holder.cc @@ -0,0 +1,133 @@ +/* +//@HEADER +// ***************************************************************************** +// +// state_holder.cc +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2024 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#include "vt/config.h" +#include "state_holder.h" +#include "vt/configs/error/hard_error.h" +#include "vt/configs/error/config_assert.h" + +namespace vt::collective::reduce::allreduce { + +size_t +getNextIdImpl(StateHolder::StatesVec& states, size_t idx) { + size_t id = u64empty; + + for (; idx < states.size(); ++idx) { + auto& state = states.at(idx); + if (not state or not state->active_) { + id = idx; + break; + } + } + + if (id == u64empty) { + id = states.size(); + } + + return id; +} + +size_t StateHolder::getNextID(detail::StrongVrtProxy proxy) { + auto& states = active_coll_states_[proxy.get()]; + + collection_idx_ = getNextIdImpl(states, collection_idx_); + return collection_idx_; +} + +size_t StateHolder::getNextID(detail::StrongObjGroup proxy) { + auto& states = active_obj_states_[proxy.get()]; + + objgroup_idx_ = getNextIdImpl(states, objgroup_idx_); + return objgroup_idx_; +} + +size_t StateHolder::getNextID(detail::StrongGroup group) { + auto& states = active_grp_states_[group.get()]; + + group_idx_ = getNextIdImpl(states, group_idx_); + return group_idx_; +} + +static inline void +clearSingleImpl(StateHolder::StatesVec& states, size_t idx) { + auto const num_states = states.size(); + vtAssert( + num_states > idx, + fmt::format( + "Attempting to access state {} with total numer of states {}!", idx, + num_states)); + + states.at(idx).reset(); +} + +void StateHolder::clearSingle(detail::StrongVrtProxy proxy, size_t idx) { + auto& states = active_coll_states_[proxy.get()]; + + clearSingleImpl(states, idx); +} + +void StateHolder::clearSingle(detail::StrongObjGroup proxy, size_t idx) { + auto& states = active_obj_states_[proxy.get()]; + + clearSingleImpl(states, idx); +} + +void StateHolder::clearSingle(detail::StrongGroup group, size_t idx) { + auto& states = active_grp_states_[group.get()]; + + clearSingleImpl(states, idx); +} + +void StateHolder::clearAll(detail::StrongVrtProxy proxy) { + active_coll_states_.erase(proxy.get()); +} + +void StateHolder::clearAll(detail::StrongObjGroup proxy) { + active_obj_states_.erase(proxy.get()); +} + +void StateHolder::clearAll(detail::StrongGroup group) { + active_grp_states_.erase(group.get()); +} + +} // namespace vt::collective::reduce::allreduce diff --git a/src/vt/collective/reduce/allreduce/state_holder.h b/src/vt/collective/reduce/allreduce/state_holder.h index 4c2b2f64b6..6983f8be30 100644 --- a/src/vt/collective/reduce/allreduce/state_holder.h +++ b/src/vt/collective/reduce/allreduce/state_holder.h @@ -45,205 +45,59 @@ #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H #include "vt/collective/reduce/allreduce/data_handler.h" -#include "vt/collective/reduce/allreduce/helpers.h" #include "vt/collective/reduce/allreduce/type.h" #include "vt/collective/reduce/scoping/strong_types.h" #include "vt/configs/types/types_type.h" -#include "vt/configs/debug/debug_print.h" #include "vt/collective/reduce/allreduce/state.h" #include -#include #include namespace vt::collective::reduce::allreduce { struct StateHolder { + using StatesVec = std::vector>; + template < typename ReducerT, typename DataT, typename Scalar = typename DataHandler::Scalar> - static auto& getState(detail::StrongVrtProxy proxy, size_t idx) { - return getStateImpl(proxy, active_coll_states_, idx); - } + static decltype(auto) getState(detail::StrongVrtProxy proxy, size_t idx); template < typename ReducerT, typename DataT, typename Scalar = typename DataHandler::Scalar> - static auto& getState(detail::StrongObjGroup proxy, size_t idx) { - return getStateImpl(proxy, active_obj_states_, idx); - } + static decltype(auto) getState(detail::StrongObjGroup proxy, size_t idx); template < typename ReducerT, typename DataT, typename Scalar = typename DataHandler::Scalar> - static auto& getState(detail::StrongGroup proxy, size_t idx) { - return getStateImpl(proxy, active_grp_states_, idx); - } - - template - static size_t getNextID(detail::StrongVrtProxy proxy) { - size_t id = 0; - auto& allreducers = active_coll_states_[proxy.get()]; - - if (not allreducers.empty()) { - // Last element is invalidated (allreduce completed) or not completed - // Generate new ID - if (not allreducers.back() or allreducers.back()->active_) { - id = allreducers.size(); - } - // Most recent state is not active, don't generate new ID - else if (not allreducers.back()->active_) { - id = allreducers.size() - 1; - } - } - - return id; - } - - template - static size_t getNextID(detail::StrongObjGroup proxy) { - size_t id = 0; - auto& allreducers = active_obj_states_[proxy.get()]; - - if (not allreducers.empty()) { - // Last element is invalidated (allreduce completed) or not completed - // Generate new ID - if (not allreducers.back() or allreducers.back()->active_) { - id = allreducers.size(); - } - // Most recent state is not active, don't generate new ID - else if (not allreducers.back()->active_) { - id = allreducers.size() - 1; - } - } - - return id; - } - - static size_t getNextID(detail::StrongGroup group) { - size_t id = 0; - auto& allreducers = active_grp_states_[group.get()]; - - if (not allreducers.empty()) { - // Last element is invalidated (allreduce completed) or not completed - // Generate new ID - if (not allreducers.back() or allreducers.back()->active_) { - id = allreducers.size(); - } - // Most recent state is not active, don't generate new ID - else if (not allreducers.back()->active_) { - id = allreducers.size() - 1; - } - } - - return id; - } - - static void clearSingle(detail::StrongVrtProxy proxy, size_t idx) { - clearSingleImpl(proxy, active_coll_states_, idx); - } - - static void clearSingle(detail::StrongObjGroup proxy, size_t idx) { - clearSingleImpl(proxy, active_obj_states_, idx); - } + static decltype(auto) getState(detail::StrongGroup proxy, size_t idx); - static void clearSingle(detail::StrongGroup group, size_t idx) { - clearSingleImpl(group, active_grp_states_, idx); - } + static size_t getNextID(detail::StrongVrtProxy proxy); + static size_t getNextID(detail::StrongObjGroup proxy); + static size_t getNextID(detail::StrongGroup group); - static void clearAll(detail::StrongVrtProxy proxy) { - // fmt::print("Clearing all states for VrtProxy={:x}\n", proxy.get()); - clearAllImpl(proxy, active_coll_states_); - } - - static void clearAll(detail::StrongObjGroup proxy) { - // fmt::print("Clearing all states for Objgroup={:x}\n", proxy.get()); - clearAllImpl(proxy, active_obj_states_); - } + static void clearSingle(detail::StrongVrtProxy proxy, size_t idx); + static void clearSingle(detail::StrongObjGroup proxy, size_t idx); + static void clearSingle(detail::StrongGroup group, size_t idx); - static void clearAll(detail::StrongGroup group) { - // fmt::print("Clearing all states for group={:x}\n", group.get()); - clearAllImpl(group, active_grp_states_); - } + static void clearAll(detail::StrongVrtProxy proxy); + static void clearAll(detail::StrongObjGroup proxy); + static void clearAll(detail::StrongGroup group); private: - template - static void clearSingleImpl(ProxyT proxy, MapT& states_map, size_t idx) { - auto& states = states_map[proxy.get()]; - - auto const num_states = states.size(); - vtAssert( - num_states > idx, - fmt::format( - "Attempting to access state {} with total numer of states {}!", idx, - num_states)); + static inline size_t collection_idx_ = 0; + static inline size_t objgroup_idx_ = 0; + static inline size_t group_idx_ = 0; - states.at(idx).reset(); - } - - template - static void clearAllImpl(ProxyT proxy, MapT& states_map) { - states_map.erase(proxy.get()); - } - - template < - typename ReduceT, typename DataT, - typename Scalar = typename DataHandler::Scalar, typename ProxyT, - typename MapT> - static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) { - auto& states = states_map[proxy.get()]; - auto const num_states = states.size(); - - vtAssert( - num_states >= idx, - fmt::format( - "Attempting to access state {} with total number of states {}!", idx, - num_states)); - - if (idx >= num_states || num_states == 0) { - if constexpr (std::is_same_v) { - states.push_back(std::make_unique>()); - } else { - states.push_back(std::make_unique>()); - } - } - - vtAssert( - states.at(idx), - fmt::format("Attempting to access invalidated state at idx={}!", idx)); - - if constexpr (std::is_same_v) { - auto* ptr = - dynamic_cast*>(states.at(idx).get()); - vtAssert( - ptr, - fmt::format( - "Invalid Rabenseifner cast at idx={} with size={}!", idx, - states.size())); - return *ptr; - } else { - auto* ptr = - dynamic_cast*>(states.at(idx).get()); - vtAssert( - ptr, - fmt::format( - "Invalid RecursiveDoubling cast at idx={} with size={}!", idx, - states.size())); - return *ptr; - } - } - - static inline std::unordered_map< - VirtualProxyType, std::vector>> + static inline std::unordered_map active_coll_states_ = {}; - static inline std::unordered_map< - ObjGroupProxyType, std::vector>> + static inline std::unordered_map active_obj_states_ = {}; - static inline std::unordered_map< - GroupType, std::vector>> - active_grp_states_ = {}; + static inline std::unordered_map active_grp_states_ = + {}; }; template @@ -272,4 +126,6 @@ static inline void cleanupState(ComponentInfo info, size_t id) { } // namespace vt::collective::reduce::allreduce +#include "state_holder.impl.h" + #endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_H*/ diff --git a/src/vt/collective/reduce/allreduce/state_holder.impl.h b/src/vt/collective/reduce/allreduce/state_holder.impl.h new file mode 100644 index 0000000000..16a51af8cf --- /dev/null +++ b/src/vt/collective/reduce/allreduce/state_holder.impl.h @@ -0,0 +1,116 @@ +/* +//@HEADER +// ***************************************************************************** +// +// state_holder.impl.h +// DARMA/vt => Virtual Transport +// +// Copyright 2019-2024 National Technology & Engineering Solutions of Sandia, LLC +// (NTESS). Under the terms of Contract DE-NA0003525 with NTESS, the U.S. +// Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// * Neither the name of the copyright holder nor the names of its +// contributors may be used to endorse or promote products derived from this +// software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact darma@sandia.gov +// +// ***************************************************************************** +//@HEADER +*/ + +#include "vt/collective/reduce/allreduce/state.h" +#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H +#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H + +#include "state_holder.h" +#include "vt/configs/debug/debug_print.h" + +#include + +namespace vt::collective::reduce::allreduce { + +template < + typename StateT, typename DataT, typename Scalar, typename StateContainerT> +static inline auto& getState(StateContainerT& states, size_t idx, uint64_t id) { + if (!states.at(idx)) { + vt_debug_print( + verbose, allreduce, + "Creating new allreduce state for id={:x} for idx={} " + "Scalar_typeid={} DataT_typeid={}\n", + id, idx, typeid(Scalar).name(), typeid(DataT).name()); + states.at(idx) = std::make_unique(); + } + + auto* ptr = + dynamic_cast(states.at(idx).get()); + vtAssert( + ptr, + fmt::format( + "Invalid allreduce state cast for id={:x} idx={} Scalar_typeid={} " + "DataT_typeid={}\n", + idx, states.size(), typeid(Scalar).name(), typeid(DataT).name())); + return *ptr; +} + +template < + typename ReduceT, typename DataT, + typename Scalar = typename DataHandler::Scalar, typename ProxyT, + typename MapT> +static auto& getStateImpl(ProxyT proxy, MapT& states_map, size_t idx) { + auto& states = states_map[proxy.get()]; + auto const num_states = states.size(); + + if (idx >= num_states || num_states == 0) { + states.resize(idx + 1); + } + + if constexpr (std::is_same_v) { + return getState, DataT, Scalar>( + states, idx, proxy.get()); + } else { + return getState, DataT, Scalar>( + states, idx, proxy.get()); + } +} + +template +decltype(auto) StateHolder::getState(detail::StrongVrtProxy proxy, size_t idx) { + return getStateImpl(proxy, active_coll_states_, idx); +} + +template +decltype(auto) StateHolder::getState(detail::StrongObjGroup proxy, size_t idx) { + return getStateImpl(proxy, active_obj_states_, idx); +} + +template +decltype(auto) StateHolder::getState(detail::StrongGroup proxy, size_t idx) { + return getStateImpl(proxy, active_grp_states_, idx); +} + +} // namespace vt::collective::reduce::allreduce + +#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_STATE_HOLDER_IMPL_H*/