File indexing completed on 2025-01-30 10:10:13
0001
0002
0003
0004
0005
0006
0007
0008
0009 #pragma once
0010
0011 #include <atomic>
0012 #include <chrono>
0013 #include <condition_variable>
0014 #include <deque>
0015 #include <exception>
0016 #include <functional>
0017 #include <list>
0018 #include <map>
0019 #include <mutex>
0020 #include <string>
0021 #include <tuple>
0022 #include <unordered_map>
0023 #include <vector>
0024
0025 #include <sys/socket.h>
0026 #include <sys/uio.h>
0027
0028 #include "gloo/common/error.h"
0029 #include "gloo/common/memory.h"
0030 #include "gloo/transport/pair.h"
0031 #include "gloo/transport/tcp/address.h"
0032 #include "gloo/transport/tcp/device.h"
0033 #include "gloo/transport/tcp/error.h"
0034 #include "gloo/transport/tcp/socket.h"
0035
0036 namespace gloo {
0037 namespace transport {
0038 namespace tcp {
0039
0040
0041 class Buffer;
0042
0043
0044 class Context;
0045
0046
0047 class UnboundBuffer;
0048
0049
0050 constexpr auto kLargeTimeDuration = std::chrono::hours(100);
0051
0052 struct Op {
0053 enum Opcode {
0054 SEND_BUFFER = 0,
0055 SEND_UNBOUND_BUFFER = 1,
0056 NOTIFY_SEND_READY = 2,
0057 NOTIFY_RECV_READY = 3,
0058 };
0059
0060 inline enum Opcode getOpcode() {
0061 return static_cast<Opcode>(preamble.opcode);
0062 }
0063
0064 struct {
0065 size_t nbytes = 0;
0066 size_t opcode = 0;
0067 size_t slot = 0;
0068 size_t offset = 0;
0069 size_t length = 0;
0070 size_t roffset = 0;
0071 } preamble;
0072
0073
0074 Buffer* buf = nullptr;
0075 WeakNonOwningPtr<UnboundBuffer> ubuf;
0076 size_t nread = 0;
0077 size_t nwritten = 0;
0078
0079
0080 size_t offset = 0;
0081 size_t nbytes = 0;
0082 };
0083
0084 class Pair : public ::gloo::transport::Pair, public Handler {
0085 protected:
0086 enum state {
0087 INITIALIZING = 1,
0088 CONNECTING = 3,
0089 CONNECTED = 4,
0090 CLOSED = 5,
0091 };
0092
0093 public:
0094 explicit Pair(
0095 Context* context,
0096 Device* device,
0097 int rank,
0098 std::chrono::milliseconds timeout);
0099
0100 virtual ~Pair();
0101
0102 Pair(const Pair& that) = delete;
0103
0104 Pair& operator=(const Pair& that) = delete;
0105
0106 virtual const Address& address() const override;
0107
0108 virtual void connect(const std::vector<char>& bytes) override;
0109
0110 virtual void setSync(bool sync, bool busyPoll) override;
0111
0112 virtual std::unique_ptr<::gloo::transport::Buffer> createSendBuffer(
0113 int slot,
0114 void* ptr,
0115 size_t size) override;
0116
0117 virtual std::unique_ptr<::gloo::transport::Buffer> createRecvBuffer(
0118 int slot,
0119 void* ptr,
0120 size_t size) override;
0121
0122
0123 virtual void send(
0124 transport::UnboundBuffer* tbuf,
0125 uint64_t tag,
0126 size_t offset,
0127 size_t nbytes) override;
0128
0129
0130 virtual void recv(
0131 transport::UnboundBuffer* tbuf,
0132 uint64_t tag,
0133 size_t offset,
0134 size_t nbytes) override;
0135
0136
0137
0138
0139 bool tryRecv(
0140 transport::UnboundBuffer* tbuf,
0141 uint64_t tag,
0142 size_t offset,
0143 size_t nbytes);
0144
0145 void handleEvents(int events) override;
0146
0147 void close() override;
0148
0149 protected:
0150
0151
0152
0153
0154
0155 Context* const context_;
0156
0157
0158
0159
0160 Device* const device_;
0161
0162 const int rank_;
0163 state state_;
0164 std::atomic<bool> sync_;
0165 const std::chrono::milliseconds timeout_;
0166
0167
0168 bool busyPoll_;
0169 int fd_;
0170 size_t sendBufferSize_;
0171
0172 Address self_;
0173 Address peer_;
0174
0175 std::mutex m_;
0176 std::condition_variable cv_;
0177 std::map<int, Buffer*> buffers_;
0178
0179
0180 using UnboundBufferOp =
0181 std::tuple<WeakNonOwningPtr<UnboundBuffer>, size_t, size_t>;
0182
0183 std::unordered_map<uint64_t, std::deque<UnboundBufferOp>> localPendingSend_;
0184 std::unordered_map<uint64_t, std::deque<UnboundBufferOp>> localPendingRecv_;
0185
0186 void sendUnboundBuffer(
0187 WeakNonOwningPtr<UnboundBuffer> buf,
0188 uint64_t slot,
0189 size_t offset,
0190 size_t nbytes);
0191 void sendNotifyRecvReady(uint64_t slot, size_t nbytes);
0192 void sendNotifySendReady(uint64_t slot, size_t nbytes);
0193
0194 void connectCallback(std::shared_ptr<Socket> socket, Error error);
0195
0196 Buffer* getBuffer(int slot);
0197 void registerBuffer(Buffer* buf);
0198 void unregisterBuffer(Buffer* buf);
0199
0200 void sendSyncMode(Op& op);
0201 void sendAsyncMode(Op& op);
0202 void send(Op& op);
0203 void recv();
0204
0205 const Address& peer() const {
0206 return peer_;
0207 }
0208
0209 bool isSync() const {
0210 return sync_;
0211 }
0212
0213 std::chrono::milliseconds getTimeout() const {
0214 return timeout_;
0215 }
0216
0217 std::exception_ptr signalExceptionExternal(const std::string& msg);
0218
0219 friend class Buffer;
0220
0221 friend class Context;
0222
0223 protected:
0224
0225
0226 Op rx_;
0227
0228
0229
0230
0231
0232
0233
0234 std::deque<Op> tx_;
0235
0236
0237 ssize_t prepareWrite(
0238 Op& op,
0239 const NonOwningPtr<UnboundBuffer>& buf,
0240 struct iovec* iov,
0241 int& ioc);
0242
0243
0244
0245
0246
0247 virtual bool write(Op& op);
0248
0249 void writeComplete(const Op &op, NonOwningPtr<UnboundBuffer> &buf,
0250 const Op::Opcode &opcode) const;
0251
0252
0253 ssize_t prepareRead(
0254 Op& op,
0255 NonOwningPtr<UnboundBuffer>& buf,
0256 struct iovec& iov);
0257
0258
0259
0260
0261
0262 virtual bool read();
0263
0264 void readComplete(NonOwningPtr<UnboundBuffer> &buf);
0265
0266
0267 void handleRemotePendingSend(const Op& op);
0268
0269
0270 void handleRemotePendingRecv(const Op& op);
0271
0272
0273
0274
0275
0276
0277 virtual void handleReadWrite(int events);
0278
0279
0280
0281
0282
0283
0284
0285 virtual void changeState(state nextState) noexcept;
0286
0287 template<typename pred_t>
0288 void waitUntil(pred_t pred, std::unique_lock<std::mutex>& lock,
0289 bool useTimeout) {
0290 auto timeoutSet = timeout_ != kNoTimeout;
0291 if (useTimeout && timeoutSet) {
0292
0293
0294
0295
0296 auto relTime = std::min(
0297 timeout_ * 5,
0298 std::chrono::duration_cast<std::chrono::milliseconds>(kLargeTimeDuration));
0299 auto done = cv_.wait_for(lock, relTime, pred);
0300 if (!done) {
0301 signalAndThrowException(GLOO_ERROR_MSG("Connect timeout ", peer_.str()));
0302 }
0303 } else {
0304 cv_.wait(lock, pred);
0305 }
0306 }
0307
0308
0309
0310 virtual void waitUntilConnected(
0311 std::unique_lock<std::mutex>& lock, bool useTimeout);
0312
0313
0314 virtual void verifyConnected();
0315
0316
0317 void throwIfException();
0318
0319
0320
0321
0322 void signalException(const std::string& msg);
0323 void signalException(std::exception_ptr);
0324
0325
0326 void signalAndThrowException(const std::string& msg);
0327 void signalAndThrowException(std::exception_ptr ex);
0328
0329
0330
0331 std::exception_ptr ex_;
0332 };
0333
0334 }
0335 }
0336 }