File indexing completed on 2025-04-19 08:33:52
0001
0002
0003
0004
0005
0006
0007
0008 #ifndef BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0009 #define BOOST_MYSQL_IMPL_INTERNAL_SANSIO_HANDSHAKE_HPP
0010
0011 #include <boost/mysql/character_set.hpp>
0012 #include <boost/mysql/diagnostics.hpp>
0013 #include <boost/mysql/error_code.hpp>
0014 #include <boost/mysql/handshake_params.hpp>
0015 #include <boost/mysql/mysql_collations.hpp>
0016
0017 #include <boost/mysql/detail/algo_params.hpp>
0018 #include <boost/mysql/detail/next_action.hpp>
0019 #include <boost/mysql/detail/ok_view.hpp>
0020
0021 #include <boost/mysql/impl/internal/auth/auth.hpp>
0022 #include <boost/mysql/impl/internal/coroutine.hpp>
0023 #include <boost/mysql/impl/internal/protocol/capabilities.hpp>
0024 #include <boost/mysql/impl/internal/protocol/db_flavor.hpp>
0025 #include <boost/mysql/impl/internal/protocol/deserialization.hpp>
0026 #include <boost/mysql/impl/internal/protocol/serialization.hpp>
0027 #include <boost/mysql/impl/internal/sansio/connection_state_data.hpp>
0028
0029 #include <cstdint>
0030
0031 namespace boost {
0032 namespace mysql {
0033 namespace detail {
0034
0035 inline capabilities conditional_capability(bool condition, std::uint32_t cap)
0036 {
0037 return capabilities(condition ? cap : 0);
0038 }
0039
0040 inline error_code process_capabilities(
0041 const handshake_params& params,
0042 const server_hello& hello,
0043 capabilities& negotiated_caps,
0044 bool transport_supports_ssl
0045 )
0046 {
0047 auto ssl = transport_supports_ssl ? params.ssl() : ssl_mode::disable;
0048 capabilities server_caps = hello.server_capabilities;
0049 capabilities required_caps = mandatory_capabilities |
0050 conditional_capability(!params.database().empty(), CLIENT_CONNECT_WITH_DB) |
0051 conditional_capability(params.multi_queries(), CLIENT_MULTI_STATEMENTS) |
0052 conditional_capability(ssl == ssl_mode::require, CLIENT_SSL);
0053 if (required_caps.has(CLIENT_SSL) && !server_caps.has(CLIENT_SSL))
0054 {
0055
0056
0057 return make_error_code(client_errc::server_doesnt_support_ssl);
0058 }
0059 else if (!server_caps.has_all(required_caps))
0060 {
0061 return make_error_code(client_errc::server_unsupported);
0062 }
0063 negotiated_caps = server_caps & (required_caps | optional_capabilities |
0064 conditional_capability(ssl == ssl_mode::enable, CLIENT_SSL));
0065 return error_code();
0066 }
0067
0068 class handshake_algo
0069 {
0070 int resume_point_{0};
0071 diagnostics* diag_;
0072 handshake_params hparams_;
0073 auth_response auth_resp_;
0074 std::uint8_t sequence_number_{0};
0075 bool secure_channel_{false};
0076
0077
0078
0079
0080 static character_set collation_id_to_charset(std::uint16_t collation_id)
0081 {
0082 switch (collation_id)
0083 {
0084 case mysql_collations::utf8mb4_bin:
0085 case mysql_collations::utf8mb4_general_ci: return utf8mb4_charset;
0086 case mysql_collations::ascii_general_ci:
0087 case mysql_collations::ascii_bin: return ascii_charset;
0088 default: return character_set{};
0089 }
0090 }
0091
0092
0093 bool use_ssl(const connection_state_data& st) const { return st.current_capabilities.has(CLIENT_SSL); }
0094
0095 error_code process_handshake(connection_state_data& st, span<const std::uint8_t> buffer)
0096 {
0097
0098 server_hello hello{};
0099 auto err = deserialize_server_hello(buffer, hello, *diag_);
0100 if (err)
0101 return err;
0102
0103
0104 capabilities negotiated_caps;
0105 err = process_capabilities(hparams_, hello, negotiated_caps, st.supports_ssl());
0106 if (err)
0107 return err;
0108
0109
0110 st.current_capabilities = negotiated_caps;
0111 st.flavor = hello.server;
0112
0113
0114 secure_channel_ = secure_channel_ || use_ssl(st);
0115
0116
0117 return compute_auth_response(
0118 hello.auth_plugin_name,
0119 hparams_.password(),
0120 hello.auth_plugin_data.to_span(),
0121 secure_channel_,
0122 auth_resp_
0123 );
0124 }
0125
0126
0127 ssl_request compose_ssl_request(const connection_state_data& st)
0128 {
0129 return ssl_request{
0130 st.current_capabilities,
0131 static_cast<std::uint32_t>(max_packet_size),
0132 hparams_.connection_collation(),
0133 };
0134 }
0135
0136 login_request compose_login_request(const connection_state_data& st)
0137 {
0138 return login_request{
0139 st.current_capabilities,
0140 static_cast<std::uint32_t>(max_packet_size),
0141 hparams_.connection_collation(),
0142 hparams_.username(),
0143 auth_resp_.data,
0144 hparams_.database(),
0145 auth_resp_.plugin_name,
0146 };
0147 }
0148
0149
0150 error_code process_auth_switch(auth_switch msg)
0151 {
0152 return compute_auth_response(
0153 msg.plugin_name,
0154 hparams_.password(),
0155 msg.auth_data,
0156 secure_channel_,
0157 auth_resp_
0158 );
0159 }
0160
0161 error_code process_auth_more_data(span<const std::uint8_t> data)
0162 {
0163 return compute_auth_response(
0164 auth_resp_.plugin_name,
0165 hparams_.password(),
0166 data,
0167 secure_channel_,
0168 auth_resp_
0169 );
0170 }
0171
0172
0173 auth_switch_response compose_auth_switch_response() const
0174 {
0175 return auth_switch_response{auth_resp_.data};
0176 }
0177
0178 void on_success(connection_state_data& st, const ok_view& ok)
0179 {
0180 st.is_connected = true;
0181 st.backslash_escapes = ok.backslash_escapes();
0182 st.current_charset = collation_id_to_charset(hparams_.connection_collation());
0183 }
0184
0185 error_code process_ok(connection_state_data& st)
0186 {
0187 ok_view res{};
0188 auto ec = deserialize_ok_packet(st.reader.message(), res);
0189 if (ec)
0190 return ec;
0191 on_success(st, res);
0192 return error_code();
0193 }
0194
0195 public:
0196 handshake_algo(handshake_algo_params params) noexcept
0197 : diag_(params.diag), hparams_(params.hparams), secure_channel_(params.secure_channel)
0198 {
0199 }
0200
0201 diagnostics& diag() { return *diag_; }
0202
0203 next_action resume(connection_state_data& st, error_code ec)
0204 {
0205 if (ec)
0206 return ec;
0207
0208 handhake_server_response resp(error_code{});
0209
0210 switch (resume_point_)
0211 {
0212 case 0:
0213
0214
0215 diag_->clear();
0216 st.reset();
0217
0218
0219 BOOST_MYSQL_YIELD(resume_point_, 1, st.read(sequence_number_))
0220
0221
0222 ec = process_handshake(st, st.reader.message());
0223 if (ec)
0224 return ec;
0225
0226
0227 if (use_ssl(st))
0228 {
0229
0230 BOOST_MYSQL_YIELD(resume_point_, 2, st.write(compose_ssl_request(st), sequence_number_))
0231
0232
0233 BOOST_MYSQL_YIELD(resume_point_, 3, next_action::ssl_handshake())
0234
0235
0236 st.ssl = ssl_state::active;
0237 }
0238
0239
0240 BOOST_MYSQL_YIELD(resume_point_, 4, st.write(compose_login_request(st), sequence_number_))
0241
0242
0243 while (true)
0244 {
0245
0246 BOOST_MYSQL_YIELD(resume_point_, 5, st.read(sequence_number_))
0247
0248
0249 resp = deserialize_handshake_server_response(st.reader.message(), st.flavor, *diag_);
0250 if (resp.type == handhake_server_response::type_t::ok)
0251 {
0252
0253 on_success(st, resp.data.ok);
0254 return next_action();
0255 }
0256 else if (resp.type == handhake_server_response::type_t::error)
0257 {
0258
0259 return resp.data.err;
0260 }
0261 else if (resp.type == handhake_server_response::type_t::auth_switch)
0262 {
0263
0264 ec = process_auth_switch(resp.data.auth_sw);
0265 if (ec)
0266 return ec;
0267
0268 BOOST_MYSQL_YIELD(
0269 resume_point_,
0270 6,
0271 st.write(compose_auth_switch_response(), sequence_number_)
0272 )
0273 }
0274 else if (resp.type == handhake_server_response::type_t::ok_follows)
0275 {
0276
0277 BOOST_MYSQL_YIELD(resume_point_, 7, st.read(sequence_number_))
0278
0279
0280
0281 return process_ok(st);
0282 }
0283 else
0284 {
0285 BOOST_ASSERT(resp.type == handhake_server_response::type_t::auth_more_data);
0286
0287
0288 ec = process_auth_more_data(resp.data.more_data);
0289 if (ec)
0290 return ec;
0291
0292
0293 BOOST_MYSQL_YIELD(
0294 resume_point_,
0295 8,
0296 st.write(compose_auth_switch_response(), sequence_number_)
0297 )
0298 }
0299 }
0300 }
0301
0302 return next_action();
0303 }
0304 };
0305
0306 }
0307 }
0308 }
0309
0310 #endif