Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 09:54:50

0001 //----------------------------------*-C++-*----------------------------------//
0002 // Copyright 2020-2024 UT-Battelle, LLC, and other Celeritas developers.
0003 // See the top-level COPYRIGHT file for details.
0004 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
0005 //---------------------------------------------------------------------------//
0006 //! \file corecel/sys/MpiOperations.hh
0007 //! \brief MPI parallel functionality
0008 //---------------------------------------------------------------------------//
0009 #pragma once
0010 
0011 #include <algorithm>
0012 #include <type_traits>
0013 
0014 #include "corecel/Config.hh"
0015 
0016 #if CELERITAS_USE_MPI
0017 #    include <mpi.h>
0018 #endif
0019 
0020 #include "corecel/Assert.hh"
0021 #include "corecel/Macros.hh"
0022 #include "corecel/cont/Span.hh"
0023 
0024 #include "MpiCommunicator.hh"
0025 
0026 #include "detail/MpiType.hh"
0027 
0028 namespace celeritas
0029 {
0030 //---------------------------------------------------------------------------//
0031 // TYPES
0032 //---------------------------------------------------------------------------//
0033 //! MPI reduction operation to perform on the data
0034 enum class Operation
0035 {
0036     min,
0037     max,
0038     sum,
0039     prod,
0040     size_
0041 };
0042 
0043 //---------------------------------------------------------------------------//
0044 // FREE FUNCTIONS
0045 //---------------------------------------------------------------------------//
0046 // Wait for all processes in this communicator to reach the barrier
0047 inline void barrier(MpiCommunicator const& comm);
0048 
0049 //---------------------------------------------------------------------------//
0050 // All-to-all reduction on the data from src to dst
0051 template<class T, std::size_t N>
0052 inline void allreduce(MpiCommunicator const& comm,
0053                       Operation op,
0054                       Span<T const, N> src,
0055                       Span<T, N> dst);
0056 
0057 //---------------------------------------------------------------------------//
0058 // All-to-all reduction on the data, in place
0059 template<class T, std::size_t N>
0060 inline void
0061 allreduce(MpiCommunicator const& comm, Operation op, Span<T, N> data);
0062 
0063 //---------------------------------------------------------------------------//
0064 // Perform reduction on a fundamental scalar and return the result
0065 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool> = true>
0066 inline T allreduce(MpiCommunicator const& comm, Operation op, T const src);
0067 
0068 //---------------------------------------------------------------------------//
0069 // INLINE DEFINITIONS
0070 //---------------------------------------------------------------------------//
0071 namespace
0072 {
0073 #if CELERITAS_USE_MPI
0074 inline MPI_Op to_mpi(Operation op)
0075 {
0076     switch (op)
0077     {
0078             // clang-format off
0079         case Operation::min:  return MPI_MIN;
0080         case Operation::max:  return MPI_MAX;
0081         case Operation::sum:  return MPI_SUM;
0082         case Operation::prod: return MPI_PROD;
0083         default: CELER_ASSERT_UNREACHABLE();
0084             // clang-format on
0085     }
0086 }
0087 #endif
0088 }  // namespace
0089 
0090 //---------------------------------------------------------------------------//
0091 /*!
0092  * Wait for all processes in this communicator to reach the barrier.
0093  */
0094 void barrier(MpiCommunicator const& comm)
0095 {
0096     if (!comm)
0097         return;
0098 
0099     CELER_MPI_CALL(MPI_Barrier(comm.mpi_comm()));
0100 }
0101 
0102 //---------------------------------------------------------------------------//
0103 /*!
0104  * All-to-all reduction on the data from src to dst.
0105  */
0106 template<class T, std::size_t N>
0107 void allreduce(MpiCommunicator const& comm,
0108                [[maybe_unused]] Operation op,
0109                Span<T const, N> src,
0110                Span<T, N> dst)
0111 {
0112     CELER_EXPECT(src.size() == dst.size());
0113 
0114     if (!comm)
0115     {
0116         std::copy(src.begin(), src.end(), dst.begin());
0117         return;
0118     }
0119 
0120     CELER_MPI_CALL(MPI_Allreduce(src.data(),
0121                                  dst.data(),
0122                                  dst.size(),
0123                                  detail::MpiType<T>::value,
0124                                  to_mpi(op),
0125                                  comm.mpi_comm()));
0126 }
0127 
0128 //---------------------------------------------------------------------------//
0129 /*!
0130  * All-to-all reduction on the data, in place.
0131  */
0132 template<class T, std::size_t N>
0133 void allreduce(MpiCommunicator const& comm,
0134                [[maybe_unused]] Operation op,
0135                [[maybe_unused]] Span<T, N> data)
0136 {
0137     if (!comm)
0138         return;
0139 
0140     CELER_MPI_CALL(MPI_Allreduce(MPI_IN_PLACE,
0141                                  data.data(),
0142                                  data.size(),
0143                                  detail::MpiType<T>::value,
0144                                  to_mpi(op),
0145                                  comm.mpi_comm()));
0146 }
0147 
0148 //---------------------------------------------------------------------------//
0149 /*!
0150  * Perform reduction on a fundamental scalar and return the result.
0151  */
0152 template<class T, std::enable_if_t<std::is_fundamental<T>::value, bool>>
0153 T allreduce(MpiCommunicator const& comm, Operation op, T const src)
0154 {
0155     T dst{};
0156     allreduce(comm, op, Span<T const, 1>{&src, 1}, Span<T, 1>{&dst, 1});
0157     return dst;
0158 }
0159 
0160 //---------------------------------------------------------------------------//
0161 }  // namespace celeritas