diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp index ef5f6f5c20..9e077317c0 100644 --- a/modules/websocket/websocket_server.cpp +++ b/modules/websocket/websocket_server.cpp @@ -49,12 +49,51 @@ void WebSocketServer::_bind_methods() { ClassDB::bind_method(D_METHOD("get_peer_port", "id"), &WebSocketServer::get_peer_port); ClassDB::bind_method(D_METHOD("disconnect_peer", "id", "code", "reason"), &WebSocketServer::disconnect_peer, DEFVAL(1000), DEFVAL("")); + ClassDB::bind_method(D_METHOD("get_private_key"), &WebSocketServer::get_private_key); + ClassDB::bind_method(D_METHOD("set_private_key"), &WebSocketServer::set_private_key); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "private_key", PROPERTY_HINT_RESOURCE_TYPE, "CryptoKey", 0), "set_private_key", "get_private_key"); + + ClassDB::bind_method(D_METHOD("get_ssl_certificate"), &WebSocketServer::get_ssl_certificate); + ClassDB::bind_method(D_METHOD("set_ssl_certificate"), &WebSocketServer::set_ssl_certificate); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "ssl_certificate", PROPERTY_HINT_RESOURCE_TYPE, "X509Certificate", 0), "set_ssl_certificate", "get_ssl_certificate"); + + ClassDB::bind_method(D_METHOD("get_ca_chain"), &WebSocketServer::get_ca_chain); + ClassDB::bind_method(D_METHOD("set_ca_chain"), &WebSocketServer::set_ca_chain); + ADD_PROPERTY(PropertyInfo(Variant::OBJECT, "ca_chain", PROPERTY_HINT_RESOURCE_TYPE, "X509Certificate", 0), "set_ca_chain", "get_ca_chain"); + ADD_SIGNAL(MethodInfo("client_close_request", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::INT, "code"), PropertyInfo(Variant::STRING, "reason"))); ADD_SIGNAL(MethodInfo("client_disconnected", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::BOOL, "was_clean_close"))); ADD_SIGNAL(MethodInfo("client_connected", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::STRING, "protocol"))); ADD_SIGNAL(MethodInfo("data_received", PropertyInfo(Variant::INT, "id"))); } +Ref WebSocketServer::get_private_key() const { + return private_key; +} + +void WebSocketServer::set_private_key(Ref p_key) { + ERR_FAIL_COND(is_listening()); + private_key = p_key; +} + +Ref WebSocketServer::get_ssl_certificate() const { + return ssl_cert; +} + +void WebSocketServer::set_ssl_certificate(Ref p_cert) { + ERR_FAIL_COND(is_listening()); + ssl_cert = p_cert; +} + +Ref WebSocketServer::get_ca_chain() const { + return ca_chain; +} + +void WebSocketServer::set_ca_chain(Ref p_ca_chain) { + ERR_FAIL_COND(is_listening()); + ca_chain = p_ca_chain; +} + NetworkedMultiplayerPeer::ConnectionStatus WebSocketServer::get_connection_status() const { if (is_listening()) return CONNECTION_CONNECTED; diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h index 83c0c10419..2e5c706f33 100644 --- a/modules/websocket/websocket_server.h +++ b/modules/websocket/websocket_server.h @@ -31,6 +31,7 @@ #ifndef WEBSOCKET_H #define WEBSOCKET_H +#include "core/crypto/crypto.h" #include "core/reference.h" #include "websocket_multiplayer_peer.h" #include "websocket_peer.h" @@ -43,6 +44,10 @@ class WebSocketServer : public WebSocketMultiplayerPeer { protected: static void _bind_methods(); + Ref private_key; + Ref ssl_cert; + Ref ca_chain; + public: virtual void poll() = 0; virtual Error listen(int p_port, PoolVector p_protocols = PoolVector(), bool gd_mp_api = false) = 0; @@ -62,6 +67,15 @@ public: void _on_disconnect(int32_t p_peer_id, bool p_was_clean); void _on_close_request(int32_t p_peer_id, int p_code, String p_reason); + Ref get_private_key() const; + void set_private_key(Ref p_key); + + Ref get_ssl_certificate() const; + void set_ssl_certificate(Ref p_cert); + + Ref get_ca_chain() const; + void set_ca_chain(Ref p_ca_chain); + virtual Error set_buffers(int p_in_buffer, int p_in_packets, int p_out_buffer, int p_out_packets) = 0; WebSocketServer(); diff --git a/modules/websocket/wsl_server.cpp b/modules/websocket/wsl_server.cpp index dbf153bd4a..e3101d99d7 100644 --- a/modules/websocket/wsl_server.cpp +++ b/modules/websocket/wsl_server.cpp @@ -35,6 +35,7 @@ #include "core/project_settings.h" WSLServer::PendingPeer::PendingPeer() { + use_ssl = false; time = 0; has_request = false; response_sent = 0; @@ -100,6 +101,16 @@ bool WSLServer::PendingPeer::_parse_request(const PoolStringArray p_protocols) { Error WSLServer::PendingPeer::do_handshake(PoolStringArray p_protocols) { if (OS::get_singleton()->get_ticks_msec() - time > WSL_SERVER_TIMEOUT) return ERR_TIMEOUT; + if (use_ssl) { + Ref ssl = static_cast >(connection); + if (ssl.is_null()) + return FAILED; + ssl->poll(); + if (ssl->get_status() == StreamPeerSSL::STATUS_HANDSHAKING) + return ERR_BUSY; + else if (ssl->get_status() != StreamPeerSSL::STATUS_CONNECTED) + return FAILED; + } if (!has_request) { int read = 0; while (true) { @@ -210,7 +221,15 @@ void WSLServer::poll() { continue; // Conn will go out-of-scope and be closed. Ref peer = memnew(PendingPeer); - peer->connection = conn; + if (private_key.is_valid() && ssl_cert.is_valid()) { + Ref ssl = Ref(StreamPeerSSL::create()); + ssl->set_blocking_handshake_enabled(false); + ssl->accept_stream(conn, private_key, ssl_cert, ca_chain); + peer->connection = ssl; + peer->use_ssl = true; + } else { + peer->connection = conn; + } peer->tcp = conn; peer->time = OS::get_singleton()->get_ticks_msec(); _pending.push_back(peer); diff --git a/modules/websocket/wsl_server.h b/modules/websocket/wsl_server.h index de6d6d77f6..2ac873cba5 100644 --- a/modules/websocket/wsl_server.h +++ b/modules/websocket/wsl_server.h @@ -36,6 +36,7 @@ #include "websocket_server.h" #include "wsl_peer.h" +#include "core/io/stream_peer_ssl.h" #include "core/io/stream_peer_tcp.h" #include "core/io/tcp_server.h" @@ -54,6 +55,7 @@ private: public: Ref tcp; Ref connection; + bool use_ssl; int time; uint8_t req_buf[WSL_MAX_HEADER_SIZE];