File indexing completed on 2025-01-18 09:54:50
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <algorithm>
0012 #include <type_traits>
0013
0014 #include "corecel/Config.hh"
0015
0016 #if CELERITAS_USE_MPI
0017 # include <mpi.h>
0018 #endif
0019
0020 #include "corecel/Assert.hh"
0021 #include "corecel/Macros.hh"
0022 #include "corecel/cont/Span.hh"
0023
0024 #include "MpiCommunicator.hh"
0025
0026 #include "detail/MpiType.hh"
0027
0028 namespace celeritas
0029 {
0030
0031
0032
0033
0034 enum class Operation
0035 {
0036 min,
0037 max,
0038 sum,
0039 prod,
0040 size_
0041 };
0042
0043
0044
0045
0046
0047 inline void barrier(MpiCommunicator const& comm);
0048
0049
0050
0051 template<class T, std::size_t N>
0052 inline void allreduce(MpiCommunicator const& comm,
0053 Operation op,
0054 Span<T const, N> src,
0055 Span<T, N> dst);
0056
0057
0058
0059 template<class T, std::size_t N>
0060 inline void
0061 allreduce(MpiCommunicator const& comm, Operation op, Span<T, N> data);
0062
0063
0064
0065 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool> = true>
0066 inline T allreduce(MpiCommunicator const& comm, Operation op, T const src);
0067
0068
0069
0070
0071 namespace
0072 {
0073 #if CELERITAS_USE_MPI
0074 inline MPI_Op to_mpi(Operation op)
0075 {
0076 switch (op)
0077 {
0078
0079 case Operation::min: return MPI_MIN;
0080 case Operation::max: return MPI_MAX;
0081 case Operation::sum: return MPI_SUM;
0082 case Operation::prod: return MPI_PROD;
0083 default: CELER_ASSERT_UNREACHABLE();
0084
0085 }
0086 }
0087 #endif
0088 }
0089
0090
0091
0092
0093
0094 void barrier(MpiCommunicator const& comm)
0095 {
0096 if (!comm)
0097 return;
0098
0099 CELER_MPI_CALL(MPI_Barrier(comm.mpi_comm()));
0100 }
0101
0102
0103
0104
0105
0106 template<class T, std::size_t N>
0107 void allreduce(MpiCommunicator const& comm,
0108 [[maybe_unused]] Operation op,
0109 Span<T const, N> src,
0110 Span<T, N> dst)
0111 {
0112 CELER_EXPECT(src.size() == dst.size());
0113
0114 if (!comm)
0115 {
0116 std::copy(src.begin(), src.end(), dst.begin());
0117 return;
0118 }
0119
0120 CELER_MPI_CALL(MPI_Allreduce(src.data(),
0121 dst.data(),
0122 dst.size(),
0123 detail::MpiType<T>::value,
0124 to_mpi(op),
0125 comm.mpi_comm()));
0126 }
0127
0128
0129
0130
0131
0132 template<class T, std::size_t N>
0133 void allreduce(MpiCommunicator const& comm,
0134 [[maybe_unused]] Operation op,
0135 [[maybe_unused]] Span<T, N> data)
0136 {
0137 if (!comm)
0138 return;
0139
0140 CELER_MPI_CALL(MPI_Allreduce(MPI_IN_PLACE,
0141 data.data(),
0142 data.size(),
0143 detail::MpiType<T>::value,
0144 to_mpi(op),
0145 comm.mpi_comm()));
0146 }
0147
0148
0149
0150
0151
0152 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool>>
0153 T allreduce(MpiCommunicator const& comm, Operation op, T const src)
0154 {
0155 T dst{};
0156 allreduce(comm, op, Span<T const, 1>{&src, 1}, Span<T, 1>{&dst, 1});
0157 return dst;
0158 }
0159
0160
0161 }