File indexing completed on 2025-09-15 08:42:57
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0010
0011 #include <boost/mysql/character_set.hpp>
0012 #include <boost/mysql/diagnostics.hpp>
0013 #include <boost/mysql/error_code.hpp>
0014 #include <boost/mysql/handshake_params.hpp>
0015 #include <boost/mysql/mysql_collations.hpp>
0016
0017 #include <boost/mysql/detail/algo_params.hpp>
0018 #include <boost/mysql/detail/next_action.hpp>
0019 #include <boost/mysql/detail/ok_view.hpp>
0020
0021 #include <boost/mysql/impl/internal/auth/auth.hpp>
0022 #include <boost/mysql/impl/internal/coroutine.hpp>
0023 #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
0024 #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
0025 #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
0026 #include <boost/mysql/impl/internal/protocol/serialization.hpp>
0027 #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
0028
0029 #include <cstdint>
0030
0031 namespace boost {
0032 namespace mysql {
0033 namespace detail {
0034
0035 inline capabilities conditional_capability(bool condition, std::uint32_t cap)
0036 {
0037 return capabilities(condition ? cap : 0);
0038 }
0039
0040 inline error_code process_capabilities(
0041 const handshake_params& params,
0042 const server_hello& hello,
0043 capabilities& negotiated_caps,
0044 bool transport_supports_ssl
0045 )
0046 {
0047 auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
0048 capabilities server_caps = hello.server_capabilities;
0049 capabilities required_caps = mandatory_capabilities |
0050 conditional_capability(!params.database().empty(), CLIENT_CONNECT_WITH_DB) |
0051 conditional_capability(params.multi_queries(), CLIENT_MULTI_STATEMENTS) |
0052 conditional_capability(ssl == ssl_mode::require, CLIENT_SSL);
0053 if (required_caps.has(CLIENT_SSL) && !server_caps.has(CLIENT_SSL))
0054 {
0055
0056
0057 return make_error_code(client_errc::server_doesnt_support_ssl);
0058 }
0059 else if (!server_caps.has_all(required_caps))
0060 {
0061 return make_error_code(client_errc::server_unsupported);
0062 }
0063 negotiated_caps = server_caps & (required_caps | optional_capabilities |
0064 conditional_capability(ssl == ssl_mode::enable, CLIENT_SSL));
0065 return error_code();
0066 }
0067
0068 class handshake_algo
0069 {
0070 int resume_point_{0};
0071 handshake_params hparams_;
0072 auth_response auth_resp_;
0073 std::uint8_t sequence_number_{0};
0074 bool secure_channel_{false};
0075
0076
0077
0078
0079 static character_set collation_id_to_charset(std::uint16_t collation_id)
0080 {
0081 switch (collation_id)
0082 {
0083 case mysql_collations::utf8mb4_bin:
0084 case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
0085 case mysql_collations::ascii_general_ci:
0086 case mysql_collations::ascii_bin: return ascii_charset;
0087 default: return character_set{};
0088 }
0089 }
0090
0091
0092 bool use_ssl(const connection_state_data& st) const { return st.current_capabilities.has(CLIENT_SSL); }
0093
0094 error_code process_handshake(
0095 connection_state_data& st,
0096 diagnostics& diag,
0097 span<const std::uint8_t> buffer
0098 )
0099 {
0100
0101 server_hello hello{};
0102 auto err = deserialize_server_hello(buffer, hello, diag);
0103 if (err)
0104 return err;
0105
0106
0107 capabilities negotiated_caps;
0108 err = process_capabilities(hparams_, hello, negotiated_caps, st.tls_supported);
0109 if (err)
0110 return err;
0111
0112
0113 st.current_capabilities = negotiated_caps;
0114 st.flavor = hello.server;
0115 st.connection_id = hello.connection_id;
0116
0117
0118 secure_channel_ = secure_channel_ || use_ssl(st);
0119
0120
0121 return compute_auth_response(
0122 hello.auth_plugin_name,
0123 hparams_.password(),
0124 hello.auth_plugin_data.to_span(),
0125 secure_channel_,
0126 auth_resp_
0127 );
0128 }
0129
0130
0131 ssl_request compose_ssl_request(const connection_state_data& st)
0132 {
0133 return ssl_request{
0134 st.current_capabilities,
0135 static_cast<std::uint32_t>(max_packet_size),
0136 hparams_.connection_collation(),
0137 };
0138 }
0139
0140 login_request compose_login_request(const connection_state_data& st)
0141 {
0142 return login_request{
0143 st.current_capabilities,
0144 static_cast<std::uint32_t>(max_packet_size),
0145 hparams_.connection_collation(),
0146 hparams_.username(),
0147 auth_resp_.data,
0148 hparams_.database(),
0149 auth_resp_.plugin_name,
0150 };
0151 }
0152
0153
0154 error_code process_auth_switch(auth_switch msg)
0155 {
0156 return compute_auth_response(
0157 msg.plugin_name,
0158 hparams_.password(),
0159 msg.auth_data,
0160 secure_channel_,
0161 auth_resp_
0162 );
0163 }
0164
0165 error_code process_auth_more_data(span<const std::uint8_t> data)
0166 {
0167 return compute_auth_response(
0168 auth_resp_.plugin_name,
0169 hparams_.password(),
0170 data,
0171 secure_channel_,
0172 auth_resp_
0173 );
0174 }
0175
0176
0177 auth_switch_response compose_auth_switch_response() const
0178 {
0179 return auth_switch_response{auth_resp_.data};
0180 }
0181
0182 void on_success(connection_state_data& st, const ok_view& ok)
0183 {
0184 st.status = connection_status::ready;
0185 st.backslash_escapes = ok.backslash_escapes();
0186 st.current_charset = collation_id_to_charset(hparams_.connection_collation());
0187 }
0188
0189 error_code process_ok(connection_state_data& st)
0190 {
0191 ok_view res{};
0192 auto ec = deserialize_ok_packet(st.reader.message(), res);
0193 if (ec)
0194 return ec;
0195 on_success(st, res);
0196 return error_code();
0197 }
0198
0199 public:
0200 handshake_algo(handshake_algo_params params) noexcept
0201 : hparams_(params.hparams), secure_channel_(params.secure_channel)
0202 {
0203 }
0204
0205 next_action resume(connection_state_data& st, diagnostics& diag, error_code ec)
0206 {
0207 if (ec)
0208 return ec;
0209
0210 handhake_server_response resp(error_code{});
0211
0212 switch (resume_point_)
0213 {
0214 case 0:
0215
0216
0217 st.reset();
0218
0219
0220 BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
0221
0222
0223 ec = process_handshake(st, diag, st.reader.message());
0224 if (ec)
0225 return ec;
0226
0227
0228 if (use_ssl(st))
0229 {
0230
0231 BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
0232
0233
0234 BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
0235
0236
0237 st.tls_active = true;
0238 }
0239
0240
0241 BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
0242
0243
0244 while (true)
0245 {
0246
0247 BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
0248
0249
0250 resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, diag);
0251 if (resp.type == handhake_server_response::type_t::ok)
0252 {
0253
0254 on_success(st, resp.data.ok);
0255 return next_action();
0256 }
0257 else if (resp.type == handhake_server_response::type_t::error)
0258 {
0259
0260 return resp.data.err;
0261 }
0262 else if (resp.type == handhake_server_response::type_t::auth_switch)
0263 {
0264
0265 ec = process_auth_switch(resp.data.auth_sw);
0266 if (ec)
0267 return ec;
0268
0269 BOOST_MYSQL_YIELD(
0270 resume_point_,
0271 6,
0272 st.write(compose_auth_switch_response(), sequence_number_)
0273 )
0274 }
0275 else if (resp.type == handhake_server_response::type_t::ok_follows)
0276 {
0277
0278 BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
0279
0280
0281
0282 return process_ok(st);
0283 }
0284 else
0285 {
0286 BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
0287
0288
0289 ec = process_auth_more_data(resp.data.more_data);
0290 if (ec)
0291 return ec;
0292
0293
0294 BOOST_MYSQL_YIELD(
0295 resume_point_,
0296 8,
0297 st.write(compose_auth_switch_response(), sequence_number_)
0298 )
0299 }
0300 }
0301 }
0302
0303 return next_action();
0304 }
0305 };
0306
0307 }
0308 }
0309 }
0310
0311 #endif