File indexing completed on 2025-01-18 10:00:10
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <functional>
0012 #include <memory>
0013 #include <vector>
0014
0015 #include "gloo/context.h"
0016 #include "gloo/transport/unbound_buffer.h"
0017
0018 namespace gloo {
0019
0020 namespace detail {
0021
0022 struct AllreduceOptionsImpl {
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036 using Func = std::function<void(void*, const void*, const void*, size_t)>;
0037
0038 enum Algorithm {
0039 UNSPECIFIED = 0,
0040 RING = 1,
0041 BCUBE = 2,
0042 };
0043
0044 explicit AllreduceOptionsImpl(const std::shared_ptr<Context>& context)
0045 : context(context),
0046 timeout(context->getTimeout()),
0047 algorithm(UNSPECIFIED) {}
0048
0049 std::shared_ptr<Context> context;
0050
0051
0052 std::chrono::milliseconds timeout;
0053
0054
0055 Algorithm algorithm;
0056
0057
0058
0059 std::vector<std::unique_ptr<transport::UnboundBuffer>> in;
0060 std::vector<std::unique_ptr<transport::UnboundBuffer>> out;
0061
0062
0063 size_t elements = 0;
0064
0065
0066 size_t elementSize = 0;
0067
0068
0069 Func reduce;
0070
0071
0072
0073 uint32_t tag = 0;
0074
0075
0076
0077
0078 static constexpr size_t kMaxSegmentSize = 1024 * 1024;
0079
0080
0081
0082
0083
0084 size_t maxSegmentSize = kMaxSegmentSize;
0085 };
0086
0087 }
0088
0089 class AllreduceOptions {
0090 public:
0091 using Func = detail::AllreduceOptionsImpl::Func;
0092 using Algorithm = detail::AllreduceOptionsImpl::Algorithm;
0093
0094 explicit AllreduceOptions(const std::shared_ptr<Context>& context)
0095 : impl_(context) {}
0096
0097 void setAlgorithm(Algorithm algorithm) {
0098 impl_.algorithm = algorithm;
0099 }
0100
0101 template <typename T>
0102 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
0103 std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
0104 bufs[0] = std::move(buf);
0105 setInputs<T>(std::move(bufs));
0106 }
0107
0108 template <typename T>
0109 void setInputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
0110 impl_.elements = bufs[0]->size / sizeof(T);
0111 impl_.elementSize = sizeof(T);
0112 impl_.in = std::move(bufs);
0113 }
0114
0115 template <typename T>
0116 void setInput(T* ptr, size_t elements) {
0117 setInputs(&ptr, 1, elements);
0118 }
0119
0120 template <typename T>
0121 void setInputs(std::vector<T*> ptrs, size_t elements) {
0122 setInputs(ptrs.data(), ptrs.size(), elements);
0123 }
0124
0125 template <typename T>
0126 void setInputs(T** ptrs, size_t len, size_t elements) {
0127 impl_.elements = elements;
0128 impl_.elementSize = sizeof(T);
0129 impl_.in.reserve(len);
0130 for (size_t i = 0; i < len; i++) {
0131 impl_.in.push_back(
0132 impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
0133 }
0134 }
0135
0136 template <typename T>
0137 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
0138 std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs(1);
0139 bufs[0] = std::move(buf);
0140 setOutputs<T>(std::move(bufs));
0141 }
0142
0143 template <typename T>
0144 void setOutputs(std::vector<std::unique_ptr<transport::UnboundBuffer>> bufs) {
0145 impl_.elements = bufs[0]->size / sizeof(T);
0146 impl_.elementSize = sizeof(T);
0147 impl_.out = std::move(bufs);
0148 }
0149
0150 template <typename T>
0151 void setOutput(T* ptr, size_t elements) {
0152 setOutputs(&ptr, 1, elements);
0153 }
0154
0155 template <typename T>
0156 void setOutputs(std::vector<T*> ptrs, size_t elements) {
0157 setOutputs(ptrs.data(), ptrs.size(), elements);
0158 }
0159
0160 template <typename T>
0161 void setOutputs(T** ptrs, size_t len, size_t elements) {
0162 impl_.elements = elements;
0163 impl_.elementSize = sizeof(T);
0164 impl_.out.reserve(len);
0165 for (size_t i = 0; i < len; i++) {
0166 impl_.out.push_back(
0167 impl_.context->createUnboundBuffer(ptrs[i], elements * sizeof(T)));
0168 }
0169 }
0170
0171 void setReduceFunction(Func fn) {
0172 impl_.reduce = fn;
0173 }
0174
0175 void setTag(uint32_t tag) {
0176 impl_.tag = tag;
0177 }
0178
0179 void setMaxSegmentSize(size_t maxSegmentSize) {
0180 impl_.maxSegmentSize = maxSegmentSize;
0181 }
0182
0183 void setTimeout(std::chrono::milliseconds timeout) {
0184 impl_.timeout = timeout;
0185 }
0186
0187 protected:
0188 detail::AllreduceOptionsImpl impl_;
0189
0190 friend void allreduce(const AllreduceOptions&);
0191 };
0192
0193 void allreduce(const AllreduceOptions& opts);
0194
0195 }