File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "gloo/common/logging.h"
0012 #include "gloo/context.h"
0013 #include "gloo/transport/unbound_buffer.h"
0014
0015 namespace gloo {
0016
0017 class AlltoallOptions {
0018 public:
0019 explicit AlltoallOptions(const std::shared_ptr<Context>& context)
0020 : context(context), timeout(context->getTimeout()) {}
0021
0022 template <typename T>
0023 void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
0024 elementSize = sizeof(T);
0025 in = std::move(buf);
0026 }
0027
0028 template <typename T>
0029 void setInput(T* ptr, size_t elements) {
0030 elementSize = sizeof(T);
0031 in = context->createUnboundBuffer(ptr, elements * sizeof(T));
0032 }
0033
0034 template <typename T>
0035 void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
0036 elementSize = sizeof(T);
0037 out = std::move(buf);
0038 }
0039
0040 template <typename T>
0041 void setOutput(T* ptr, size_t elements) {
0042 elementSize = sizeof(T);
0043 out = context->createUnboundBuffer(ptr, elements * sizeof(T));
0044 }
0045
0046 void setTag(uint32_t tag) {
0047 this->tag = tag;
0048 }
0049
0050 void setTimeout(std::chrono::milliseconds timeout) {
0051 GLOO_ENFORCE(timeout.count() > 0);
0052 this->timeout = timeout;
0053 }
0054
0055 protected:
0056 std::shared_ptr<Context> context;
0057 std::unique_ptr<transport::UnboundBuffer> in;
0058 std::unique_ptr<transport::UnboundBuffer> out;
0059
0060
0061 size_t elementSize = 0;
0062
0063
0064
0065 uint32_t tag = 0;
0066
0067
0068 std::chrono::milliseconds timeout;
0069
0070 friend void alltoall(AlltoallOptions&);
0071 };
0072
0073 void alltoall(AlltoallOptions& opts);
0074
0075 }