File indexing completed on 2025-01-18 10:00:11
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include "gloo/common/logging.h"
0012 #include "gloo/context.h"
0013 #include "gloo/transport/unbound_buffer.h"
0014
0015 namespace gloo {
0016
0017 class AlltoallvOptions {
0018 public:
0019 explicit AlltoallvOptions(const std::shared_ptr<Context>& context)
0020 : context(context), timeout(context->getTimeout()) {}
0021
0022 template <typename T>
0023 void setInput(
0024 std::unique_ptr<transport::UnboundBuffer> buf,
0025 std::vector<int64_t> elementsPerRank) {
0026 setInput(std::move(buf), std::move(elementsPerRank), sizeof(T));
0027 }
0028
0029 template <typename T>
0030 void setInput(T* ptr, std::vector<int64_t> elementsPerRank) {
0031 setInput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T));
0032 }
0033
0034 template <typename T>
0035 void setOutput(
0036 std::unique_ptr<transport::UnboundBuffer> buf,
0037 std::vector<int64_t> elementsPerRank) {
0038 setOutput(std::move(buf), std::move(elementsPerRank), sizeof(T));
0039 }
0040
0041 template <typename T>
0042 void setOutput(T* ptr, std::vector<int64_t> elementsPerRank) {
0043 setOutput(static_cast<void*>(ptr), std::move(elementsPerRank), sizeof(T));
0044 }
0045
0046 void setTag(uint32_t tag) {
0047 this->tag = tag;
0048 }
0049
0050 void setTimeout(std::chrono::milliseconds timeout) {
0051 GLOO_ENFORCE(timeout.count() > 0);
0052 this->timeout = timeout;
0053 }
0054
0055 protected:
0056 std::shared_ptr<Context> context;
0057 std::unique_ptr<transport::UnboundBuffer> in;
0058 std::unique_ptr<transport::UnboundBuffer> out;
0059 std::vector<size_t> inOffsetPerRank;
0060 std::vector<size_t> inLengthPerRank;
0061 std::vector<size_t> outOffsetPerRank;
0062 std::vector<size_t> outLengthPerRank;
0063
0064
0065 size_t elementSize = 0;
0066
0067
0068
0069 uint32_t tag = 0;
0070
0071
0072 void setElementSize(size_t elementSize);
0073
0074
0075 void setInput(
0076 std::unique_ptr<transport::UnboundBuffer> buf,
0077 std::vector<int64_t> elementsPerRank,
0078 size_t elementSize);
0079
0080
0081 void
0082 setInput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize);
0083
0084
0085 void setOutput(
0086 std::unique_ptr<transport::UnboundBuffer> buf,
0087 std::vector<int64_t> elementsPerRank,
0088 size_t elementSize);
0089
0090
0091 void
0092 setOutput(void* ptr, std::vector<int64_t> elementsPerRank, size_t elementSize);
0093
0094
0095 std::chrono::milliseconds timeout;
0096
0097 friend void alltoallv(AlltoallvOptions&);
0098 };
0099
0100 void alltoallv(AlltoallvOptions& opts);
0101
0102 }