Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-04-19 08:33:52

0001 //
0002 // Copyright (c) 2019-2024 Ruben Perez Hidalgo (rubenperez038 at gmail dot com)
0003 //
0004 // Distributed under the Boost Software License, Version 1.0. (See accompanying
0005 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
0006 //
0007 
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0010 
0011 #include <boost/mysql/character_set.hpp>
0012 #include <boost/mysql/diagnostics.hpp>
0013 #include <boost/mysql/error_code.hpp>
0014 #include <boost/mysql/handshake_params.hpp>
0015 #include <boost/mysql/mysql_collations.hpp>
0016 
0017 #include <boost/mysql/detail/algo_params.hpp>
0018 #include <boost/mysql/detail/next_action.hpp>
0019 #include <boost/mysql/detail/ok_view.hpp>
0020 
0021 #include <boost/mysql/impl/internal/auth/auth.hpp>
0022 #include <boost/mysql/impl/internal/coroutine.hpp>
0023 #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
0024 #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
0025 #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
0026 #include <boost/mysql/impl/internal/protocol/serialization.hpp>
0027 #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
0028 
0029 #include <cstdint>
0030 
0031 namespace boost {
0032 namespace mysql {
0033 namespace detail {
0034 
0035 inline capabilities conditional_capability(bool condition, std::uint32_t cap)
0036 {
0037     return capabilities(condition ? cap : 0);
0038 }
0039 
0040 inline error_code process_capabilities(
0041     const handshake_params& params,
0042     const server_hello& hello,
0043     capabilities& negotiated_caps,
0044     bool transport_supports_ssl
0045 )
0046 {
0047     auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
0048     capabilities server_caps = hello.server_capabilities;
0049     capabilities required_caps = mandatory_capabilities |
0050                                  conditional_capability(!params.database().empty(), CLIENT_CONNECT_WITH_DB) |
0051                                  conditional_capability(params.multi_queries(), CLIENT_MULTI_STATEMENTS) |
0052                                  conditional_capability(ssl == ssl_mode::require, CLIENT_SSL);
0053     if (required_caps.has(CLIENT_SSL) && !server_caps.has(CLIENT_SSL))
0054     {
0055         // This happens if the server doesn't have SSL configured. This special
0056         // error code helps users diagnosing their problem a lot (server_unsupported doesn't).
0057         return make_error_code(client_errc::server_doesnt_support_ssl);
0058     }
0059     else if (!server_caps.has_all(required_caps))
0060     {
0061         return make_error_code(client_errc::server_unsupported);
0062     }
0063     negotiated_caps = server_caps & (required_caps | optional_capabilities |
0064                                      conditional_capability(ssl == ssl_mode::enable, CLIENT_SSL));
0065     return error_code();
0066 }
0067 
0068 class handshake_algo
0069 {
0070     int resume_point_{0};
0071     diagnostics* diag_;
0072     handshake_params hparams_;
0073     auth_response auth_resp_;
0074     std::uint8_t sequence_number_{0};
0075     bool secure_channel_{false};
0076 
0077     // Attempts to map the collection_id to a character set. We try to be conservative
0078     // here, since servers will happily accept unknown collation IDs, silently defaulting
0079     // to the server's default character set (often latin1, which is not Unicode).
0080     static character_set collation_id_to_charset(std::uint16_t collation_id)
0081     {
0082         switch (collation_id)
0083         {
0084         case mysql_collations::utf8mb4_bin:
0085         case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
0086         case mysql_collations::ascii_general_ci:
0087         case mysql_collations::ascii_bin: return ascii_charset;
0088         default: return character_set{};
0089         }
0090     }
0091 
0092     // Once the handshake is processed, the capabilities are stored in the connection state
0093     bool use_ssl(const connection_state_data& st) const { return st.current_capabilities.has(CLIENT_SSL); }
0094 
0095     error_code process_handshake(connection_state_data& st, span<const std::uint8_t> buffer)
0096     {
0097         // Deserialize server hello
0098         server_hello hello{};
0099         auto err = deserialize_server_hello(buffer, hello, *diag_);
0100         if (err)
0101             return err;
0102 
0103         // Check capabilities
0104         capabilities negotiated_caps;
0105         err = process_capabilities(hparams_, hello, negotiated_caps, st.supports_ssl());
0106         if (err)
0107             return err;
0108 
0109         // Set capabilities & db flavor
0110         st.current_capabilities = negotiated_caps;
0111         st.flavor = hello.server;
0112 
0113         // If we're using SSL, mark the channel as secure
0114         secure_channel_ = secure_channel_ || use_ssl(st);
0115 
0116         // Compute auth response
0117         return compute_auth_response(
0118             hello.auth_plugin_name,
0119             hparams_.password(),
0120             hello.auth_plugin_data.to_span(),
0121             secure_channel_,
0122             auth_resp_
0123         );
0124     }
0125 
0126     // Response to that initial greeting
0127     ssl_request compose_ssl_request(const connection_state_data& st)
0128     {
0129         return ssl_request{
0130             st.current_capabilities,
0131             static_cast<std::uint32_t>(max_packet_size),
0132             hparams_.connection_collation(),
0133         };
0134     }
0135 
0136     login_request compose_login_request(const connection_state_data& st)
0137     {
0138         return login_request{
0139             st.current_capabilities,
0140             static_cast<std::uint32_t>(max_packet_size),
0141             hparams_.connection_collation(),
0142             hparams_.username(),
0143             auth_resp_.data,
0144             hparams_.database(),
0145             auth_resp_.plugin_name,
0146         };
0147     }
0148 
0149     // Processes auth_switch and auth_more_data messages, and leaves the result in auth_resp_
0150     error_code process_auth_switch(auth_switch msg)
0151     {
0152         return compute_auth_response(
0153             msg.plugin_name,
0154             hparams_.password(),
0155             msg.auth_data,
0156             secure_channel_,
0157             auth_resp_
0158         );
0159     }
0160 
0161     error_code process_auth_more_data(span<const std::uint8_t> data)
0162     {
0163         return compute_auth_response(
0164             auth_resp_.plugin_name,
0165             hparams_.password(),
0166             data,
0167             secure_channel_,
0168             auth_resp_
0169         );
0170     }
0171 
0172     // Composes an auth_switch_response message with the contents of auth_resp_
0173     auth_switch_response compose_auth_switch_response() const
0174     {
0175         return auth_switch_response{auth_resp_.data};
0176     }
0177 
0178     void on_success(connection_state_data& st, const ok_view& ok)
0179     {
0180         st.is_connected = true;
0181         st.backslash_escapes = ok.backslash_escapes();
0182         st.current_charset = collation_id_to_charset(hparams_.connection_collation());
0183     }
0184 
0185     error_code process_ok(connection_state_data& st)
0186     {
0187         ok_view res{};
0188         auto ec = deserialize_ok_packet(st.reader.message(), res);
0189         if (ec)
0190             return ec;
0191         on_success(st, res);
0192         return error_code();
0193     }
0194 
0195 public:
0196     handshake_algo(handshake_algo_params params) noexcept
0197         : diag_(params.diag), hparams_(params.hparams), secure_channel_(params.secure_channel)
0198     {
0199     }
0200 
0201     diagnostics& diag() { return *diag_; }
0202 
0203     next_action resume(connection_state_data& st, error_code ec)
0204     {
0205         if (ec)
0206             return ec;
0207 
0208         handhake_server_response resp(error_code{});
0209 
0210         switch (resume_point_)
0211         {
0212         case 0:
0213 
0214             // Setup
0215             diag_->clear();
0216             st.reset();
0217 
0218             // Read server greeting
0219             BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
0220 
0221             // Process server greeting
0222             ec = process_handshake(st, st.reader.message());
0223             if (ec)
0224                 return ec;
0225 
0226             // SSL
0227             if (use_ssl(st))
0228             {
0229                 // Send SSL request
0230                 BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
0231 
0232                 // SSL handshake
0233                 BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
0234 
0235                 // Mark the connection as using ssl
0236                 st.ssl = ssl_state::active;
0237             }
0238 
0239             // Compose and send handshake response
0240             BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
0241 
0242             // Auth message exchange
0243             while (true)
0244             {
0245                 // Receive response
0246                 BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
0247 
0248                 // Process it
0249                 resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, *diag_);
0250                 if (resp.type == handhake_server_response::type_t::ok)
0251                 {
0252                     // Auth success, quit
0253                     on_success(st, resp.data.ok);
0254                     return next_action();
0255                 }
0256                 else if (resp.type == handhake_server_response::type_t::error)
0257                 {
0258                     // Error, quit
0259                     return resp.data.err;
0260                 }
0261                 else if (resp.type == handhake_server_response::type_t::auth_switch)
0262                 {
0263                     // Compute response
0264                     ec = process_auth_switch(resp.data.auth_sw);
0265                     if (ec)
0266                         return ec;
0267 
0268                     BOOST_MYSQL_YIELD(
0269                         resume_point_,
0270                         6,
0271                         st.write(compose_auth_switch_response(), sequence_number_)
0272                     )
0273                 }
0274                 else if (resp.type == handhake_server_response::type_t::ok_follows)
0275                 {
0276                     // The next packet must be an OK packet. Read it
0277                     BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
0278 
0279                     // Process it
0280                     // Regardless of whether we succeeded or not, we're done
0281                     return process_ok(st);
0282                 }
0283                 else
0284                 {
0285                     BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
0286 
0287                     // Compute response
0288                     ec = process_auth_more_data(resp.data.more_data);
0289                     if (ec)
0290                         return ec;
0291 
0292                     // Write response
0293                     BOOST_MYSQL_YIELD(
0294                         resume_point_,
0295                         8,
0296                         st.write(compose_auth_switch_response(), sequence_number_)
0297                     )
0298                 }
0299             }
0300         }
0301 
0302         return next_action();
0303     }
0304 };
0305 
0306 }  // namespace detail
0307 }  // namespace mysql
0308 }  // namespace boost
0309 
0310 #endif