Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:00:10

0001 /**
0002  * Copyright (c) 2018-present, Facebook, Inc.
0003  * All rights reserved.
0004  *
0005  * This source code is licensed under the BSD-style license found in the
0006  * LICENSE file in the root directory of this source tree.
0007  */
0008 
0009 #pragma once
0010 
0011 #include <functional>
0012 #include <memory>
0013 #include <vector>
0014 
0015 #include "gloo/context.h"
0016 #include "gloo/transport/unbound_buffer.h"
0017 
0018 namespace gloo {
0019 
0020 namespace detail {
0021 
0022 struct AllreduceOptionsImpl {
0023   // This type describes the function to use for element wise reduction.
0024   //
0025   // Its arguments are:
0026   //   1. non-const output pointer
0027   //   2. const input pointer 1 (may be equal to 1)
0028   //   3. const input pointer 2 (may be equal to 1)
0029   //   4. number of elements to reduce.
0030   //
0031   // Note that this function is not strictly typed and takes void pointers.
0032   // This is specifically done to avoid the need for a templated options class
0033   // and templated algorithm implementations. We found this adds very little
0034   // value for the increase in compilation time and code size.
0035   //
0036   using Func = std::function<void(void*, const void*, const void*, size_t)>;
0037 
0038   enum Algorithm {
0039     UNSPECIFIED = 0,
0040     RING = 1,
0041     BCUBE = 2,
0042   };
0043 
0044   explicit AllreduceOptionsImpl(const std::shared_ptr<Context>& context)
0045       : context(context),
0046         timeout(context->getTimeout()),
0047         algorithm(UNSPECIFIED) {}
0048 
0049   std::shared_ptr<Context> context;
0050 
0051   // End-to-end timeout for this operation.
0052   std::chrono::milliseconds timeout;
0053 
0054   // Algorithm selection.
0055   Algorithm algorithm;
0056 
0057   // Input and output buffers.
0058   // The output is used as input if input is not specified.
0059   std::vector<std::unique_ptr<transport::UnboundBuffer>> in;
0060   std::vector<std::unique_ptr<transport::UnboundBuffer>> out;
0061 
0062   // Number of elements.
0063   size_t elements = 0;
0064 
0065   // Number of bytes per element.
0066   size_t elementSize = 0;
0067 
0068   // Reduction function.
0069   Func reduce;
0070 
0071   // Tag for this operation.
0072   // Must be unique across operations executing in parallel.
0073   uint32_t tag = 0;
0074 
0075   // This is the maximum size of each I/O operation (send/recv) of which
0076   // two are in flight at all times. A smaller value leads to more
0077   // overhead and a larger value leads to poor cache behavior.
0078   static constexpr size_t kMaxSegmentSize = 1024 * 1024;
0079 
0080   // Internal use only. This is used to exercise code paths where we
0081   // have more than 2 segments per rank without making the tests slow
0082   // (because they would require millions of elements if the default
0083   // were not configurable).
0084   size_t maxSegmentSize = kMaxSegmentSize;
0085 };
0086 
0087 } // namespace detail
0088 
0089 class AllreduceOptions {
0090  public:
0091   using Func = detail::AllreduceOptionsImpl::Func;
0092   using Algorithm = detail::AllreduceOptionsImpl::Algorithm;
0093 
0094   explicit AllreduceOptions(const std::shared_ptr<Context>& context)
0095       : impl_(context) {}
0096 
0097   void setAlgorithm(Algorithm algorithm) {
0098     impl_.algorithm = algorithm;
0099   }
0100 
0101   template <typename T>
0102   void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
0103     std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
0104     bufs[0] = std::move(buf);
0105     setInputs<T>(std::move(bufs));
0106   }
0107 
0108   template <typename T>
0109   void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
0110     impl_.elements = bufs[0]->size / sizeof(T);
0111     impl_.elementSize = sizeof(T);
0112     impl_.in = std::move(bufs);
0113   }
0114 
0115   template <typename T>
0116   void setInput(T* ptr, size_t elements) {
0117     setInputs(&ptr, 1, elements);
0118   }
0119 
0120   template <typename T>
0121   void setInputs(std::vector<T*> ptrs, size_t elements) {
0122     setInputs(ptrs.data(), ptrs.size(), elements);
0123   }
0124 
0125   template <typename T>
0126   void setInputs(T** ptrs, size_t len, size_t elements) {
0127     impl_.elements = elements;
0128     impl_.elementSize = sizeof(T);
0129     impl_.in.reserve(len);
0130     for (size_t i = 0; i < len; i++) {
0131       impl_.in.push_back(
0132           impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
0133     }
0134   }
0135 
0136   template <typename T>
0137   void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
0138     std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
0139     bufs[0] = std::move(buf);
0140     setOutputs<T>(std::move(bufs));
0141   }
0142 
0143   template <typename T>
0144   void setOutputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
0145     impl_.elements = bufs[0]->size / sizeof(T);
0146     impl_.elementSize = sizeof(T);
0147     impl_.out = std::move(bufs);
0148   }
0149 
0150   template <typename T>
0151   void setOutput(T* ptr, size_t elements) {
0152     setOutputs(&ptr, 1, elements);
0153   }
0154 
0155   template <typename T>
0156   void setOutputs(std::vector<T*> ptrs, size_t elements) {
0157     setOutputs(ptrs.data(), ptrs.size(), elements);
0158   }
0159 
0160   template <typename T>
0161   void setOutputs(T** ptrs, size_t len, size_t elements) {
0162     impl_.elements = elements;
0163     impl_.elementSize = sizeof(T);
0164     impl_.out.reserve(len);
0165     for (size_t i = 0; i < len; i++) {
0166       impl_.out.push_back(
0167           impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
0168     }
0169   }
0170 
0171   void setReduceFunction(Func fn) {
0172     impl_.reduce = fn;
0173   }
0174 
0175   void setTag(uint32_t tag) {
0176     impl_.tag = tag;
0177   }
0178 
0179   void setMaxSegmentSize(size_t maxSegmentSize) {
0180     impl_.maxSegmentSize = maxSegmentSize;
0181   }
0182 
0183   void setTimeout(std::chrono::milliseconds timeout) {
0184     impl_.timeout = timeout;
0185   }
0186 
0187  protected:
0188   detail::AllreduceOptionsImpl impl_;
0189 
0190   friend void allreduce(const AllreduceOptions&);
0191 };
0192 
0193 void allreduce(const AllreduceOptions& opts);
0194 
0195 } // namespace gloo