diff --git a/include/ircd/net/acceptor.h b/include/ircd/net/acceptor.h index 9946af9a3..a4a78151c 100644 --- a/include/ircd/net/acceptor.h +++ b/include/ircd/net/acceptor.h @@ -61,8 +61,8 @@ struct ircd::net::acceptor void configure(const json::object &opts); // Handshake stack - bool handle_sni(SSL &, int &ad); - string_view handle_alpn(SSL &, const vector_view &in); + bool handle_sni(socket &, int &ad); + string_view handle_alpn(socket &, const vector_view &in); void check_handshake_error(const error_code &ec, socket &) const; void handshake(const error_code &, const std::shared_ptr, const decltype(handshaking)::const_iterator) noexcept; diff --git a/ircd/net_listener.cc b/ircd/net_listener.cc index a77ceec9d..51b0e5658 100644 --- a/ircd/net_listener.cc +++ b/ircd/net_listener.cc @@ -575,6 +575,8 @@ noexcept try sock->set_timeout(milliseconds(timeout)); sock->ssl.async_handshake(handshake_type, ios::handle(desc, std::move(handshake))); + assert(!openssl::get_app_data(*sock)); + openssl::set_app_data(*sock, sock.get()); } catch(const ctx::interrupted &e) { @@ -705,6 +707,7 @@ noexcept try assert(bool(sock)); assert(!handshaking.empty()); assert(it != end(handshaking)); + assert(openssl::get_app_data(*sock) == sock.get()); #ifdef RB_DEBUG const auto *const current_cipher @@ -730,6 +733,7 @@ noexcept try #endif handshaking.erase(it); + openssl::set_app_data(*sock, nullptr); check_handshake_error(ec, *sock); sock->cancel_timeout(); assert(bool(cb)); @@ -816,7 +820,7 @@ const } ircd::string_view -ircd::net::acceptor::handle_alpn(SSL &ssl, +ircd::net::acceptor::handle_alpn(socket &socket, const vector_view &in) { if(empty(in)) @@ -824,7 +828,8 @@ ircd::net::acceptor::handle_alpn(SSL &ssl, log::debug { - log, "%s offered %zu ALPN protocols", + log, "%s %s offered %zu ALPN protocols", + loghead(socket), loghead(*this), size(in), }; @@ -835,7 +840,7 @@ ircd::net::acceptor::handle_alpn(SSL &ssl, log::debug { log, "%s ALPN protocol %zu of %zu: '%s'", - loghead(*this), + loghead(socket), i, size(in), in[i], @@ -891,9 +896,19 @@ noexcept try protos, p }; + assert(s); + assert(ircd::openssl::get_app_data(*s)); + if(unlikely(!ircd::openssl::get_app_data(*s))) + return SSL_TLSEXT_ERR_ALERT_FATAL; + + auto &socket + { + *static_cast(ircd::openssl::get_app_data(*s)) + }; + const ircd::string_view sel { - acceptor.handle_alpn(*s, vec) + acceptor.handle_alpn(socket, vec) }; if(!sel) @@ -920,13 +935,13 @@ catch(...) } bool -ircd::net::acceptor::handle_sni(SSL &ssl, +ircd::net::acceptor::handle_sni(socket &socket, int &client_server) try { const string_view &name { - openssl::server_name(ssl) + openssl::server_name(socket) }; if(!name) @@ -946,7 +961,8 @@ try { log::dwarning { - log, "%s unrecognized SNI '%s' offered.", + log, "%s %s unrecognized SNI '%s' offered.", + loghead(socket), loghead(*this), name, }; @@ -956,9 +972,10 @@ try log::debug { - log, "%s offered SNI '%s'", + log, "%s %s offered SNI '%s'", + loghead(socket), loghead(*this), - name + name, }; return true; @@ -967,9 +984,10 @@ catch(const sni_warning &e) { log::warning { - log, "%s during SNI :%s", + log, "%s %s during SNI :%s", + loghead(socket), loghead(*this), - e.what() + e.what(), }; throw; @@ -978,9 +996,10 @@ catch(const std::exception &e) { log::error { - log, "%s during SNI :%s", + log, "%s %s during SNI :%s", + loghead(socket), loghead(*this), - e.what() + e.what(), }; throw; @@ -1003,7 +1022,17 @@ noexcept try *reinterpret_cast(a) }; - return acceptor.handle_sni(*s, *i)? + assert(s); + assert(ircd::openssl::get_app_data(*s)); + if(unlikely(!ircd::openssl::get_app_data(*s))) + return SSL_TLSEXT_ERR_ALERT_FATAL; + + auto &socket + { + *static_cast(ircd::openssl::get_app_data(*s)) + }; + + return acceptor.handle_sni(socket, *i)? SSL_TLSEXT_ERR_OK: SSL_TLSEXT_ERR_NOACK; }