File indexing completed on 2025-01-18 09:40:57
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef BOOST_MPI_SCATTER_HPP
0009 #define BOOST_MPI_SCATTER_HPP
0010
0011 #include <boost/mpi/exception.hpp>
0012 #include <boost/mpi/datatype.hpp>
0013 #include <vector>
0014 #include <boost/mpi/packed_oarchive.hpp>
0015 #include <boost/mpi/packed_iarchive.hpp>
0016 #include <boost/mpi/detail/point_to_point.hpp>
0017 #include <boost/mpi/communicator.hpp>
0018 #include <boost/mpi/environment.hpp>
0019 #include <boost/mpi/detail/offsets.hpp>
0020 #include <boost/mpi/detail/antiques.hpp>
0021 #include <boost/assert.hpp>
0022
0023 namespace boost { namespace mpi {
0024
0025 namespace detail {
0026
0027
0028 template<typename T>
0029 void
0030 scatter_impl(const communicator& comm, const T* in_values, T* out_values,
0031 int n, int root, mpl::true_)
0032 {
0033 MPI_Datatype type = get_mpi_datatype<T>(*in_values);
0034 BOOST_MPI_CHECK_RESULT(MPI_Scatter,
0035 (const_cast<T*>(in_values), n, type,
0036 out_values, n, type, root, comm));
0037 }
0038
0039
0040
0041 template<typename T>
0042 void
0043 scatter_impl(const communicator& comm, T* out_values, int n, int root,
0044 mpl::true_)
0045 {
0046 MPI_Datatype type = get_mpi_datatype<T>(*out_values);
0047 BOOST_MPI_CHECK_RESULT(MPI_Scatter,
0048 (0, n, type,
0049 out_values, n, type,
0050 root, comm));
0051 }
0052
0053
0054
0055
0056
0057
0058
0059
0060 template<typename T>
0061 void
0062 fill_scatter_sendbuf(const communicator& comm, T const* values,
0063 int const* nslots, int const* skipped_slots,
0064 packed_oarchive::buffer_type& sendbuf, std::vector<int>& archsizes) {
0065 int nproc = comm.size();
0066 archsizes.resize(nproc);
0067
0068 for (int dest = 0; dest < nproc; ++dest) {
0069 if (skipped_slots) {
0070 for(int k= 0; k < skipped_slots[dest]; ++k) ++values;
0071 }
0072 packed_oarchive procarchive(comm);
0073 for (int i = 0; i < nslots[dest]; ++i) {
0074 procarchive << *values++;
0075 }
0076 int archsize = procarchive.size();
0077 sendbuf.resize(sendbuf.size() + archsize);
0078 archsizes[dest] = archsize;
0079 char const* aptr = static_cast<char const*>(procarchive.address());
0080 std::copy(aptr, aptr+archsize, sendbuf.end()-archsize);
0081 }
0082 }
0083
0084 template<typename T, class A>
0085 T*
0086 non_const_data(std::vector<T,A> const& v) {
0087 using detail::c_data;
0088 return const_cast<T*>(c_data(v));
0089 }
0090
0091
0092
0093
0094 template<typename T>
0095 void
0096 dispatch_scatter_sendbuf(const communicator& comm,
0097 packed_oarchive::buffer_type const& sendbuf, std::vector<int> const& archsizes,
0098 T const* in_values,
0099 T* out_values, int n, int root) {
0100
0101 int myarchsize;
0102 BOOST_MPI_CHECK_RESULT(MPI_Scatter,
0103 (non_const_data(archsizes), 1, MPI_INT,
0104 &myarchsize, 1, MPI_INT, root, comm));
0105 std::vector<int> offsets;
0106 if (root == comm.rank()) {
0107 sizes2offsets(archsizes, offsets);
0108 }
0109
0110 packed_iarchive::buffer_type recvbuf;
0111 recvbuf.resize(myarchsize);
0112 BOOST_MPI_CHECK_RESULT(MPI_Scatterv,
0113 (non_const_data(sendbuf), non_const_data(archsizes), c_data(offsets), MPI_BYTE,
0114 c_data(recvbuf), recvbuf.size(), MPI_BYTE,
0115 root, MPI_Comm(comm)));
0116
0117 if ( in_values != 0 && root == comm.rank()) {
0118
0119 std::copy(in_values + root * n, in_values + (root + 1) * n, out_values);
0120 } else {
0121
0122 packed_iarchive iarchv(comm, recvbuf);
0123 for (int i = 0; i < n; ++i) {
0124 iarchv >> out_values[i];
0125 }
0126 }
0127 }
0128
0129
0130
0131 template<typename T>
0132 void
0133 scatter_impl(const communicator& comm, const T* in_values, T* out_values,
0134 int n, int root, mpl::false_)
0135 {
0136 packed_oarchive::buffer_type sendbuf;
0137 std::vector<int> archsizes;
0138
0139 if (root == comm.rank()) {
0140 std::vector<int> nslots(comm.size(), n);
0141 fill_scatter_sendbuf(comm, in_values, c_data(nslots), (int const*)0, sendbuf, archsizes);
0142 }
0143 dispatch_scatter_sendbuf(comm, sendbuf, archsizes, in_values, out_values, n, root);
0144 }
0145
0146 template<typename T>
0147 void
0148 scatter_impl(const communicator& comm, T* out_values, int n, int root,
0149 mpl::false_ is_mpi_type)
0150 {
0151 scatter_impl(comm, (T const*)0, out_values, n, root, is_mpi_type);
0152 }
0153 }
0154
0155 template<typename T>
0156 void
0157 scatter(const communicator& comm, const T* in_values, T& out_value, int root)
0158 {
0159 detail::scatter_impl(comm, in_values, &out_value, 1, root, is_mpi_datatype<T>());
0160 }
0161
0162 template<typename T>
0163 void
0164 scatter(const communicator& comm, const std::vector<T>& in_values, T& out_value,
0165 int root)
0166 {
0167 using detail::c_data;
0168 ::boost::mpi::scatter<T>(comm, c_data(in_values), out_value, root);
0169 }
0170
0171 template<typename T>
0172 void scatter(const communicator& comm, T& out_value, int root)
0173 {
0174 BOOST_ASSERT(comm.rank() != root);
0175 detail::scatter_impl(comm, &out_value, 1, root, is_mpi_datatype<T>());
0176 }
0177
0178 template<typename T>
0179 void
0180 scatter(const communicator& comm, const T* in_values, T* out_values, int n,
0181 int root)
0182 {
0183 detail::scatter_impl(comm, in_values, out_values, n, root, is_mpi_datatype<T>());
0184 }
0185
0186 template<typename T>
0187 void
0188 scatter(const communicator& comm, const std::vector<T>& in_values,
0189 T* out_values, int n, int root)
0190 {
0191 ::boost::mpi::scatter(comm, detail::c_data(in_values), out_values, n, root);
0192 }
0193
0194 template<typename T>
0195 void scatter(const communicator& comm, T* out_values, int n, int root)
0196 {
0197 BOOST_ASSERT(comm.rank() != root);
0198 detail::scatter_impl(comm, out_values, n, root, is_mpi_datatype<T>());
0199 }
0200
0201 } }
0202
0203 #endif