Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-15 08:42:57

0001 //
0002 // Copyright (c) 2019-2025 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
0003 //
0004 // Distributed under the Boost Software License, Version 1.0. (See accompanying
0005 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
0006 //
0007 
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0010 
0011 #include <boost/mysql/character_set.hpp>
0012 #include <boost/mysql/diagnostics.hpp>
0013 #include <boost/mysql/error_code.hpp>
0014 #include <boost/mysql/handshake_params.hpp>
0015 #include <boost/mysql/mysql_collations.hpp>
0016 
0017 #include <boost/mysql/detail/algo_params.hpp>
0018 #include <boost/mysql/detail/next_action.hpp>
0019 #include <boost/mysql/detail/ok_view.hpp>
0020 
0021 #include <boost/mysql/impl/internal/auth/auth.hpp>
0022 #include <boost/mysql/impl/internal/coroutine.hpp>
0023 #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
0024 #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
0025 #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
0026 #include <boost/mysql/impl/internal/protocol/serialization.hpp>
0027 #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
0028 
0029 #include <cstdint>
0030 
0031 namespace boost {
0032 namespace mysql {
0033 namespace detail {
0034 
0035 inline capabilities conditional_capability(bool condition, std::uint32_t cap)
0036 {
0037     return capabilities(condition ? cap : 0);
0038 }
0039 
0040 inline error_code process_capabilities(
0041     const handshake_params& params,
0042     const server_hello& hello,
0043     capabilities& negotiated_caps,
0044     bool transport_supports_ssl
0045 )
0046 {
0047     auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
0048     capabilities server_caps = hello.server_capabilities;
0049     capabilities required_caps = mandatory_capabilities |
0050                                  conditional_capability(!params.database().empty(), CLIENT_CONNECT_WITH_DB) |
0051                                  conditional_capability(params.multi_queries(), CLIENT_MULTI_STATEMENTS) |
0052                                  conditional_capability(ssl == ssl_mode::require, CLIENT_SSL);
0053     if (required_caps.has(CLIENT_SSL) && !server_caps.has(CLIENT_SSL))
0054     {
0055         // This happens if the server doesn't have SSL configured. This special
0056         // error code helps users diagnosing their problem a lot (server_unsupported doesn't).
0057         return make_error_code(client_errc::server_doesnt_support_ssl);
0058     }
0059     else if (!server_caps.has_all(required_caps))
0060     {
0061         return make_error_code(client_errc::server_unsupported);
0062     }
0063     negotiated_caps = server_caps & (required_caps | optional_capabilities |
0064                                      conditional_capability(ssl == ssl_mode::enable, CLIENT_SSL));
0065     return error_code();
0066 }
0067 
0068 class handshake_algo
0069 {
0070     int resume_point_{0};
0071     handshake_params hparams_;
0072     auth_response auth_resp_;
0073     std::uint8_t sequence_number_{0};
0074     bool secure_channel_{false};
0075 
0076     // Attempts to map the collection_id to a character set. We try to be conservative
0077     // here, since servers will happily accept unknown collation IDs, silently defaulting
0078     // to the server's default character set (often latin1, which is not Unicode).
0079     static character_set collation_id_to_charset(std::uint16_t collation_id)
0080     {
0081         switch (collation_id)
0082         {
0083         case mysql_collations::utf8mb4_bin:
0084         case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
0085         case mysql_collations::ascii_general_ci:
0086         case mysql_collations::ascii_bin: return ascii_charset;
0087         default: return character_set{};
0088         }
0089     }
0090 
0091     // Once the handshake is processed, the capabilities are stored in the connection state
0092     bool use_ssl(const connection_state_data& st) const { return st.current_capabilities.has(CLIENT_SSL); }
0093 
0094     error_code process_handshake(
0095         connection_state_data& st,
0096         diagnostics& diag,
0097         span<const std::uint8_t> buffer
0098     )
0099     {
0100         // Deserialize server hello
0101         server_hello hello{};
0102         auto err = deserialize_server_hello(buffer, hello, diag);
0103         if (err)
0104             return err;
0105 
0106         // Check capabilities
0107         capabilities negotiated_caps;
0108         err = process_capabilities(hparams_, hello, negotiated_caps, st.tls_supported);
0109         if (err)
0110             return err;
0111 
0112         // Set capabilities, db flavor and connection ID
0113         st.current_capabilities = negotiated_caps;
0114         st.flavor = hello.server;
0115         st.connection_id = hello.connection_id;
0116 
0117         // If we're using SSL, mark the channel as secure
0118         secure_channel_ = secure_channel_ || use_ssl(st);
0119 
0120         // Compute auth response
0121         return compute_auth_response(
0122             hello.auth_plugin_name,
0123             hparams_.password(),
0124             hello.auth_plugin_data.to_span(),
0125             secure_channel_,
0126             auth_resp_
0127         );
0128     }
0129 
0130     // Response to that initial greeting
0131     ssl_request compose_ssl_request(const connection_state_data& st)
0132     {
0133         return ssl_request{
0134             st.current_capabilities,
0135             static_cast<std::uint32_t>(max_packet_size),
0136             hparams_.connection_collation(),
0137         };
0138     }
0139 
0140     login_request compose_login_request(const connection_state_data& st)
0141     {
0142         return login_request{
0143             st.current_capabilities,
0144             static_cast<std::uint32_t>(max_packet_size),
0145             hparams_.connection_collation(),
0146             hparams_.username(),
0147             auth_resp_.data,
0148             hparams_.database(),
0149             auth_resp_.plugin_name,
0150         };
0151     }
0152 
0153     // Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
0154     error_code process_auth_switch(auth_switch msg)
0155     {
0156         return compute_auth_response(
0157             msg.plugin_name,
0158             hparams_.password(),
0159             msg.auth_data,
0160             secure_channel_,
0161             auth_resp_
0162         );
0163     }
0164 
0165     error_code process_auth_more_data(span<const std::uint8_t> data)
0166     {
0167         return compute_auth_response(
0168             auth_resp_.plugin_name,
0169             hparams_.password(),
0170             data,
0171             secure_channel_,
0172             auth_resp_
0173         );
0174     }
0175 
0176     // Composes an auth_switch_response message with the contents of auth_resp_
0177     auth_switch_response compose_auth_switch_response() const
0178     {
0179         return auth_switch_response{auth_resp_.data};
0180     }
0181 
0182     void on_success(connection_state_data& st, const ok_view& ok)
0183     {
0184         st.status = connection_status::ready;
0185         st.backslash_escapes = ok.backslash_escapes();
0186         st.current_charset = collation_id_to_charset(hparams_.connection_collation());
0187     }
0188 
0189     error_code process_ok(connection_state_data& st)
0190     {
0191         ok_view res{};
0192         auto ec = deserialize_ok_packet(st.reader.message(), res);
0193         if (ec)
0194             return ec;
0195         on_success(st, res);
0196         return error_code();
0197     }
0198 
0199 public:
0200     handshake_algo(handshake_algo_params params) noexcept
0201         : hparams_(params.hparams), secure_channel_(params.secure_channel)
0202     {
0203     }
0204 
0205     next_action resume(connection_state_data& st, diagnostics& diag, error_code ec)
0206     {
0207         if (ec)
0208             return ec;
0209 
0210         handhake_server_response resp(error_code{});
0211 
0212         switch (resume_point_)
0213         {
0214         case 0:
0215             // Handshake wipes out state, so no state checks are performed.
0216             // Setup
0217             st.reset();
0218 
0219             // Read server greeting
0220             BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
0221 
0222             // Process server greeting
0223             ec = process_handshake(st, diag, st.reader.message());
0224             if (ec)
0225                 return ec;
0226 
0227             // SSL
0228             if (use_ssl(st))
0229             {
0230                 // Send SSL request
0231                 BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
0232 
0233                 // SSL handshake
0234                 BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
0235 
0236                 // Mark the connection as using ssl
0237                 st.tls_active = true;
0238             }
0239 
0240             // Compose and send handshake response
0241             BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
0242 
0243             // Auth message exchange
0244             while (true)
0245             {
0246                 // Receive response
0247                 BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
0248 
0249                 // Process it
0250                 resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
0251                 if (resp.type == handhake_server_response::type_t::ok)
0252                 {
0253                     // Auth success, quit
0254                     on_success(st, resp.data.ok);
0255                     return next_action();
0256                 }
0257                 else if (resp.type == handhake_server_response::type_t::error)
0258                 {
0259                     // Error, quit
0260                     return resp.data.err;
0261                 }
0262                 else if (resp.type == handhake_server_response::type_t::auth_switch)
0263                 {
0264                     // Compute response
0265                     ec = process_auth_switch(resp.data.auth_sw);
0266                     if (ec)
0267                         return ec;
0268 
0269                     BOOST_MYSQL_YIELD(
0270                         resume_point_,
0271                         6,
0272                         st.write(compose_auth_switch_response(), sequence_number_)
0273                     )
0274                 }
0275                 else if (resp.type == handhake_server_response::type_t::ok_follows)
0276                 {
0277                     // The next packet must be an OK packet. Read it
0278                     BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
0279 
0280                     // Process it
0281                     // Regardless of whether we succeeded or not, we're done
0282                     return process_ok(st);
0283                 }
0284                 else
0285                 {
0286                     BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
0287 
0288                     // Compute response
0289                     ec = process_auth_more_data(resp.data.more_data);
0290                     if (ec)
0291                         return ec;
0292 
0293                     // Write response
0294                     BOOST_MYSQL_YIELD(
0295                         resume_point_,
0296                         8,
0297                         st.write(compose_auth_switch_response(), sequence_number_)
0298                     )
0299                 }
0300             }
0301         }
0302 
0303         return next_action();
0304     }
0305 };
0306 
0307 }  // namespace detail
0308 }  // namespace mysql
0309 }  // namespace boost
0310 
0311 #endif