File indexing completed on 2025-10-31 08:59:09
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 #pragma once
0009 
0010 #include <algorithm>
0011 #include <type_traits>
0012 
0013 #include "corecel/Config.hh"
0014 
0015 #if CELERITAS_USE_MPI
0016 #    include <mpi.h>
0017 #endif
0018 
0019 #include "corecel/Assert.hh"
0020 #include "corecel/Macros.hh"
0021 #include "corecel/cont/Span.hh"
0022 
0023 #include "MpiCommunicator.hh"
0024 
0025 #include "detail/MpiType.hh"
0026 
0027 namespace celeritas
0028 {
0029 
0030 
0031 
0032 
0033 enum class Operation
0034 {
0035     min,
0036     max,
0037     sum,
0038     prod,
0039     size_
0040 };
0041 
0042 
0043 
0044 
0045 
0046 inline void barrier(MpiCommunicator const& comm);
0047 
0048 
0049 
0050 template<class T, std::size_t N>
0051 inline void allreduce(MpiCommunicator const& comm,
0052                       Operation op,
0053                       Span<T const, N> src,
0054                       Span<T, N> dst);
0055 
0056 
0057 
0058 template<class T, std::size_t N>
0059 inline void
0060 allreduce(MpiCommunicator const& comm, Operation op, Span<T, N> data);
0061 
0062 
0063 
0064 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool> = true>
0065 inline T allreduce(MpiCommunicator const& comm, Operation op, T const src);
0066 
0067 
0068 
0069 
0070 namespace
0071 {
0072 #if CELERITAS_USE_MPI
0073 inline MPI_Op to_mpi(Operation op)
0074 {
0075     switch (op)
0076     {
0077             
0078         case Operation::min:  return MPI_MIN;
0079         case Operation::max:  return MPI_MAX;
0080         case Operation::sum:  return MPI_SUM;
0081         case Operation::prod: return MPI_PROD;
0082         default: CELER_ASSERT_UNREACHABLE();
0083             
0084     }
0085 }
0086 #endif
0087 }  
0088 
0089 
0090 
0091 
0092 
0093 void barrier(MpiCommunicator const& comm)
0094 {
0095     if (!comm)
0096         return;
0097 
0098     CELER_MPI_CALL(MPI_Barrier(comm.mpi_comm()));
0099 }
0100 
0101 
0102 
0103 
0104 
0105 template<class T, std::size_t N>
0106 void allreduce(MpiCommunicator const& comm,
0107                [[maybe_unused]] Operation op,
0108                Span<T const, N> src,
0109                Span<T, N> dst)
0110 {
0111     CELER_EXPECT(src.size() == dst.size());
0112 
0113     if (!comm)
0114     {
0115         std::copy(src.begin(), src.end(), dst.begin());
0116         return;
0117     }
0118 
0119     CELER_MPI_CALL(MPI_Allreduce(src.data(),
0120                                  dst.data(),
0121                                  dst.size(),
0122                                  detail::MpiType<T>::value,
0123                                  to_mpi(op),
0124                                  comm.mpi_comm()));
0125 }
0126 
0127 
0128 
0129 
0130 
0131 template<class T, std::size_t N>
0132 void allreduce(MpiCommunicator const& comm,
0133                [[maybe_unused]] Operation op,
0134                [[maybe_unused]] Span<T, N> data)
0135 {
0136     if (!comm)
0137         return;
0138 
0139     CELER_MPI_CALL(MPI_Allreduce(MPI_IN_PLACE,
0140                                  data.data(),
0141                                  data.size(),
0142                                  detail::MpiType<T>::value,
0143                                  to_mpi(op),
0144                                  comm.mpi_comm()));
0145 }
0146 
0147 
0148 
0149 
0150 
0151 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool>>
0152 T allreduce(MpiCommunicator const& comm, Operation op, T const src)
0153 {
0154     T dst{};
0155     allreduce(comm, op, Span<T const, 1>{&src, 1}, Span<T, 1>{&dst, 1});
0156     return dst;
0157 }
0158 
0159 
0160 }