Warning, file /include/corecel/sys/MpiOperations.hh was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008 #pragma once
0009
0010 #include <algorithm>
0011 #include <type_traits>
0012
0013 #include "corecel/Config.hh"
0014
0015 #if CELERITAS_USE_MPI
0016 # include <mpi.h>
0017 #endif
0018
0019 #include "corecel/Assert.hh"
0020 #include "corecel/Macros.hh"
0021 #include "corecel/cont/Span.hh"
0022
0023 #include "MpiCommunicator.hh"
0024
0025 #include "detail/MpiType.hh"
0026
0027 namespace celeritas
0028 {
0029
0030
0031
0032
0033 enum class Operation
0034 {
0035 min,
0036 max,
0037 sum,
0038 prod,
0039 size_
0040 };
0041
0042
0043
0044
0045
0046 inline void barrier(MpiCommunicator const& comm);
0047
0048
0049
0050 template<class T, std::size_t N>
0051 inline void allreduce(MpiCommunicator const& comm,
0052 Operation op,
0053 Span<T const, N> src,
0054 Span<T, N> dst);
0055
0056
0057
0058 template<class T, std::size_t N>
0059 inline void
0060 allreduce(MpiCommunicator const& comm, Operation op, Span<T, N> data);
0061
0062
0063
0064 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool> = true>
0065 inline T allreduce(MpiCommunicator const& comm, Operation op, T const src);
0066
0067
0068
0069
0070 namespace
0071 {
0072 #if CELERITAS_USE_MPI
0073 inline MPI_Op to_mpi(Operation op)
0074 {
0075 switch (op)
0076 {
0077
0078 case Operation::min: return MPI_MIN;
0079 case Operation::max: return MPI_MAX;
0080 case Operation::sum: return MPI_SUM;
0081 case Operation::prod: return MPI_PROD;
0082 default: CELER_ASSERT_UNREACHABLE();
0083
0084 }
0085 }
0086 #endif
0087 }
0088
0089
0090
0091
0092
0093 void barrier(MpiCommunicator const& comm)
0094 {
0095 if (!comm)
0096 return;
0097
0098 CELER_MPI_CALL(MPI_Barrier(comm.mpi_comm()));
0099 }
0100
0101
0102
0103
0104
0105 template<class T, std::size_t N>
0106 void allreduce(MpiCommunicator const& comm,
0107 [[maybe_unused]] Operation op,
0108 Span<T const, N> src,
0109 Span<T, N> dst)
0110 {
0111 CELER_EXPECT(src.size() == dst.size());
0112
0113 if (!comm)
0114 {
0115 std::copy(src.begin(), src.end(), dst.begin());
0116 return;
0117 }
0118
0119 CELER_MPI_CALL(MPI_Allreduce(src.data(),
0120 dst.data(),
0121 dst.size(),
0122 detail::MpiType<T>::value,
0123 to_mpi(op),
0124 comm.mpi_comm()));
0125 }
0126
0127
0128
0129
0130
0131 template<class T, std::size_t N>
0132 void allreduce(MpiCommunicator const& comm,
0133 [[maybe_unused]] Operation op,
0134 [[maybe_unused]] Span<T, N> data)
0135 {
0136 if (!comm)
0137 return;
0138
0139 CELER_MPI_CALL(MPI_Allreduce(MPI_IN_PLACE,
0140 data.data(),
0141 data.size(),
0142 detail::MpiType<T>::value,
0143 to_mpi(op),
0144 comm.mpi_comm()));
0145 }
0146
0147
0148
0149
0150
0151 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool>>
0152 T allreduce(MpiCommunicator const& comm, Operation op, T const src)
0153 {
0154 T dst{};
0155 allreduce(comm, op, Span<T const, 1>{&src, 1}, Span<T, 1>{&dst, 1});
0156 return dst;
0157 }
0158
0159
0160 }