File indexing completed on 2025-01-18 09:40:56
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef BOOST_MPI_ALL_REDUCE_HPP
0013 #define BOOST_MPI_ALL_REDUCE_HPP
0014
0015 #include <vector>
0016
0017 #include <boost/mpi/inplace.hpp>
0018
0019
0020 #include <boost/mpi/collectives/broadcast.hpp>
0021 #include <boost/mpi/collectives/reduce.hpp>
0022
0023 namespace boost { namespace mpi {
0024 namespace detail {
0025
0026
0027
0028
0029
0030 template<typename T, typename Op>
0031 void
0032 all_reduce_impl(const communicator& comm, const T* in_values, int n,
0033 T* out_values, Op , mpl::true_ ,
0034 mpl::true_ )
0035 {
0036 BOOST_MPI_CHECK_RESULT(MPI_Allreduce,
0037 (const_cast<T*>(in_values), out_values, n,
0038 boost::mpi::get_mpi_datatype<T>(*in_values),
0039 (is_mpi_op<Op, T>::op()), comm));
0040 }
0041
0042
0043
0044
0045
0046
0047
0048 template<typename T, typename Op>
0049 void
0050 all_reduce_impl(const communicator& comm, const T* in_values, int n,
0051 T* out_values, Op , mpl::false_ ,
0052 mpl::true_ )
0053 {
0054 user_op<Op, T> mpi_op;
0055 BOOST_MPI_CHECK_RESULT(MPI_Allreduce,
0056 (const_cast<T*>(in_values), out_values, n,
0057 boost::mpi::get_mpi_datatype<T>(*in_values),
0058 mpi_op.get_mpi_op(), comm));
0059 }
0060
0061
0062
0063
0064
0065
0066
0067 template<typename T, typename Op>
0068 void
0069 all_reduce_impl(const communicator& comm, const T* in_values, int n,
0070 T* out_values, Op op, mpl::false_ ,
0071 mpl::false_ )
0072 {
0073 if (in_values == MPI_IN_PLACE) {
0074
0075
0076
0077
0078
0079 std::vector<T> tmp_in( out_values, out_values + n);
0080 reduce(comm, detail::c_data(tmp_in), n, out_values, op, 0);
0081 } else {
0082 reduce(comm, in_values, n, out_values, op, 0);
0083 }
0084 broadcast(comm, out_values, n, 0);
0085 }
0086 }
0087
0088 template<typename T, typename Op>
0089 inline void
0090 all_reduce(const communicator& comm, const T* in_values, int n, T* out_values,
0091 Op op)
0092 {
0093 detail::all_reduce_impl(comm, in_values, n, out_values, op,
0094 is_mpi_op<Op, T>(), is_mpi_datatype<T>());
0095 }
0096
0097 template<typename T, typename Op>
0098 inline void
0099 all_reduce(const communicator& comm, inplace_t<T*> inout_values, int n, Op op)
0100 {
0101 all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), n, inout_values.buffer, op);
0102 }
0103
0104 template<typename T, typename Op>
0105 inline void
0106 all_reduce(const communicator& comm, inplace_t<T> inout_values, Op op)
0107 {
0108 all_reduce(comm, static_cast<const T*>(MPI_IN_PLACE), 1, &(inout_values.buffer), op);
0109 }
0110
0111 template<typename T, typename Op>
0112 inline void
0113 all_reduce(const communicator& comm, const T& in_value, T& out_value, Op op)
0114 {
0115 detail::all_reduce_impl(comm, &in_value, 1, &out_value, op,
0116 is_mpi_op<Op, T>(), is_mpi_datatype<T>());
0117 }
0118
0119 template<typename T, typename Op>
0120 T all_reduce(const communicator& comm, const T& in_value, Op op)
0121 {
0122 T result;
0123 ::boost::mpi::all_reduce(comm, in_value, result, op);
0124 return result;
0125 }
0126
0127 } }
0128
0129 #endif