Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 08:52:42

0001 //
0002 // Copyright (c) 2019-2025 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
0003 //
0004 // Distributed under the Boost Software License, Version 1.0. (See accompanying
0005 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
0006 //
0007 
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0010 
0011 #include <boost/mysql/any_address.hpp>
0012 #include <boost/mysql/error_code.hpp>
0013 #include <boost/mysql/string_view.hpp>
0014 
0015 #include <boost/mysql/detail/access.hpp>
0016 
0017 #include <boost/mysql/impl/internal/coroutine.hpp>
0018 #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
0019 
0020 #include <boost/asio/any_io_executor.hpp>
0021 #include <boost/asio/associated_immediate_executor.hpp>
0022 #include <boost/asio/cancellation_type.hpp>
0023 #include <boost/asio/compose.hpp>
0024 #include <boost/asio/connect.hpp>
0025 #include <boost/asio/dispatch.hpp>
0026 #include <boost/asio/error.hpp>
0027 #include <boost/asio/generic/stream_protocol.hpp>
0028 #include <boost/asio/ip/tcp.hpp>
0029 #include <boost/asio/local/stream_protocol.hpp>
0030 #include <boost/asio/ssl/context.hpp>
0031 #include <boost/asio/ssl/stream.hpp>
0032 #include <boost/core/span.hpp>
0033 #include <boost/optional/optional.hpp>
0034 
0035 #include <memory>
0036 #include <string>
0037 #include <utility>
0038 #include <vector>
0039 
0040 namespace boost {
0041 namespace mysql {
0042 namespace detail {
0043 
0044 struct variant_stream_state
0045 {
0046     asio::generic::stream_protocol::socket sock;
0047     ssl_context_with_default ssl_ctx;
0048     boost::optional<asio::ssl::stream<asio::generic::stream_protocol::socket&>> ssl;
0049 
0050     variant_stream_state(asio::any_io_executor ex, asio::ssl::context* ctx) : sock(ex), ssl_ctx(ctx) {}
0051 
0052     asio::ssl::stream<asio::generic::stream_protocol::socket&>& create_ssl_stream()
0053     {
0054         // The stream object must be re-created even if it already exists, since
0055         // once used for a connection (anytime after ssl::stream::handshake is called),
0056         // it can't be re-used for any subsequent connections
0057         ssl.emplace(sock, ssl_ctx.get());
0058         return *ssl;
0059     }
0060 };
0061 
0062 enum class vsconnect_action_type
0063 {
0064     none,
0065     resolve,
0066     connect,
0067     immediate,  // we'll be performing an immediate completion
0068 };
0069 
0070 struct vsconnect_action
0071 {
0072     vsconnect_action_type type;
0073 
0074     union data_t
0075     {
0076         error_code err;
0077         struct resolve_t
0078         {
0079             const std::string* hostname;
0080             const std::string* service;
0081         } resolve;
0082         span<const asio::generic::stream_protocol::endpoint> connect;
0083 
0084         data_t(error_code v) noexcept : err(v) {}
0085         data_t(resolve_t v) noexcept : resolve(v) {}
0086         data_t(span<const asio::generic::stream_protocol::endpoint> v) noexcept : connect(v) {}
0087     } data;
0088 
0089     struct immediate_tag
0090     {
0091     };
0092 
0093     vsconnect_action(immediate_tag) noexcept : type(vsconnect_action_type::immediate), data(error_code()) {}
0094     vsconnect_action(error_code v = {}) noexcept : type(vsconnect_action_type::none), data(v) {}
0095     vsconnect_action(data_t::resolve_t v) noexcept : type(vsconnect_action_type::resolve), data(v) {}
0096     vsconnect_action(span<const asio::generic::stream_protocol::endpoint> v) noexcept
0097         : type(vsconnect_action_type::connect), data(v)
0098     {
0099     }
0100 };
0101 
0102 class variant_stream_connect_algo
0103 {
0104     variant_stream_state* st_;
0105     const any_address* addr_;
0106     boost::optional<asio::ip::tcp::resolver> resolv_;
0107     std::vector<asio::generic::stream_protocol::endpoint> endpoints_;
0108     std::string service_;
0109     int resume_point_{0};
0110 
0111     const std::string& address() const { return access::get_impl(*addr_).address; }
0112     asio::any_io_executor get_executor() const { return st_->sock.get_executor(); }
0113 
0114 public:
0115     variant_stream_connect_algo(variant_stream_state& st, const any_address& addr) : st_(&st), addr_(&addr) {}
0116 
0117     asio::ip::tcp::resolver& resolver() { return *resolv_; }
0118     asio::generic::stream_protocol::socket& socket() { return st_->sock; }
0119 
0120     vsconnect_action resume(
0121         error_code ec,
0122         const asio::ip::tcp::resolver::results_type* resolver_results,
0123         asio::cancellation_type_t cancel_state
0124     )
0125     {
0126         // All errors are considered fatal
0127         if (ec)
0128             return ec;
0129 
0130         // If we received a terminal cancellation signal, exit with the appropriate error code.
0131         // In composed async operations, if the cancellation arrives after an intermediate operation
0132         // has completed, but before the handler is called, the operation finishes successfully,
0133         // but the cancellation state is set. This check covers this case.
0134         if (!!(cancel_state & asio::cancellation_type_t::terminal))
0135             return error_code(asio::error::operation_aborted);
0136 
0137         switch (resume_point_)
0138         {
0139         case 0:
0140 
0141             // Clean up any previous state
0142             st_->sock = asio::generic::stream_protocol::socket(get_executor());
0143 
0144             // Set up the endpoints vector
0145             if (addr_->type() == address_type::host_and_port)
0146             {
0147                 // Emplace the resolver
0148                 resolv_.emplace(get_executor());
0149 
0150                 // Resolve the endpoints
0151                 service_ = std::to_string(addr_->port());
0152                 BOOST_MYSQL_YIELD(resume_point_, 1, vsconnect_action({&address(), &service_}));
0153 
0154                 // Convert them to a vector of type-erased endpoints.
0155                 // This workarounds https://github.com/chriskohlhoff/asio/issues/1502
0156                 // and makes connect() uniform for TCP and UNIX
0157                 endpoints_.reserve(resolver_results->size());
0158                 for (const auto& entry : *resolver_results)
0159                 {
0160                     endpoints_.push_back(entry.endpoint());
0161                 }
0162             }
0163             else
0164             {
0165                 BOOST_ASSERT(addr_->type() == address_type::unix_path);
0166 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0167                 endpoints_.push_back(asio::local::stream_protocol::endpoint(address()));
0168 #else
0169                 BOOST_MYSQL_YIELD(resume_point_, 3, vsconnect_action::immediate_tag{});
0170                 return vsconnect_action(asio::error::operation_not_supported);
0171 #endif
0172             }
0173 
0174             // Actually connect
0175             BOOST_MYSQL_YIELD(resume_point_, 2, vsconnect_action{endpoints_});
0176 
0177             // If we're doing TCP, disable Naggle's algorithm
0178             if (addr_->type() == address_type::host_and_port)
0179             {
0180                 st_->sock.set_option(asio::ip::tcp::no_delay(true));
0181             }
0182 
0183             // Done
0184         }
0185 
0186         return {};
0187     }
0188 };
0189 
0190 // Implements the EngineStream concept (see stream_adaptor)
0191 class variant_stream
0192 {
0193 public:
0194     variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : st_(std::move(ex), ctx) {}
0195 
0196     bool supports_ssl() const { return true; }
0197 
0198     // Executor
0199     using executor_type = asio::any_io_executor;
0200     executor_type get_executor() { return st_.sock.get_executor(); }
0201 
0202     // SSL
0203     void ssl_handshake(error_code& ec)
0204     {
0205         st_.create_ssl_stream().handshake(asio::ssl::stream_base::client, ec);
0206     }
0207 
0208     template <class CompletionToken>
0209     void async_ssl_handshake(CompletionToken&& token)
0210     {
0211         st_.create_ssl_stream();
0212         st_.ssl->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
0213     }
0214 
0215     void ssl_shutdown(error_code& ec)
0216     {
0217         BOOST_ASSERT(st_.ssl.has_value());
0218         st_.ssl->shutdown(ec);
0219     }
0220 
0221     template <class CompletionToken>
0222     void async_ssl_shutdown(CompletionToken&& token)
0223     {
0224         BOOST_ASSERT(st_.ssl.has_value());
0225         st_.ssl->async_shutdown(std::forward<CompletionToken>(token));
0226     }
0227 
0228     // Reading
0229     std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
0230     {
0231         if (use_ssl)
0232         {
0233             BOOST_ASSERT(st_.ssl.has_value());
0234             return st_.ssl->read_some(buff, ec);
0235         }
0236         else
0237         {
0238             return st_.sock.read_some(buff, ec);
0239         }
0240     }
0241 
0242     template <class CompletionToken>
0243     void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
0244     {
0245         if (use_ssl)
0246         {
0247             BOOST_ASSERT(st_.ssl.has_value());
0248             st_.ssl->async_read_some(buff, std::forward<CompletionToken>(token));
0249         }
0250         else
0251         {
0252             st_.sock.async_read_some(buff, std::forward<CompletionToken>(token));
0253         }
0254     }
0255 
0256     // Writing
0257     std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
0258     {
0259         if (use_ssl)
0260         {
0261             BOOST_ASSERT(st_.ssl.has_value());
0262             return st_.ssl->write_some(buff, ec);
0263         }
0264         else
0265         {
0266             return st_.sock.write_some(buff, ec);
0267         }
0268     }
0269 
0270     template <class CompletionToken>
0271     void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
0272     {
0273         if (use_ssl)
0274         {
0275             BOOST_ASSERT(st_.ssl.has_value());
0276             return st_.ssl->async_write_some(buff, std::forward<CompletionToken>(token));
0277         }
0278         else
0279         {
0280             return st_.sock.async_write_some(buff, std::forward<CompletionToken>(token));
0281         }
0282     }
0283 
0284     // Connect and close
0285     void connect(const void* server_address, error_code& output_ec)
0286     {
0287         // Setup
0288         variant_stream_connect_algo algo(st_, *static_cast<const any_address*>(server_address));
0289         error_code ec;
0290         asio::ip::tcp::resolver::results_type resolver_results;
0291 
0292         // Run until complete
0293         while (true)
0294         {
0295             // The sync algorithm doesn't support cancellation
0296             auto act = algo.resume(ec, &resolver_results, asio::cancellation_type_t::none);
0297             switch (act.type)
0298             {
0299             case vsconnect_action_type::connect: asio::connect(st_.sock, act.data.connect, ec); break;
0300             case vsconnect_action_type::resolve:
0301                 resolver_results = algo.resolver()
0302                                        .resolve(*act.data.resolve.hostname, *act.data.resolve.service, ec);
0303                 break;
0304             case vsconnect_action_type::immediate: break;  // has effect only for async
0305             case vsconnect_action_type::none: output_ec = act.data.err; return;
0306             default: BOOST_ASSERT(false);  // LCOV_EXCL_LINE
0307             }
0308         }
0309     }
0310 
0311     template <class CompletionToken>
0312     void async_connect(const void* server_address, CompletionToken&& token)
0313     {
0314         asio::async_compose<CompletionToken, void(error_code)>(
0315             connect_op(*this, *static_cast<const any_address*>(server_address)),
0316             token,
0317             get_executor()
0318         );
0319     }
0320 
0321     void close(error_code& ec)
0322     {
0323         st_.sock.shutdown(asio::generic::stream_protocol::socket::shutdown_both, ec);
0324         st_.sock.close(ec);
0325     }
0326 
0327     // Exposed for testing
0328     const asio::generic::stream_protocol::socket& socket() const { return st_.sock; }
0329 
0330 private:
0331     variant_stream_state st_;
0332 
0333     struct connect_op
0334     {
0335         std::unique_ptr<variant_stream_connect_algo> algo_;
0336 
0337         connect_op(variant_stream& this_obj, const any_address& server_address)
0338             : algo_(new variant_stream_connect_algo(this_obj.st_, server_address))
0339         {
0340         }
0341 
0342         template <class Self>
0343         void operator()(
0344             Self& self,
0345             error_code ec = {},
0346             const asio::ip::tcp::resolver::results_type& resolver_results = {}
0347         )
0348         {
0349             auto act = algo_->resume(ec, &resolver_results, self.cancelled());
0350             switch (act.type)
0351             {
0352             case vsconnect_action_type::connect:
0353                 asio::async_connect(algo_->socket(), act.data.connect, std::move(self));
0354                 break;
0355             case vsconnect_action_type::resolve:
0356                 algo_->resolver()
0357                     .async_resolve(*act.data.resolve.hostname, *act.data.resolve.service, std::move(self));
0358                 break;
0359             case vsconnect_action_type::immediate:
0360                 asio::dispatch(
0361                     asio::get_associated_immediate_executor(self, self.get_io_executor()),
0362                     std::move(self)
0363                 );
0364                 break;
0365             case vsconnect_action_type::none:
0366                 algo_.reset();
0367                 self.complete(act.data.err);
0368                 break;
0369             default: BOOST_ASSERT(false);  // LCOV_EXCL_LINE
0370             }
0371         }
0372 
0373         // Signature for range connect
0374         template <class Self>
0375         void operator()(Self& self, error_code ec, asio::generic::stream_protocol::endpoint)
0376         {
0377             (*this)(self, ec, asio::ip::tcp::resolver::results_type{});
0378         }
0379     };
0380 };
0381 
0382 }  // namespace detail
0383 }  // namespace mysql
0384 }  // namespace boost
0385 
0386 #endif