Skip to content

Commit

Permalink
cpp to MPI type mapping improvements (#3495)
Browse files Browse the repository at this point in the history
* Switch to gtype trait based mpi type dispatching

* Add mpi type mapping for std::int64_t

* Add mpi type mapping for std::int32_t

* Missing double cases

* Add documentation

* Tidy and fix

* Some more

* More fixes

* Fix order

* Simplify

* Revert

* Add default NULL type

* Sanity check static assert

* Fancy void_t fix?

* Wrong position

* Switch to non width types

* Revert type trait tickery and document odd behavior

* Add char types

* Doc

* Doc for macros

* Enabel preprocessing for doxygen

* Reactivate fixed with types

* one more

* Type size dependent overloading

* Another

* Try wordsize check

* combine checks

* Give up, make mpi type explicit for Kahip and remove general support of long long unsigned

* typo

* Add KaHIP type comment

* typos

* Remove maps for char and bool

* Revert to non type trait usage, when not templated type
  • Loading branch information
schnellerhase authored Nov 29, 2024
1 parent 11f485e commit fb4fd29
Show file tree
Hide file tree
Showing 14 changed files with 101 additions and 97 deletions.
4 changes: 2 additions & 2 deletions cpp/doc/Doxyfile
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,7 @@ PERLMOD_MAKEVAR_PREFIX =
# C-preprocessor directives found in the sources and include files.
# The default value is: YES.

ENABLE_PREPROCESSING = NO
ENABLE_PREPROCESSING = YES

# If the MACRO_EXPANSION tag is set to YES, doxygen will expand all macro names
# in the source code. If set to NO, only conditional compilation will be
Expand All @@ -2285,7 +2285,7 @@ MACRO_EXPANSION = YES
# The default value is: NO.
# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.

EXPAND_ONLY_PREDEF = YES
EXPAND_ONLY_PREDEF = NO

# If the SEARCH_INCLUDES tag is set to YES, the include files in the
# INCLUDE_PATH will be searched if a #include is found.
Expand Down
77 changes: 40 additions & 37 deletions cpp/dolfinx/common/MPI.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2007-2023 Magnus Vikstrøm and Garth N. Wells
// Copyright (C) 2007-2023 Magnus Vikstrøm, Garth N. Wells and Paul T. Kühner
//
// This file is part of DOLFINx (https://www.fenicsproject.org)
//
Expand Down Expand Up @@ -271,39 +271,42 @@ struct dependent_false : std::false_type
};

/// MPI Type

/// @brief Type trait for MPI type conversions.
template <typename T>
constexpr MPI_Datatype mpi_type()
{
if constexpr (std::is_same_v<T, float>)
return MPI_FLOAT;
else if constexpr (std::is_same_v<T, double>)
return MPI_DOUBLE;
else if constexpr (std::is_same_v<T, std::complex<double>>)
return MPI_C_DOUBLE_COMPLEX;
else if constexpr (std::is_same_v<T, std::complex<float>>)
return MPI_C_FLOAT_COMPLEX;
else if constexpr (std::is_same_v<T, short int>)
return MPI_SHORT;
else if constexpr (std::is_same_v<T, int>)
return MPI_INT;
else if constexpr (std::is_same_v<T, unsigned int>)
return MPI_UNSIGNED;
else if constexpr (std::is_same_v<T, long int>)
return MPI_LONG;
else if constexpr (std::is_same_v<T, unsigned long>)
return MPI_UNSIGNED_LONG;
else if constexpr (std::is_same_v<T, long long>)
return MPI_LONG_LONG;
else if constexpr (std::is_same_v<T, unsigned long long>)
return MPI_UNSIGNED_LONG_LONG;
else if constexpr (std::is_same_v<T, bool>)
return MPI_C_BOOL;
else if constexpr (std::is_same_v<T, std::int8_t>)
return MPI_INT8_T;
else
// Issue compile time error
static_assert(!std::is_same_v<T, T>);
}
struct mpi_type_mapping;

/// @brief Retrieves the MPI data type associated to the provided type.
/// @tparam T cpp type to map
template <typename T>
MPI_Datatype mpi_t = mpi_type_mapping<T>::type;

/// @brief Registers for cpp_t the corresponding mpi_t which can then be
/// retrieved with mpi_t<cpp_t> from here on.
#define MAP_TO_MPI_TYPE(cpp_t, mpi_t) \
template <> \
struct mpi_type_mapping<cpp_t> \
{ \
static inline MPI_Datatype type = mpi_t; \
};

/// @defgroup MPI type mappings
/// @{
/// @cond
MAP_TO_MPI_TYPE(float, MPI_FLOAT)
MAP_TO_MPI_TYPE(double, MPI_DOUBLE)
MAP_TO_MPI_TYPE(std::complex<float>, MPI_C_FLOAT_COMPLEX)
MAP_TO_MPI_TYPE(std::complex<double>, MPI_C_DOUBLE_COMPLEX)
MAP_TO_MPI_TYPE(std::int8_t, MPI_INT8_T)
MAP_TO_MPI_TYPE(std::int16_t, MPI_INT16_T)
MAP_TO_MPI_TYPE(std::int32_t, MPI_INT32_T)
MAP_TO_MPI_TYPE(std::int64_t, MPI_INT64_T)
MAP_TO_MPI_TYPE(std::uint8_t, MPI_UINT8_T)
MAP_TO_MPI_TYPE(std::uint16_t, MPI_UINT16_T)
MAP_TO_MPI_TYPE(std::uint32_t, MPI_UINT32_T)
MAP_TO_MPI_TYPE(std::uint64_t, MPI_UINT64_T)
/// @endcond
/// @}

//---------------------------------------------------------------------------
template <typename U>
Expand Down Expand Up @@ -434,7 +437,7 @@ distribute_to_postoffice(MPI_Comm comm, const U& x,

// Send/receive data (x)
MPI_Datatype compound_type;
MPI_Type_contiguous(shape[1], dolfinx::MPI::mpi_type<T>(), &compound_type);
MPI_Type_contiguous(shape[1], dolfinx::MPI::mpi_t<T>, &compound_type);
MPI_Type_commit(&compound_type);
std::vector<T> recv_buffer_data(shape[1] * recv_disp.back());
err = MPI_Neighbor_alltoallv(
Expand Down Expand Up @@ -616,7 +619,7 @@ distribute_from_postoffice(MPI_Comm comm, std::span<const std::int64_t> indices,
dolfinx::MPI::check_error(comm, err);

MPI_Datatype compound_type0;
MPI_Type_contiguous(shape[1], dolfinx::MPI::mpi_type<T>(), &compound_type0);
MPI_Type_contiguous(shape[1], dolfinx::MPI::mpi_t<T>, &compound_type0);
MPI_Type_commit(&compound_type0);

std::vector<T> recv_buffer_data(shape[1] * send_disp.back());
Expand Down Expand Up @@ -691,8 +694,8 @@ distribute_data(MPI_Comm comm0, std::span<const std::int64_t> indices,
if (comm1 != MPI_COMM_NULL)
{
rank_offset = 0;
err = MPI_Exscan(&shape0_local, &rank_offset, 1, MPI_INT64_T, MPI_SUM,
comm1);
err = MPI_Exscan(&shape0_local, &rank_offset, 1,
dolfinx::MPI::mpi_t<std::int64_t>, MPI_SUM, comm1);
dolfinx::MPI::check_error(comm1, err);
}
else
Expand Down
34 changes: 16 additions & 18 deletions cpp/dolfinx/common/Scatterer.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,7 @@ class Scatterer

// Scale sizes and displacements by block size
{
auto rescale = [](auto& x, int bs)
{
auto rescale = [](auto& x, int bs) {
std::ranges::transform(x, x.begin(), [bs](auto e) { return e *= bs; });
};
rescale(_sizes_local, bs);
Expand Down Expand Up @@ -207,11 +206,11 @@ class Scatterer
case type::neighbor:
{
assert(requests.size() == std::size_t(1));
MPI_Ineighbor_alltoallv(
send_buffer.data(), _sizes_local.data(), _displs_local.data(),
dolfinx::MPI::mpi_type<T>(), recv_buffer.data(), _sizes_remote.data(),
_displs_remote.data(), dolfinx::MPI::mpi_type<T>(), _comm0.comm(),
requests.data());
MPI_Ineighbor_alltoallv(send_buffer.data(), _sizes_local.data(),
_displs_local.data(), dolfinx::MPI::mpi_t<T>,
recv_buffer.data(), _sizes_remote.data(),
_displs_remote.data(), dolfinx::MPI::mpi_t<T>,
_comm0.comm(), requests.data());
break;
}
case type::p2p:
Expand All @@ -220,14 +219,14 @@ class Scatterer
for (std::size_t i = 0; i < _src.size(); i++)
{
MPI_Irecv(recv_buffer.data() + _displs_remote[i], _sizes_remote[i],
dolfinx::MPI::mpi_type<T>(), _src[i], MPI_ANY_TAG,
_comm0.comm(), &requests[i]);
dolfinx::MPI::mpi_t<T>, _src[i], MPI_ANY_TAG, _comm0.comm(),
&requests[i]);
}

for (std::size_t i = 0; i < _dest.size(); i++)
{
MPI_Isend(send_buffer.data() + _displs_local[i], _sizes_local[i],
dolfinx::MPI::mpi_type<T>(), _dest[i], 0, _comm0.comm(),
dolfinx::MPI::mpi_t<T>, _dest[i], 0, _comm0.comm(),
&requests[i + _src.size()]);
}
break;
Expand Down Expand Up @@ -404,11 +403,10 @@ class Scatterer
case type::neighbor:
{
assert(requests.size() == 1);
MPI_Ineighbor_alltoallv(send_buffer.data(), _sizes_remote.data(),
_displs_remote.data(), MPI::mpi_type<T>(),
recv_buffer.data(), _sizes_local.data(),
_displs_local.data(), MPI::mpi_type<T>(),
_comm1.comm(), &requests[0]);
MPI_Ineighbor_alltoallv(
send_buffer.data(), _sizes_remote.data(), _displs_remote.data(),
MPI::mpi_t<T>, recv_buffer.data(), _sizes_local.data(),
_displs_local.data(), MPI::mpi_t<T>, _comm1.comm(), &requests[0]);
break;
}
case type::p2p:
Expand All @@ -418,16 +416,16 @@ class Scatterer
for (std::size_t i = 0; i < _dest.size(); i++)
{
MPI_Irecv(recv_buffer.data() + _displs_local[i], _sizes_local[i],
dolfinx::MPI::mpi_type<T>(), _dest[i], MPI_ANY_TAG,
_comm0.comm(), &requests[i]);
dolfinx::MPI::mpi_t<T>, _dest[i], MPI_ANY_TAG, _comm0.comm(),
&requests[i]);
}

// Start non-blocking receive from neighbor process for which an owned
// index is a ghost.
for (std::size_t i = 0; i < _src.size(); i++)
{
MPI_Isend(send_buffer.data() + _displs_remote[i], _sizes_remote[i],
dolfinx::MPI::mpi_type<T>(), _src[i], 0, _comm0.comm(),
dolfinx::MPI::mpi_t<T>, _src[i], 0, _comm0.comm(),
&requests[i + _dest.size()]);
}
break;
Expand Down
5 changes: 3 additions & 2 deletions cpp/dolfinx/common/Table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,9 @@ Table Table::reduce(MPI_Comm comm, Table::Reduction reduction) const
std::partial_sum(pcounts.begin(), pcounts.end(), offsets.begin() + 1);

std::vector<double> values_all(offsets.back());
err = MPI_Gatherv(values.data(), values.size(), MPI_DOUBLE, values_all.data(),
pcounts.data(), offsets.data(), MPI_DOUBLE, 0, comm);
err = MPI_Gatherv(values.data(), values.size(), dolfinx::MPI::mpi_t<double>,
values_all.data(), pcounts.data(), offsets.data(),
dolfinx::MPI::mpi_t<double>, 0, comm);
dolfinx::MPI::check_error(comm, err);

// Return empty table on rank > 0
Expand Down
9 changes: 4 additions & 5 deletions cpp/dolfinx/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,17 @@ std::size_t hash_global(MPI_Comm comm, const T& x)

// Gather hash keys on root process
std::vector<std::size_t> all_hashes(dolfinx::MPI::size(comm));
int err = MPI_Gather(&local_hash, 1, dolfinx::MPI::mpi_type<std::size_t>(),
all_hashes.data(), 1,
dolfinx::MPI::mpi_type<std::size_t>(), 0, comm);
int err = MPI_Gather(&local_hash, 1, dolfinx::MPI::mpi_t<std::size_t>,
all_hashes.data(), 1, dolfinx::MPI::mpi_t<std::size_t>,
0, comm);
dolfinx::MPI::check_error(comm, err);

// Hash the received hash keys
boost::hash<std::vector<std::size_t>> hash;
std::size_t global_hash = hash(all_hashes);

// Broadcast hash key to all processes
err = MPI_Bcast(&global_hash, 1, dolfinx::MPI::mpi_type<std::size_t>(), 0,
comm);
err = MPI_Bcast(&global_hash, 1, dolfinx::MPI::mpi_t<std::size_t>, 0, comm);
dolfinx::MPI::check_error(comm, err);

return global_hash;
Expand Down
4 changes: 2 additions & 2 deletions cpp/dolfinx/fem/interpolate.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ void scatter_values(MPI_Comm comm, std::span<const std::int32_t> src_ranks,
std::vector<T> values(recv_offsets.back());
values.reserve(1);
MPI_Neighbor_alltoallv(send_values.data_handle(), send_sizes.data(),
send_offsets.data(), dolfinx::MPI::mpi_type<T>(),
send_offsets.data(), dolfinx::MPI::mpi_t<T>,
values.data(), recv_sizes.data(), recv_offsets.data(),
dolfinx::MPI::mpi_type<T>(), reverse_comm);
dolfinx::MPI::mpi_t<T>, reverse_comm);
MPI_Comm_free(&reverse_comm);

// Insert values received from neighborhood communicator in output
Expand Down
4 changes: 2 additions & 2 deletions cpp/dolfinx/geometry/BoundingBoxTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ class BoundingBoxTree
if (num_bboxes() > 0)
std::copy_n(std::prev(_bbox_coordinates.end(), 6), 6, send_bbox.begin());
std::vector<T> recv_bbox(mpi_size * 6);
MPI_Allgather(send_bbox.data(), 6, dolfinx::MPI::mpi_type<T>(),
recv_bbox.data(), 6, dolfinx::MPI::mpi_type<T>(), comm);
MPI_Allgather(send_bbox.data(), 6, dolfinx::MPI::mpi_t<T>, recv_bbox.data(),
6, dolfinx::MPI::mpi_t<T>, comm);

std::vector<std::pair<std::array<T, 6>, std::int32_t>> _recv_bbox(mpi_size);
for (std::size_t i = 0; i < _recv_bbox.size(); ++i)
Expand Down
8 changes: 4 additions & 4 deletions cpp/dolfinx/geometry/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,8 @@ PointOwnershipData<T> determine_point_ownership(const mesh::Mesh<T>& mesh,
std::vector<T> received_points((std::size_t)recv_offsets.back());
MPI_Neighbor_alltoallv(
send_data.data(), send_sizes.data(), send_offsets.data(),
dolfinx::MPI::mpi_type<T>(), received_points.data(), recv_sizes.data(),
recv_offsets.data(), dolfinx::MPI::mpi_type<T>(), forward_comm);
dolfinx::MPI::mpi_t<T>, received_points.data(), recv_sizes.data(),
recv_offsets.data(), dolfinx::MPI::mpi_t<T>, forward_comm);

// Get mesh geometry for closest entity
const mesh::Geometry<T>& geometry = mesh.geometry();
Expand Down Expand Up @@ -905,8 +905,8 @@ PointOwnershipData<T> determine_point_ownership(const mesh::Mesh<T>& mesh,
std::vector<T> recv_distances(recv_offsets.back());
MPI_Neighbor_alltoallv(
squared_distances.data(), send_sizes.data(), send_offsets.data(),
dolfinx::MPI::mpi_type<T>(), recv_distances.data(), recv_sizes.data(),
recv_offsets.data(), dolfinx::MPI::mpi_type<T>(), reverse_comm);
dolfinx::MPI::mpi_t<T>, recv_distances.data(), recv_sizes.data(),
recv_offsets.data(), dolfinx::MPI::mpi_t<T>, reverse_comm);

// Update point ownership with extrapolation information
std::vector<T> closest_distance(point_owners.size(),
Expand Down
16 changes: 10 additions & 6 deletions cpp/dolfinx/graph/partitioners.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ graph::partition_fn graph::scotch::partitioner(graph::scotch::strategy strategy,
// Exchange halo with node_partition data for ghosts
common::Timer timer3("SCOTCH: call SCOTCH_dgraphHalo");
err = SCOTCH_dgraphHalo(&dgrafdat, node_partition.data(),
dolfinx::MPI::mpi_type<SCOTCH_Num>());
dolfinx::MPI::mpi_t<SCOTCH_Num>);
if (err != 0)
throw std::runtime_error("Error during SCOTCH halo exchange");
timer3.stop();
Expand Down Expand Up @@ -554,9 +554,8 @@ graph::partition_fn graph::parmetis::partitioner(double imbalance,
const int psize = dolfinx::MPI::size(pcomm);
const idx_t num_local_nodes = graph.num_nodes();
node_disp = std::vector<idx_t>(psize + 1, 0);
MPI_Allgather(&num_local_nodes, 1, dolfinx::MPI::mpi_type<idx_t>(),
node_disp.data() + 1, 1, dolfinx::MPI::mpi_type<idx_t>(),
pcomm);
MPI_Allgather(&num_local_nodes, 1, dolfinx::MPI::mpi_t<idx_t>,
node_disp.data() + 1, 1, dolfinx::MPI::mpi_t<idx_t>, pcomm);
std::partial_sum(node_disp.begin(), node_disp.end(), node_disp.begin());
std::vector<idx_t> array(graph.array().begin(), graph.array().end());
std::vector<idx_t> offsets(graph.offsets().begin(),
Expand Down Expand Up @@ -631,8 +630,13 @@ graph::partition_fn graph::kahip::partitioner(int mode, int seed,
common::Timer timer1("KaHIP: build adjacency data");
std::vector<T> node_disp(dolfinx::MPI::size(comm) + 1, 0);
const T num_local_nodes = graph.num_nodes();
MPI_Allgather(&num_local_nodes, 1, dolfinx::MPI::mpi_type<T>(),
node_disp.data() + 1, 1, dolfinx::MPI::mpi_type<T>(), comm);

// KaHIP internally relies on an unsigned long long int type, which is not
// easily convertible to a general mpi type due to platform specific
// differences. So we can not rely on the general mpi_t<> mapping and do it
// by hand in this sole occurence.
MPI_Allgather(&num_local_nodes, 1, MPI_UNSIGNED_LONG_LONG,
node_disp.data() + 1, 1, MPI_UNSIGNED_LONG_LONG, comm);
std::partial_sum(node_disp.begin(), node_disp.end(), node_disp.begin());
std::vector<T> array(graph.array().begin(), graph.array().end());
std::vector<T> offsets(graph.offsets().begin(), graph.offsets().end());
Expand Down
13 changes: 5 additions & 8 deletions cpp/dolfinx/io/xdmf_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,8 @@ xdmf_utils::distribute_entity_data(
std::vector<T> recv_values_buffer(recv_disp.back());
err = MPI_Neighbor_alltoallv(
send_values_buffer.data(), num_items_send.data(), send_disp.data(),
dolfinx::MPI::mpi_type<T>(), recv_values_buffer.data(),
num_items_recv.data(), recv_disp.data(), dolfinx::MPI::mpi_type<T>(),
comm0);
dolfinx::MPI::mpi_t<T>, recv_values_buffer.data(),
num_items_recv.data(), recv_disp.data(), dolfinx::MPI::mpi_t<T>, comm0);
dolfinx::MPI::check_error(comm, err);
err = MPI_Comm_free(&comm0);
dolfinx::MPI::check_error(comm, err);
Expand All @@ -403,8 +402,7 @@ xdmf_utils::distribute_entity_data(
std::vector<std::pair<int, std::int64_t>> dest_to_index;
std::ranges::transform(
indices, std::back_inserter(dest_to_index),
[size, num_nodes](auto n)
{
[size, num_nodes](auto n) {
return std::pair(dolfinx::MPI::index_owner(size, n, num_nodes), n);
});
std::ranges::sort(dest_to_index);
Expand Down Expand Up @@ -552,9 +550,8 @@ xdmf_utils::distribute_entity_data(
std::vector<T> recv_values_buffer(recv_disp.back());
err = MPI_Neighbor_alltoallv(
send_values_buffer.data(), num_items_send.data(), send_disp.data(),
dolfinx::MPI::mpi_type<T>(), recv_values_buffer.data(),
num_items_recv.data(), recv_disp.data(), dolfinx::MPI::mpi_type<T>(),
comm0);
dolfinx::MPI::mpi_t<T>, recv_values_buffer.data(),
num_items_recv.data(), recv_disp.data(), dolfinx::MPI::mpi_t<T>, comm0);

dolfinx::MPI::check_error(comm, err);

Expand Down
4 changes: 2 additions & 2 deletions cpp/dolfinx/la/MatrixCSR.h
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,9 @@ void MatrixCSR<U, V, W, X>::scatter_rev_begin()

int status = MPI_Ineighbor_alltoallv(
_ghost_value_data.data(), val_send_count.data(), _val_send_disp.data(),
dolfinx::MPI::mpi_type<value_type>(), _ghost_value_data_in.data(),
dolfinx::MPI::mpi_t<value_type>, _ghost_value_data_in.data(),
val_recv_count.data(), _val_recv_disp.data(),
dolfinx::MPI::mpi_type<value_type>(), _comm.comm(), &_request);
dolfinx::MPI::mpi_t<value_type>, _comm.comm(), &_request);
assert(status == MPI_SUCCESS);
}
//-----------------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions cpp/dolfinx/la/Vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ auto inner_product(const V& a, const V& b)
});

T result;
MPI_Allreduce(&local, &result, 1, dolfinx::MPI::mpi_type<T>(), MPI_SUM,
MPI_Allreduce(&local, &result, 1, dolfinx::MPI::mpi_t<T>, MPI_SUM,
a.index_map()->comm());
return result;
}
Expand Down Expand Up @@ -279,7 +279,7 @@ auto norm(const V& x, Norm type = Norm::l2)
= std::accumulate(data.begin(), data.end(), U(0),
[](auto norm, auto x) { return norm + std::abs(x); });
U l1(0);
MPI_Allreduce(&local_l1, &l1, 1, MPI::mpi_type<U>(), MPI_SUM,
MPI_Allreduce(&local_l1, &l1, 1, MPI::mpi_t<U>, MPI_SUM,
x.index_map()->comm());
return l1;
}
Expand All @@ -293,7 +293,7 @@ auto norm(const V& x, Norm type = Norm::l2)
data, [](T a, T b) { return std::norm(a) < std::norm(b); });
auto local_linf = std::abs(*max_pos);
decltype(local_linf) linf = 0;
MPI_Allreduce(&local_linf, &linf, 1, MPI::mpi_type<decltype(linf)>(),
MPI_Allreduce(&local_linf, &linf, 1, MPI::mpi_t<decltype(linf)>,
MPI_MAX, x.index_map()->comm());
return linf;
}
Expand Down
Loading

0 comments on commit fb4fd29

Please sign in to comment.