Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2026-04-17 08:35:04

0001 /*
0002  * Licensed to the Apache Software Foundation (ASF) under one
0003  * or more contributor license agreements. See the NOTICE file
0004  * distributed with this work for additional information
0005  * regarding copyright ownership. The ASF licenses this file
0006  * to you under the Apache License, Version 2.0 (the
0007  * "License"); you may not use this file except in compliance
0008  * with the License. You may obtain a copy of the License at
0009  *
0010  *   http://www.apache.org/licenses/LICENSE-2.0
0011  *
0012  * Unless required by applicable law or agreed to in writing,
0013  * software distributed under the License is distributed on an
0014  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
0015  * KIND, either express or implied. See the License for the
0016  * specific language governing permissions and limitations
0017  * under the License.
0018  */
0019 
0020 #ifndef _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_
0021 #define _THRIFT_TRANSPORT_TWEBSOCKETSERVER_H_ 1
0022 
0023 #include <cstdlib>
0024 #include <iostream>
0025 #include <sstream>
0026 
0027 #include <openssl/sha.h>
0028 
0029 #include <thrift/config.h>
0030 #include <thrift/protocol/TProtocol.h>
0031 #include <thrift/transport/TSocket.h>
0032 #include <thrift/transport/THttpServer.h>
0033 #if defined(_MSC_VER) || defined(__MINGW32__)
0034 #include <Shlwapi.h>
0035 #define THRIFT_strncasecmp(str1, str2, len) _strnicmp(str1, str2, len)
0036 #define THRIFT_strcasestr(haystack, needle) StrStrIA(haystack, needle)
0037 #else
0038 #define THRIFT_strncasecmp(str1, str2, len) strncasecmp(str1, str2, len)
0039 #define THRIFT_strcasestr(haystack, needle) strcasestr(haystack, needle)
0040 #endif
0041 #if defined(__CYGWIN__)
0042 #include <alloca.h>
0043 #endif
0044 
0045 using std::string;
0046 
0047 namespace apache {
0048 namespace thrift {
0049 namespace transport {
0050 
0051 std::string base64Encode(unsigned char* data, int length);
0052 
0053 template <bool binary>
0054 class TWebSocketServer : public THttpServer {
0055 public:
0056   TWebSocketServer(std::shared_ptr<TTransport> transport, std::shared_ptr<TConfiguration> config = nullptr)
0057     : THttpServer(transport, config) {
0058       resetHandshake();
0059   }
0060 
0061   ~TWebSocketServer() override = default;
0062 
0063   uint32_t readAll_virt(uint8_t* buf, uint32_t len) override {
0064     // If we do not have a good handshake, the client will attempt one.
0065     if (!handshakeComplete()) {
0066       resetHandshake();
0067       THttpServer::read(buf, len);
0068       // If we did not get everything we expected, the handshake failed
0069       // and we need to send a 400 response back.
0070       if (!handshakeComplete()) {
0071         sendBadRequest();
0072         return 0;
0073       }
0074       // Otherwise, send back the 101 response.
0075       THttpServer::flush();
0076     }
0077 
0078     uint32_t want = len;
0079     auto have = readBuffer_.available_read();
0080 
0081     // If we have some data in the buffer, copy it out and return it.
0082     // We have to return it without attempting to read more, since we aren't
0083     // guaranteed that the underlying transport actually has more data, so
0084     // attempting to read from it could block.
0085     if (have > 0 && have >= want) {
0086       return readBuffer_.read(buf, want);
0087     }
0088 
0089     // Read another frame.
0090     if (!readFrame()) {
0091       // EOF.  No frame available.
0092       return 0;
0093     }
0094 
0095     // Hand over whatever we have.
0096     uint32_t give = (std::min)(want, readBuffer_.available_read());
0097     return readBuffer_.read(buf, give);
0098   }
0099 
0100   void flush() override {
0101     resetConsumedMessageSize();
0102     writeFrameHeader();
0103     uint8_t* buffer;
0104     uint32_t length;
0105     writeBuffer_.getBuffer(&buffer, &length);
0106     transport_->write(buffer, length);
0107     transport_->flush();
0108     writeBuffer_.resetBuffer();
0109   }
0110 
0111 protected:
0112   std::string getHeader(uint32_t len) override {
0113     THRIFT_UNUSED_VARIABLE(len);
0114     std::ostringstream h;
0115     h << "HTTP/1.1 101 Switching Protocols" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF
0116       << "Upgrade: websocket" << CRLF << "Connection: Upgrade" << CRLF
0117       << "Sec-WebSocket-Accept: " << acceptKey_ << CRLF << CRLF;
0118     return h.str();
0119   }
0120 
0121   void parseHeader(char* header) override {
0122     char* colon = strchr(header, ':');
0123     if (colon == nullptr) {
0124       return;
0125     }
0126     size_t sz = colon - header;
0127     char* value = colon + 1;
0128 
0129     if (THRIFT_strncasecmp(header, "Upgrade", sz) == 0) {
0130       if (THRIFT_strcasestr(value, "websocket") != nullptr) {
0131         upgrade_ = true;
0132       }
0133     } else if (THRIFT_strncasecmp(header, "Connection", sz) == 0) {
0134       if (THRIFT_strcasestr(value, "Upgrade") != nullptr) {
0135         connection_ = true;
0136       }
0137     } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Key", sz) == 0) {
0138       std::string toHash = value + 1;
0139       toHash += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
0140       unsigned char hash[20];
0141       SHA1((const unsigned char*)toHash.c_str(), toHash.length(), hash);
0142       acceptKey_ = base64Encode(hash, 20);
0143       secWebSocketKey_ = true;
0144     } else if (THRIFT_strncasecmp(header, "Sec-WebSocket-Version", sz) == 0) {
0145       if (THRIFT_strcasestr(value, "13") != nullptr) {
0146         secWebSocketVersion_ = true;
0147       }
0148     }
0149   }
0150 
0151   bool parseStatusLine(char* status) override {
0152     char* method = status;
0153 
0154     char* path = strchr(method, ' ');
0155     if (path == nullptr) {
0156       throw TTransportException(string("Bad Status: ") + status);
0157     }
0158 
0159     *path = '\0';
0160     while (*(++path) == ' ') {
0161     };
0162 
0163     char* http = strchr(path, ' ');
0164     if (http == nullptr) {
0165       throw TTransportException(string("Bad Status: ") + status);
0166     }
0167     *http = '\0';
0168 
0169     if (strcmp(method, "GET") == 0) {
0170       // GET method ok, looking for content.
0171       return true;
0172     }
0173     throw TTransportException(string("Bad Status (unsupported method): ") + status);
0174   }
0175 
0176 private:
0177   enum class CloseCode : uint16_t {
0178     NormalClosure = 1000,
0179     GoingAway = 1001,
0180     ProtocolError = 1002,
0181     UnsupportedDataType = 1003,
0182     NoStatusCode = 1005,
0183     AbnormalClosure = 1006,
0184     InvalidData = 1007,
0185     PolicyViolation = 1008,
0186     MessageTooBig = 1009,
0187     ExtensionExpected = 1010,
0188     UnexpectedError = 1011,
0189     NotSecure = 1015
0190   };
0191 
0192   enum class Opcode : uint8_t {
0193     Continuation = 0x0,
0194     Text = 0x1,
0195     Binary = 0x2,
0196     Close = 0x8,
0197     Ping = 0x9,
0198     Pong = 0xA
0199   };
0200 
0201   void failConnection(CloseCode reason) {
0202     writeFrameHeader(Opcode::Close);
0203     auto buffer = htons(static_cast<uint16_t>(reason));
0204     transport_->write(reinterpret_cast<const uint8_t*>(&buffer), 2);
0205     transport_->flush();
0206     transport_->close();
0207   }
0208 
0209   bool handshakeComplete() {
0210     return upgrade_ && connection_ && secWebSocketKey_ && secWebSocketVersion_;
0211   }
0212 
0213   void pong() {
0214     writeFrameHeader(Opcode::Pong);
0215     uint8_t* buffer;
0216     uint32_t size;
0217     readBuffer_.getBuffer(&buffer, &size);
0218     transport_->write(buffer, size);
0219     transport_->flush();
0220   }
0221 
0222   bool readFrame() {
0223     uint8_t headerBuffer[8];
0224 
0225     auto read = transport_->read(headerBuffer, 2);
0226     if (read < 2) {
0227       return false;
0228     }
0229     // Since Thrift has its own message end marker and we read frame by frame,
0230     // it doesn't really matter if the frame is marked as FIN.
0231     // Capture it only for debugging only.
0232     auto fin = (headerBuffer[0] & 0x80) != 0;
0233     THRIFT_UNUSED_VARIABLE(fin);
0234 
0235     // RSV1, RSV2, RSV3
0236     if ((headerBuffer[0] & 0x70) != 0) {
0237       failConnection(CloseCode::ProtocolError);
0238       throw TTransportException(TTransportException::CORRUPTED_DATA,
0239                                 "Reserved bits must be zeroes");
0240     }
0241 
0242     auto opcode = (Opcode)(headerBuffer[0] & 0x0F);
0243 
0244     // Mask
0245     if ((headerBuffer[1] & 0x80) == 0) {
0246       failConnection(CloseCode::ProtocolError);
0247       throw TTransportException(TTransportException::CORRUPTED_DATA,
0248                                 "Messages from the client must be masked");
0249     }
0250 
0251     // Read the length
0252     uint64_t payloadLength = headerBuffer[1] & 0x7F;
0253     if (payloadLength == 126) {
0254       read = transport_->read(headerBuffer, 2);
0255       if (read < 2) {
0256         return false;
0257       }
0258       payloadLength = ntohs(*reinterpret_cast<uint16_t*>(headerBuffer));
0259     } else if (payloadLength == 127) {
0260       read = transport_->read(headerBuffer, 8);
0261       if (read < 8) {
0262         return false;
0263       }
0264       payloadLength = THRIFT_ntohll(*reinterpret_cast<uint64_t*>(headerBuffer));
0265       if ((payloadLength & 0x8000000000000000) != 0) {
0266         failConnection(CloseCode::ProtocolError);
0267         throw TTransportException(
0268             TTransportException::CORRUPTED_DATA,
0269             "The most significant bit of the payload length must be zero");
0270       }
0271     }
0272 
0273     // size_t is smaller than a ulong on a 32-bit system
0274     if (payloadLength > UINT32_MAX) {
0275       failConnection(CloseCode::MessageTooBig);
0276       return false;
0277     }
0278 
0279     auto length = static_cast<uint32_t>(payloadLength);
0280 
0281     if (length > 0) {
0282       // Read the masking key
0283       read = transport_->read(headerBuffer, 4);
0284       if (read < 4) {
0285         return false;
0286       }
0287 
0288       readBuffer_.resetBuffer(length);
0289       uint8_t* buffer = readBuffer_.getWritePtr(length);
0290       read = transport_->read(buffer, length);
0291       readBuffer_.wroteBytes(read);
0292       if (read < length) {
0293         return false;
0294       }
0295       
0296       // Unmask the data
0297       for (size_t i = 0; i < length; i++) {
0298         buffer[i] ^= headerBuffer[i % 4];
0299       }
0300 
0301       T_DEBUG("FIN=%d, Opcode=%X, length=%d, payload=%s", fin, opcode, length,
0302               binary ? readBuffer_.toHexString() : cast(string) readBuffer_);
0303     }
0304 
0305     switch (opcode) {
0306     case Opcode::Close:
0307       if (length >= 2) {
0308         uint8_t buffer[2];
0309         readBuffer_.read(buffer, 2);
0310         CloseCode closeCode = static_cast<CloseCode>(ntohs(*reinterpret_cast<uint16_t*>(buffer)));
0311         THRIFT_UNUSED_VARIABLE(closeCode);
0312         string closeReason = readBuffer_.readAsString(length - 2);
0313         T_DEBUG("Connection closed: %d %s", closeCode, closeReason);
0314       }
0315       transport_->close();
0316       return false;
0317     case Opcode::Ping:
0318       pong();
0319       return readFrame();
0320     default:
0321       return true;
0322     }
0323   }
0324 
0325   void resetHandshake() {
0326     connection_ = false;
0327     secWebSocketKey_ = false;
0328     secWebSocketVersion_ = false;
0329     upgrade_ = false;
0330   }
0331 
0332   void sendBadRequest() {
0333     std::ostringstream h;
0334     h << "HTTP/1.1 400 Bad Request" << CRLF << "Server: Thrift/" << PACKAGE_VERSION << CRLF << CRLF;
0335     std::string header = h.str();
0336     transport_->write(reinterpret_cast<const uint8_t*>(header.data()), static_cast<uint32_t>(header.length()));
0337     transport_->flush();
0338     transport_->close();
0339   }
0340 
0341   void writeFrameHeader(Opcode opcode = Opcode::Continuation) {
0342     uint32_t headerSize = 1;
0343     uint32_t length = writeBuffer_.available_read();
0344     if (length < 126) {
0345       ++headerSize;
0346     } else if (length < 65536) {
0347       headerSize += 3;
0348     } else {
0349       headerSize += 9;
0350     }
0351     // The server does not mask the response
0352 
0353     uint8_t* header = static_cast<uint8_t*>(alloca(headerSize));
0354     if (opcode == Opcode::Continuation) {
0355       opcode = binary ? Opcode::Binary : Opcode::Text;
0356     }
0357     header[0] = static_cast<uint8_t>(opcode) | 0x80;
0358     if (length < 126) {
0359       header[1] = static_cast<uint8_t>(length);
0360     } else if (length < 65536) {
0361       header[1] = 126;
0362       *reinterpret_cast<uint16_t*>(header + 2) = htons(length);
0363     } else {
0364       header[1] = 127;
0365       *reinterpret_cast<uint64_t*>(header + 2) = THRIFT_htonll(length);
0366     }
0367 
0368     transport_->write(header, headerSize);
0369   }
0370 
0371   // Add constant here to avoid a linker error on Windows
0372   constexpr static const char* CRLF = "\r\n";
0373   std::string acceptKey_;
0374   bool connection_;
0375   bool secWebSocketKey_;
0376   bool secWebSocketVersion_;
0377   bool upgrade_;
0378 };
0379 
0380 /**
0381  * Wraps a transport into binary WebSocket protocol
0382  */
0383 class TBinaryWebSocketServerTransportFactory : public TTransportFactory {
0384 public:
0385   TBinaryWebSocketServerTransportFactory() = default;
0386 
0387   ~TBinaryWebSocketServerTransportFactory() override = default;
0388 
0389   /**
0390    * Wraps the transport into a buffered one.
0391    */
0392   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
0393     return std::shared_ptr<TTransport>(new TWebSocketServer<true>(trans));
0394   }
0395 };
0396 
0397 /**
0398  * Wraps a transport into text WebSocket protocol
0399  */
0400 class TTextWebSocketServerTransportFactory : public TTransportFactory {
0401 public:
0402   TTextWebSocketServerTransportFactory() = default;
0403 
0404   ~TTextWebSocketServerTransportFactory() override = default;
0405 
0406   /**
0407    * Wraps the transport into a buffered one.
0408    */
0409   std::shared_ptr<TTransport> getTransport(std::shared_ptr<TTransport> trans) override {
0410     return std::shared_ptr<TTransport>(new TWebSocketServer<false>(trans));
0411   }
0412 };
0413 } // namespace transport
0414 } // namespace thrift
0415 } // namespace apache
0416 #endif