File indexing completed on 2026-04-17 08:35:04
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020 #ifndef _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_
0021 #define _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 1
0022
0023 #include <cstdlib>
0024 #include <iostream>
0025 #include <sstream>
0026
0027 #include <openssl/sha.h>
0028
0029 #include <thrift/config.h>
0030 #include <thrift/protocol/TProtocol.h>
0031 #include <thrift/transport/TSocket.h>
0032 #include <thrift/transport/THttpServer.h>
0033 #if defined(_MSC_VER) || defined(__MINGW32__)
0034 #include <Shlwapi.h>
0035 #define THRIFT_strncasecmp(str1, str2, len) _strnicmp(str1, str2, len)
0036 #define THRIFT_strcasestr(haystack, needle) StrStrIA(haystack, needle)
0037 #else
0038 #define THRIFT_strncasecmp(str1, str2, len) strncasecmp(str1, str2, len)
0039 #define THRIFT_strcasestr(haystack, needle) strcasestr(haystack, needle)
0040 #endif
0041 #if defined(__CYGWIN__)
0042 #include <alloca.h>
0043 #endif
0044
0045 using std::string;
0046
0047 namespace apache {
0048 namespace thrift {
0049 namespace transport {
0050
0051 std::string base64Encode(unsigned char* data, int length);
0052
0053 template <bool binary>
0054 class TWebSocketServer : public THttpServer {
0055 public:
0056 TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
0057 : THttpServer(transport, config) {
0058 resetHandshake();
0059 }
0060
0061 ~TWebSocketServer() override = default;
0062
0063 uint32_t readAll_virt(uint8_t* buf, uint32_t len) override {
0064
0065 if (!handshakeComplete()) {
0066 resetHandshake();
0067 THttpServer::read(buf, len);
0068
0069
0070 if (!handshakeComplete()) {
0071 sendBadRequest();
0072 return 0;
0073 }
0074
0075 THttpServer::flush();
0076 }
0077
0078 uint32_t want = len;
0079 auto have = readBuffer_.available_read();
0080
0081
0082
0083
0084
0085 if (have > 0 && have >= want) {
0086 return readBuffer_.read(buf, want);
0087 }
0088
0089
0090 if (!readFrame()) {
0091
0092 return 0;
0093 }
0094
0095
0096 uint32_t give = (std::min)(want, readBuffer_.available_read());
0097 return readBuffer_.read(buf, give);
0098 }
0099
0100 void flush() override {
0101 resetConsumedMessageSize();
0102 writeFrameHeader();
0103 uint8_t* buffer;
0104 uint32_t length;
0105 writeBuffer_.getBuffer(&buffer, &length);
0106 transport_->write(buffer, length);
0107 transport_->flush();
0108 writeBuffer_.resetBuffer();
0109 }
0110
0111 protected:
0112 std::string getHeader(uint32_t len) override {
0113 THRIFT_UNUSED_VARIABLE(len);
0114 std::ostringstream h;
0115 h << "HTTP/1.1 101 Switching Protocols" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF
0116 << "Upgrade: websocket" << CRLF << "Connection: Upgrade" << CRLF
0117 << "Sec-WebSocket-Accept: " << acceptKey_ << CRLF << CRLF;
0118 return h.str();
0119 }
0120
0121 void parseHeader(char* header) override {
0122 char* colon = strchr(header, ':');
0123 if (colon == nullptr) {
0124 return;
0125 }
0126 size_t sz = colon - header;
0127 char* value = colon + 1;
0128
0129 if (THRIFT_strncasecmp(header, "Upgrade", sz) == 0) {
0130 if (THRIFT_strcasestr(value, "websocket") != nullptr) {
0131 upgrade_ = true;
0132 }
0133 } else if (THRIFT_strncasecmp(header, "Connection", sz) == 0) {
0134 if (THRIFT_strcasestr(value, "Upgrade") != nullptr) {
0135 connection_ = true;
0136 }
0137 } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Key", sz) == 0) {
0138 std::string toHash = value + 1;
0139 toHash += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
0140 unsigned char hash[20];
0141 SHA1((const unsigned char*)toHash.c_str(), toHash.length(), hash);
0142 acceptKey_ = base64Encode(hash, 20);
0143 secWebSocketKey_ = true;
0144 } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Version", sz) == 0) {
0145 if (THRIFT_strcasestr(value, "13") != nullptr) {
0146 secWebSocketVersion_ = true;
0147 }
0148 }
0149 }
0150
0151 bool parseStatusLine(char* status) override {
0152 char* method = status;
0153
0154 char* path = strchr(method, ' ');
0155 if (path == nullptr) {
0156 throw TTransportException(string("Bad Status: ") + status);
0157 }
0158
0159 *path = '\0';
0160 while (*(++path) == ' ') {
0161 };
0162
0163 char* http = strchr(path, ' ');
0164 if (http == nullptr) {
0165 throw TTransportException(string("Bad Status: ") + status);
0166 }
0167 *http = '\0';
0168
0169 if (strcmp(method, "GET") == 0) {
0170
0171 return true;
0172 }
0173 throw TTransportException(string("Bad Status (unsupported method): ") + status);
0174 }
0175
0176 private:
0177 enum class CloseCode : uint16_t {
0178 NormalClosure = 1000,
0179 GoingAway = 1001,
0180 ProtocolError = 1002,
0181 UnsupportedDataType = 1003,
0182 NoStatusCode = 1005,
0183 AbnormalClosure = 1006,
0184 InvalidData = 1007,
0185 PolicyViolation = 1008,
0186 MessageTooBig = 1009,
0187 ExtensionExpected = 1010,
0188 UnexpectedError = 1011,
0189 NotSecure = 1015
0190 };
0191
0192 enum class Opcode : uint8_t {
0193 Continuation = 0x0,
0194 Text = 0x1,
0195 Binary = 0x2,
0196 Close = 0x8,
0197 Ping = 0x9,
0198 Pong = 0xA
0199 };
0200
0201 void failConnection(CloseCode reason) {
0202 writeFrameHeader(Opcode::Close);
0203 auto buffer = htons(static_cast<uint16_t>(reason));
0204 transport_->write(reinterpret_cast<const uint8_t*>(&buffer), 2);
0205 transport_->flush();
0206 transport_->close();
0207 }
0208
0209 bool handshakeComplete() {
0210 return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
0211 }
0212
0213 void pong() {
0214 writeFrameHeader(Opcode::Pong);
0215 uint8_t* buffer;
0216 uint32_t size;
0217 readBuffer_.getBuffer(&buffer, &size);
0218 transport_->write(buffer, size);
0219 transport_->flush();
0220 }
0221
0222 bool readFrame() {
0223 uint8_t headerBuffer[8];
0224
0225 auto read = transport_->read(headerBuffer, 2);
0226 if (read < 2) {
0227 return false;
0228 }
0229
0230
0231
0232 auto fin = (headerBuffer[0] & 0x80) != 0;
0233 THRIFT_UNUSED_VARIABLE(fin);
0234
0235
0236 if ((headerBuffer[0] & 0x70) != 0) {
0237 failConnection(CloseCode::ProtocolError);
0238 throw TTransportException(TTransportException::CORRUPTED_DATA,
0239 "Reserved bits must be zeroes");
0240 }
0241
0242 auto opcode = (Opcode)(headerBuffer[0] & 0x0F);
0243
0244
0245 if ((headerBuffer[1] & 0x80) == 0) {
0246 failConnection(CloseCode::ProtocolError);
0247 throw TTransportException(TTransportException::CORRUPTED_DATA,
0248 "Messages from the client must be masked");
0249 }
0250
0251
0252 uint64_t payloadLength = headerBuffer[1] & 0x7F;
0253 if (payloadLength == 126) {
0254 read = transport_->read(headerBuffer, 2);
0255 if (read < 2) {
0256 return false;
0257 }
0258 payloadLength = ntohs(*reinterpret_cast<uint16_t*>(headerBuffer));
0259 } else if (payloadLength == 127) {
0260 read = transport_->read(headerBuffer, 8);
0261 if (read < 8) {
0262 return false;
0263 }
0264 payloadLength = THRIFT_ntohll(*reinterpret_cast<uint64_t*>(headerBuffer));
0265 if ((payloadLength & 0x8000000000000000) != 0) {
0266 failConnection(CloseCode::ProtocolError);
0267 throw TTransportException(
0268 TTransportException::CORRUPTED_DATA,
0269 "The most significant bit of the payload length must be zero");
0270 }
0271 }
0272
0273
0274 if (payloadLength > UINT32_MAX) {
0275 failConnection(CloseCode::MessageTooBig);
0276 return false;
0277 }
0278
0279 auto length = static_cast<uint32_t>(payloadLength);
0280
0281 if (length > 0) {
0282
0283 read = transport_->read(headerBuffer, 4);
0284 if (read < 4) {
0285 return false;
0286 }
0287
0288 readBuffer_.resetBuffer(length);
0289 uint8_t* buffer = readBuffer_.getWritePtr(length);
0290 read = transport_->read(buffer, length);
0291 readBuffer_.wroteBytes(read);
0292 if (read < length) {
0293 return false;
0294 }
0295
0296
0297 for (size_t i = 0; i < length; i++) {
0298 buffer[i] ^= headerBuffer[i % 4];
0299 }
0300
0301 T_DEBUG("FIN=%d, Opcode=%X, length=%d, payload=%s", fin, opcode, length,
0302 binary ? readBuffer_.toHexString() : cast(string) readBuffer_);
0303 }
0304
0305 switch (opcode) {
0306 case Opcode::Close:
0307 if (length >= 2) {
0308 uint8_t buffer[2];
0309 readBuffer_.read(buffer, 2);
0310 CloseCode closeCode = static_cast<CloseCode>(ntohs(*reinterpret_cast<uint16_t*>(buffer)));
0311 THRIFT_UNUSED_VARIABLE(closeCode);
0312 string closeReason = readBuffer_.readAsString(length - 2);
0313 T_DEBUG("Connection closed: %d %s", closeCode, closeReason);
0314 }
0315 transport_->close();
0316 return false;
0317 case Opcode::Ping:
0318 pong();
0319 return readFrame();
0320 default:
0321 return true;
0322 }
0323 }
0324
0325 void resetHandshake() {
0326 connection_ = false;
0327 secWebSocketKey_ = false;
0328 secWebSocketVersion_ = false;
0329 upgrade_ = false;
0330 }
0331
0332 void sendBadRequest() {
0333 std::ostringstream h;
0334 h << "HTTP/1.1 400 Bad Request" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF << CRLF;
0335 std::string header = h.str();
0336 transport_->write(reinterpret_cast<const uint8_t*>(header.data()), static_cast<uint32_t>(header.length()));
0337 transport_->flush();
0338 transport_->close();
0339 }
0340
0341 void writeFrameHeader(Opcode opcode = Opcode::Continuation) {
0342 uint32_t headerSize = 1;
0343 uint32_t length = writeBuffer_.available_read();
0344 if (length < 126) {
0345 ++headerSize;
0346 } else if (length < 65536) {
0347 headerSize += 3;
0348 } else {
0349 headerSize += 9;
0350 }
0351
0352
0353 uint8_t* header = static_cast<uint8_t*>(alloca(headerSize));
0354 if (opcode == Opcode::Continuation) {
0355 opcode = binary ? Opcode::Binary : Opcode::Text;
0356 }
0357 header[0] = static_cast<uint8_t>(opcode) | 0x80;
0358 if (length < 126) {
0359 header[1] = static_cast<uint8_t>(length);
0360 } else if (length < 65536) {
0361 header[1] = 126;
0362 *reinterpret_cast<uint16_t*>(header + 2) = htons(length);
0363 } else {
0364 header[1] = 127;
0365 *reinterpret_cast<uint64_t*>(header + 2) = THRIFT_htonll(length);
0366 }
0367
0368 transport_->write(header, headerSize);
0369 }
0370
0371
0372 constexpr static const char* CRLF = "\r\n";
0373 std::string acceptKey_;
0374 bool connection_;
0375 bool secWebSocketKey_;
0376 bool secWebSocketVersion_;
0377 bool upgrade_;
0378 };
0379
0380
0381
0382
0383 class TBinaryWebSocketServerTransportFactory : public TTransportFactory {
0384 public:
0385 TBinaryWebSocketServerTransportFactory() = default;
0386
0387 ~TBinaryWebSocketServerTransportFactory() override = default;
0388
0389
0390
0391
0392 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
0393 return std::shared_ptr<TTransport>(new TWebSocketServer<true>(trans));
0394 }
0395 };
0396
0397
0398
0399
0400 class TTextWebSocketServerTransportFactory : public TTransportFactory {
0401 public:
0402 TTextWebSocketServerTransportFactory() = default;
0403
0404 ~TTextWebSocketServerTransportFactory() override = default;
0405
0406
0407
0408
0409 std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
0410 return std::shared_ptr<TTransport>(new TWebSocketServer<false>(trans));
0411 }
0412 };
0413 }
0414 }
0415 }
0416 #endif