Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-17 08:39:01

0001 //
0002 // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
0003 //
0004 // Distributed under the Boost Software License, Version 1.0. (See accompanying
0005 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
0006 //
0007 
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0010 
0011 #include <boost/mysql/any_address.hpp>
0012 #include <boost/mysql/error_code.hpp>
0013 #include <boost/mysql/string_view.hpp>
0014 
0015 #include <boost/mysql/detail/config.hpp>
0016 #include <boost/mysql/detail/connect_params_helpers.hpp>
0017 
0018 #include <boost/mysql/impl/internal/coroutine.hpp>
0019 #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
0020 
0021 #include <boost/asio/any_io_executor.hpp>
0022 #include <boost/asio/compose.hpp>
0023 #include <boost/asio/connect.hpp>
0024 #include <boost/asio/error.hpp>
0025 #include <boost/asio/ip/tcp.hpp>
0026 #include <boost/asio/local/stream_protocol.hpp>
0027 #include <boost/asio/post.hpp>
0028 #include <boost/asio/ssl/context.hpp>
0029 #include <boost/asio/ssl/stream.hpp>
0030 #include <boost/optional/optional.hpp>
0031 #include <boost/variant2/variant.hpp>
0032 
0033 #include <string>
0034 #include <utility>
0035 
0036 namespace boost {
0037 namespace mysql {
0038 namespace detail {
0039 
0040 // Asio defines a "string view parameter" to be either const std::string&,
0041 // std::experimental::string_view or std::string_view. Casting from the Boost
0042 // version doesn't work for std::experimental::string_view
0043 #if defined(BOOST_ASIO_HAS_STD_STRING_VIEW)
0044 inline std::string_view cast_asio_sv_param(string_view input) noexcept { return input; }
0045 #elif defined(BOOST_ASIO_HAS_STD_EXPERIMENTAL_STRING_VIEW)
0046 inline std::experimental::string_view cast_asio_sv_param(string_view input) noexcept
0047 {
0048     return {input.data(), input.size()};
0049 }
0050 #else
0051 inline std::string cast_asio_sv_param(string_view input) { return input; }
0052 #endif
0053 
0054 // Implements the EngineStream concept (see stream_adaptor)
0055 class variant_stream
0056 {
0057 public:
0058     variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : ex_(std::move(ex)), ssl_ctx_(ctx) {}
0059 
0060     bool supports_ssl() const { return true; }
0061 
0062     void set_endpoint(const void* value) { address_ = static_cast<const any_address*>(value); }
0063 
0064     // Executor
0065     using executor_type = asio::any_io_executor;
0066     executor_type get_executor() { return ex_; }
0067 
0068     // SSL
0069     void ssl_handshake(error_code& ec)
0070     {
0071         create_ssl_stream();
0072         ssl_->handshake(asio::ssl::stream_base::client, ec);
0073     }
0074 
0075     template <class CompletionToken>
0076     void async_ssl_handshake(CompletionToken&& token)
0077     {
0078         create_ssl_stream();
0079         ssl_->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
0080     }
0081 
0082     void ssl_shutdown(error_code& ec)
0083     {
0084         BOOST_ASSERT(ssl_.has_value());
0085         ssl_->shutdown(ec);
0086     }
0087 
0088     template <class CompletionToken>
0089     void async_ssl_shutdown(CompletionToken&& token)
0090     {
0091         BOOST_ASSERT(ssl_.has_value());
0092         ssl_->async_shutdown(std::forward<CompletionToken>(token));
0093     }
0094 
0095     // Reading
0096     std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
0097     {
0098         if (use_ssl)
0099         {
0100             BOOST_ASSERT(ssl_.has_value());
0101             return ssl_->read_some(buff, ec);
0102         }
0103         else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0104         {
0105             return tcp_sock->sock.read_some(buff, ec);
0106         }
0107 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0108         else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0109         {
0110             return unix_sock->read_some(buff, ec);
0111         }
0112 #endif
0113         else
0114         {
0115             BOOST_ASSERT(false);
0116             return 0u;
0117         }
0118     }
0119 
0120     template <class CompletionToken>
0121     void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
0122     {
0123         if (use_ssl)
0124         {
0125             BOOST_ASSERT(ssl_.has_value());
0126             ssl_->async_read_some(buff, std::forward<CompletionToken>(token));
0127         }
0128         else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0129         {
0130             tcp_sock->sock.async_read_some(buff, std::forward<CompletionToken>(token));
0131         }
0132 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0133         else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0134         {
0135             unix_sock->async_read_some(buff, std::forward<CompletionToken>(token));
0136         }
0137 #endif
0138         else
0139         {
0140             BOOST_ASSERT(false);
0141         }
0142     }
0143 
0144     // Writing
0145     std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
0146     {
0147         if (use_ssl)
0148         {
0149             BOOST_ASSERT(ssl_.has_value());
0150             return ssl_->write_some(buff, ec);
0151         }
0152         else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0153         {
0154             return tcp_sock->sock.write_some(buff, ec);
0155         }
0156 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0157         else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0158         {
0159             return unix_sock->write_some(buff, ec);
0160         }
0161 #endif
0162         else
0163         {
0164             BOOST_ASSERT(false);
0165             return 0u;
0166         }
0167     }
0168 
0169     template <class CompletionToken>
0170     void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
0171     {
0172         if (use_ssl)
0173         {
0174             BOOST_ASSERT(ssl_.has_value());
0175             return ssl_->async_write_some(buff, std::forward<CompletionToken>(token));
0176         }
0177         else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0178         {
0179             return tcp_sock->sock.async_write_some(buff, std::forward<CompletionToken>(token));
0180         }
0181 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0182         else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0183         {
0184             return unix_sock->async_write_some(buff, std::forward<CompletionToken>(token));
0185         }
0186 #endif
0187         else
0188         {
0189             BOOST_ASSERT(false);
0190         }
0191     }
0192 
0193     // Connect and close
0194     void connect(error_code& ec)
0195     {
0196         ec = setup_stream();
0197         if (ec)
0198             return;
0199 
0200         if (address_->type() == address_type::host_and_port)
0201         {
0202             // Resolve endpoints
0203             auto& tcp_sock = variant2::unsafe_get<1>(sock_);
0204             auto endpoints = tcp_sock.resolv.resolve(
0205                 cast_asio_sv_param(address_->hostname()),
0206                 std::to_string(address_->port()),
0207                 ec
0208             );
0209             if (ec)
0210                 return;
0211 
0212             // Connect stream
0213             asio::connect(tcp_sock.sock, std::move(endpoints), ec);
0214             if (ec)
0215                 return;
0216 
0217             // Disable Naggle's algorithm
0218             set_tcp_nodelay();
0219         }
0220 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0221         else
0222         {
0223             BOOST_ASSERT(address_->type() == address_type::unix_path);
0224 
0225             // Just connect the stream
0226             auto& unix_sock = variant2::unsafe_get<2>(sock_);
0227             unix_sock.connect(cast_asio_sv_param(address_->unix_socket_path()), ec);
0228         }
0229 #endif
0230     }
0231 
0232     template <class CompletionToken>
0233     void async_connect(CompletionToken&& token)
0234     {
0235         asio::async_compose<CompletionToken, void(error_code)>(connect_op(*this), token, ex_);
0236     }
0237 
0238     void close(error_code& ec)
0239     {
0240         if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0241         {
0242             tcp_sock->sock.close(ec);
0243         }
0244 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0245         else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0246         {
0247             unix_sock->close(ec);
0248         }
0249 #endif
0250     }
0251 
0252     // Exposed for testing
0253     const asio::ip::tcp::socket& tcp_socket() const { return variant2::get<socket_and_resolver>(sock_).sock; }
0254 
0255 private:
0256     struct socket_and_resolver
0257     {
0258         asio::ip::tcp::socket sock;
0259         asio::ip::tcp::resolver resolv;
0260 
0261         socket_and_resolver(asio::any_io_executor ex) : sock(ex), resolv(std::move(ex)) {}
0262     };
0263 
0264 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0265     using unix_socket = asio::local::stream_protocol::socket;
0266 #endif
0267 
0268     const any_address* address_{};
0269     asio::any_io_executor ex_;
0270     variant2::variant<
0271         variant2::monostate,
0272         socket_and_resolver
0273 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0274         ,
0275         unix_socket
0276 #endif
0277         >
0278         sock_;
0279     ssl_context_with_default ssl_ctx_;
0280     boost::optional<asio::ssl::stream<asio::ip::tcp::socket&>> ssl_;
0281 
0282     error_code setup_stream()
0283     {
0284         if (address_->type() == address_type::host_and_port)
0285         {
0286             // Clean up any previous state
0287             sock_.emplace<socket_and_resolver>(ex_);
0288         }
0289 
0290         else if (address_->type() == address_type::unix_path)
0291         {
0292 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0293             // Clean up any previous state
0294             sock_.emplace<unix_socket>(ex_);
0295 #else
0296             return asio::error::operation_not_supported;
0297 #endif
0298         }
0299 
0300         return error_code();
0301     }
0302 
0303     void set_tcp_nodelay() { variant2::unsafe_get<1u>(sock_).sock.set_option(asio::ip::tcp::no_delay(true)); }
0304 
0305     void create_ssl_stream()
0306     {
0307         // The stream object must be re-created even if it already exists, since
0308         // once used for a connection (anytime after ssl::stream::handshake is called),
0309         // it can't be re-used for any subsequent connections
0310         BOOST_ASSERT(variant2::holds_alternative<socket_and_resolver>(sock_));
0311         ssl_.emplace(variant2::unsafe_get<1>(sock_).sock, ssl_ctx_.get());
0312     }
0313 
0314     struct connect_op
0315     {
0316         int resume_point_{0};
0317         variant_stream& this_obj_;
0318         error_code stored_ec_;
0319 
0320         connect_op(variant_stream& this_obj) noexcept : this_obj_(this_obj) {}
0321 
0322         template <class Self>
0323         void operator()(Self& self, error_code ec = {}, asio::ip::tcp::resolver::results_type endpoints = {})
0324         {
0325             if (ec)
0326             {
0327                 self.complete(ec);
0328                 return;
0329             }
0330 
0331             switch (resume_point_)
0332             {
0333             case 0:
0334 
0335                 // Setup stream
0336                 stored_ec_ = this_obj_.setup_stream();
0337                 if (stored_ec_)
0338                 {
0339                     BOOST_MYSQL_YIELD(resume_point_, 1, asio::post(this_obj_.ex_, std::move(self)))
0340                     self.complete(stored_ec_);
0341                     return;
0342                 }
0343 
0344                 if (this_obj_.address_->type() == address_type::host_and_port)
0345                 {
0346                     // Resolve endpoints
0347                     BOOST_MYSQL_YIELD(
0348                         resume_point_,
0349                         2,
0350                         variant2::unsafe_get<1>(this_obj_.sock_)
0351                             .resolv.async_resolve(
0352                                 cast_asio_sv_param(this_obj_.address_->hostname()),
0353                                 std::to_string(this_obj_.address_->port()),
0354                                 std::move(self)
0355                             )
0356                     )
0357 
0358                     // Connect stream
0359                     BOOST_MYSQL_YIELD(
0360                         resume_point_,
0361                         3,
0362                         asio::async_connect(
0363                             variant2::unsafe_get<1>(this_obj_.sock_).sock,
0364                             std::move(endpoints),
0365                             std::move(self)
0366                         )
0367                     )
0368 
0369                     // The final handler requires a void(error_code, tcp::endpoint signature),
0370                     // which this function can't implement. See operator() overload below.
0371                 }
0372 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0373                 else
0374                 {
0375                     BOOST_ASSERT(this_obj_.address_->type() == address_type::unix_path);
0376 
0377                     // Just connect the stream
0378                     BOOST_MYSQL_YIELD(
0379                         resume_point_,
0380                         4,
0381                         variant2::unsafe_get<2>(this_obj_.sock_)
0382                             .async_connect(
0383                                 cast_asio_sv_param(this_obj_.address_->unix_socket_path()),
0384                                 std::move(self)
0385                             )
0386                     )
0387 
0388                     self.complete(error_code());
0389                 }
0390 #endif
0391             }
0392         }
0393 
0394         template <class Self>
0395         void operator()(Self& self, error_code ec, asio::ip::tcp::endpoint)
0396         {
0397             if (!ec)
0398             {
0399                 // Disable Naggle's algorithm
0400                 this_obj_.set_tcp_nodelay();
0401             }
0402             self.complete(ec);
0403         }
0404     };
0405 };
0406 
0407 }  // namespace detail
0408 }  // namespace mysql
0409 }  // namespace boost
0410 
0411 #endif