File indexing completed on 2025-04-17 08:39:01
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_VARIANT_STREAM_HPP
0010
0011 #include <boost/mysql/any_address.hpp>
0012 #include <boost/mysql/error_code.hpp>
0013 #include <boost/mysql/string_view.hpp>
0014
0015 #include <boost/mysql/detail/config.hpp>
0016 #include <boost/mysql/detail/connect_params_helpers.hpp>
0017
0018 #include <boost/mysql/impl/internal/coroutine.hpp>
0019 #include <boost/mysql/impl/internal/ssl_context_with_default.hpp>
0020
0021 #include <boost/asio/any_io_executor.hpp>
0022 #include <boost/asio/compose.hpp>
0023 #include <boost/asio/connect.hpp>
0024 #include <boost/asio/error.hpp>
0025 #include <boost/asio/ip/tcp.hpp>
0026 #include <boost/asio/local/stream_protocol.hpp>
0027 #include <boost/asio/post.hpp>
0028 #include <boost/asio/ssl/context.hpp>
0029 #include <boost/asio/ssl/stream.hpp>
0030 #include <boost/optional/optional.hpp>
0031 #include <boost/variant2/variant.hpp>
0032
0033 #include <string>
0034 #include <utility>
0035
0036 namespace boost {
0037 namespace mysql {
0038 namespace detail {
0039
0040
0041
0042
0043 #if defined(BOOST_ASIO_HAS_STD_STRING_VIEW)
0044 inline std::string_view cast_asio_sv_param(string_view input) noexcept { return input; }
0045 #elif defined(BOOST_ASIO_HAS_STD_EXPERIMENTAL_STRING_VIEW)
0046 inline std::experimental::string_view cast_asio_sv_param(string_view input) noexcept
0047 {
0048 return {input.data(), input.size()};
0049 }
0050 #else
0051 inline std::string cast_asio_sv_param(string_view input) { return input; }
0052 #endif
0053
0054
0055 class variant_stream
0056 {
0057 public:
0058 variant_stream(asio::any_io_executor ex, asio::ssl::context* ctx) : ex_(std::move(ex)), ssl_ctx_(ctx) {}
0059
0060 bool supports_ssl() const { return true; }
0061
0062 void set_endpoint(const void* value) { address_ = static_cast<const any_address*>(value); }
0063
0064
0065 using executor_type = asio::any_io_executor;
0066 executor_type get_executor() { return ex_; }
0067
0068
0069 void ssl_handshake(error_code& ec)
0070 {
0071 create_ssl_stream();
0072 ssl_->handshake(asio::ssl::stream_base::client, ec);
0073 }
0074
0075 template <class CompletionToken>
0076 void async_ssl_handshake(CompletionToken&& token)
0077 {
0078 create_ssl_stream();
0079 ssl_->async_handshake(asio::ssl::stream_base::client, std::forward<CompletionToken>(token));
0080 }
0081
0082 void ssl_shutdown(error_code& ec)
0083 {
0084 BOOST_ASSERT(ssl_.has_value());
0085 ssl_->shutdown(ec);
0086 }
0087
0088 template <class CompletionToken>
0089 void async_ssl_shutdown(CompletionToken&& token)
0090 {
0091 BOOST_ASSERT(ssl_.has_value());
0092 ssl_->async_shutdown(std::forward<CompletionToken>(token));
0093 }
0094
0095
0096 std::size_t read_some(asio::mutable_buffer buff, bool use_ssl, error_code& ec)
0097 {
0098 if (use_ssl)
0099 {
0100 BOOST_ASSERT(ssl_.has_value());
0101 return ssl_->read_some(buff, ec);
0102 }
0103 else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0104 {
0105 return tcp_sock->sock.read_some(buff, ec);
0106 }
0107 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0108 else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0109 {
0110 return unix_sock->read_some(buff, ec);
0111 }
0112 #endif
0113 else
0114 {
0115 BOOST_ASSERT(false);
0116 return 0u;
0117 }
0118 }
0119
0120 template <class CompletionToken>
0121 void async_read_some(asio::mutable_buffer buff, bool use_ssl, CompletionToken&& token)
0122 {
0123 if (use_ssl)
0124 {
0125 BOOST_ASSERT(ssl_.has_value());
0126 ssl_->async_read_some(buff, std::forward<CompletionToken>(token));
0127 }
0128 else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0129 {
0130 tcp_sock->sock.async_read_some(buff, std::forward<CompletionToken>(token));
0131 }
0132 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0133 else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0134 {
0135 unix_sock->async_read_some(buff, std::forward<CompletionToken>(token));
0136 }
0137 #endif
0138 else
0139 {
0140 BOOST_ASSERT(false);
0141 }
0142 }
0143
0144
0145 std::size_t write_some(boost::asio::const_buffer buff, bool use_ssl, error_code& ec)
0146 {
0147 if (use_ssl)
0148 {
0149 BOOST_ASSERT(ssl_.has_value());
0150 return ssl_->write_some(buff, ec);
0151 }
0152 else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0153 {
0154 return tcp_sock->sock.write_some(buff, ec);
0155 }
0156 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0157 else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0158 {
0159 return unix_sock->write_some(buff, ec);
0160 }
0161 #endif
0162 else
0163 {
0164 BOOST_ASSERT(false);
0165 return 0u;
0166 }
0167 }
0168
0169 template <class CompletionToken>
0170 void async_write_some(boost::asio::const_buffer buff, bool use_ssl, CompletionToken&& token)
0171 {
0172 if (use_ssl)
0173 {
0174 BOOST_ASSERT(ssl_.has_value());
0175 return ssl_->async_write_some(buff, std::forward<CompletionToken>(token));
0176 }
0177 else if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0178 {
0179 return tcp_sock->sock.async_write_some(buff, std::forward<CompletionToken>(token));
0180 }
0181 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0182 else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0183 {
0184 return unix_sock->async_write_some(buff, std::forward<CompletionToken>(token));
0185 }
0186 #endif
0187 else
0188 {
0189 BOOST_ASSERT(false);
0190 }
0191 }
0192
0193
0194 void connect(error_code& ec)
0195 {
0196 ec = setup_stream();
0197 if (ec)
0198 return;
0199
0200 if (address_->type() == address_type::host_and_port)
0201 {
0202
0203 auto& tcp_sock = variant2::unsafe_get<1>(sock_);
0204 auto endpoints = tcp_sock.resolv.resolve(
0205 cast_asio_sv_param(address_->hostname()),
0206 std::to_string(address_->port()),
0207 ec
0208 );
0209 if (ec)
0210 return;
0211
0212
0213 asio::connect(tcp_sock.sock, std::move(endpoints), ec);
0214 if (ec)
0215 return;
0216
0217
0218 set_tcp_nodelay();
0219 }
0220 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0221 else
0222 {
0223 BOOST_ASSERT(address_->type() == address_type::unix_path);
0224
0225
0226 auto& unix_sock = variant2::unsafe_get<2>(sock_);
0227 unix_sock.connect(cast_asio_sv_param(address_->unix_socket_path()), ec);
0228 }
0229 #endif
0230 }
0231
0232 template <class CompletionToken>
0233 void async_connect(CompletionToken&& token)
0234 {
0235 asio::async_compose<CompletionToken, void(error_code)>(connect_op(*this), token, ex_);
0236 }
0237
0238 void close(error_code& ec)
0239 {
0240 if (auto* tcp_sock = variant2::get_if<socket_and_resolver>(&sock_))
0241 {
0242 tcp_sock->sock.close(ec);
0243 }
0244 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0245 else if (auto* unix_sock = variant2::get_if<unix_socket>(&sock_))
0246 {
0247 unix_sock->close(ec);
0248 }
0249 #endif
0250 }
0251
0252
0253 const asio::ip::tcp::socket& tcp_socket() const { return variant2::get<socket_and_resolver>(sock_).sock; }
0254
0255 private:
0256 struct socket_and_resolver
0257 {
0258 asio::ip::tcp::socket sock;
0259 asio::ip::tcp::resolver resolv;
0260
0261 socket_and_resolver(asio::any_io_executor ex) : sock(ex), resolv(std::move(ex)) {}
0262 };
0263
0264 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0265 using unix_socket = asio::local::stream_protocol::socket;
0266 #endif
0267
0268 const any_address* address_{};
0269 asio::any_io_executor ex_;
0270 variant2::variant<
0271 variant2::monostate,
0272 socket_and_resolver
0273 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0274 ,
0275 unix_socket
0276 #endif
0277 >
0278 sock_;
0279 ssl_context_with_default ssl_ctx_;
0280 boost::optional<asio::ssl::stream<asio::ip::tcp::socket&>> ssl_;
0281
0282 error_code setup_stream()
0283 {
0284 if (address_->type() == address_type::host_and_port)
0285 {
0286
0287 sock_.emplace<socket_and_resolver>(ex_);
0288 }
0289
0290 else if (address_->type() == address_type::unix_path)
0291 {
0292 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0293
0294 sock_.emplace<unix_socket>(ex_);
0295 #else
0296 return asio::error::operation_not_supported;
0297 #endif
0298 }
0299
0300 return error_code();
0301 }
0302
0303 void set_tcp_nodelay() { variant2::unsafe_get<1u>(sock_).sock.set_option(asio::ip::tcp::no_delay(true)); }
0304
0305 void create_ssl_stream()
0306 {
0307
0308
0309
0310 BOOST_ASSERT(variant2::holds_alternative<socket_and_resolver>(sock_));
0311 ssl_.emplace(variant2::unsafe_get<1>(sock_).sock, ssl_ctx_.get());
0312 }
0313
0314 struct connect_op
0315 {
0316 int resume_point_{0};
0317 variant_stream& this_obj_;
0318 error_code stored_ec_;
0319
0320 connect_op(variant_stream& this_obj) noexcept : this_obj_(this_obj) {}
0321
0322 template <class Self>
0323 void operator()(Self& self, error_code ec = {}, asio::ip::tcp::resolver::results_type endpoints = {})
0324 {
0325 if (ec)
0326 {
0327 self.complete(ec);
0328 return;
0329 }
0330
0331 switch (resume_point_)
0332 {
0333 case 0:
0334
0335
0336 stored_ec_ = this_obj_.setup_stream();
0337 if (stored_ec_)
0338 {
0339 BOOST_MYSQL_YIELD(resume_point_, 1, asio::post(this_obj_.ex_, std::move(self)))
0340 self.complete(stored_ec_);
0341 return;
0342 }
0343
0344 if (this_obj_.address_->type() == address_type::host_and_port)
0345 {
0346
0347 BOOST_MYSQL_YIELD(
0348 resume_point_,
0349 2,
0350 variant2::unsafe_get<1>(this_obj_.sock_)
0351 .resolv.async_resolve(
0352 cast_asio_sv_param(this_obj_.address_->hostname()),
0353 std::to_string(this_obj_.address_->port()),
0354 std::move(self)
0355 )
0356 )
0357
0358
0359 BOOST_MYSQL_YIELD(
0360 resume_point_,
0361 3,
0362 asio::async_connect(
0363 variant2::unsafe_get<1>(this_obj_.sock_).sock,
0364 std::move(endpoints),
0365 std::move(self)
0366 )
0367 )
0368
0369
0370
0371 }
0372 #ifdef BOOST_ASIO_HAS_LOCAL_SOCKETS
0373 else
0374 {
0375 BOOST_ASSERT(this_obj_.address_->type() == address_type::unix_path);
0376
0377
0378 BOOST_MYSQL_YIELD(
0379 resume_point_,
0380 4,
0381 variant2::unsafe_get<2>(this_obj_.sock_)
0382 .async_connect(
0383 cast_asio_sv_param(this_obj_.address_->unix_socket_path()),
0384 std::move(self)
0385 )
0386 )
0387
0388 self.complete(error_code());
0389 }
0390 #endif
0391 }
0392 }
0393
0394 template <class Self>
0395 void operator()(Self& self, error_code ec, asio::ip::tcp::endpoint)
0396 {
0397 if (!ec)
0398 {
0399
0400 this_obj_.set_tcp_nodelay();
0401 }
0402 self.complete(ec);
0403 }
0404 };
0405 };
0406
0407 }
0408 }
0409 }
0410
0411 #endif