From 5d91e87c64f757fb1c20f43a0a30db3c73073625 Mon Sep 17 00:00:00 2001 From: Fabio Alessandrelli Date: Sun, 23 Sep 2018 21:14:20 +0200 Subject: [PATCH] Implement WebSocket close notify. --- modules/websocket/emws_client.cpp | 5 ++-- modules/websocket/emws_client.h | 2 +- modules/websocket/emws_peer.cpp | 8 +++-- modules/websocket/emws_peer.h | 2 +- modules/websocket/emws_server.cpp | 2 +- modules/websocket/emws_server.h | 2 +- modules/websocket/lws_client.cpp | 17 ++++++++--- modules/websocket/lws_client.h | 2 +- modules/websocket/lws_peer.cpp | 41 ++++++++++++++++++++++++-- modules/websocket/lws_peer.h | 7 ++++- modules/websocket/lws_server.cpp | 30 +++++++++++++++---- modules/websocket/lws_server.h | 2 +- modules/websocket/websocket_client.cpp | 8 ++++- modules/websocket/websocket_client.h | 3 +- modules/websocket/websocket_peer.cpp | 2 +- modules/websocket/websocket_peer.h | 2 +- modules/websocket/websocket_server.cpp | 8 ++++- modules/websocket/websocket_server.h | 3 +- 18 files changed, 116 insertions(+), 30 deletions(-) diff --git a/modules/websocket/emws_client.cpp b/modules/websocket/emws_client.cpp index 836b564de8..413ef661b1 100644 --- a/modules/websocket/emws_client.cpp +++ b/modules/websocket/emws_client.cpp @@ -55,6 +55,7 @@ EMSCRIPTEN_KEEPALIVE void _esws_on_error(void *obj) { EMSCRIPTEN_KEEPALIVE void _esws_on_close(void *obj, int code, char *reason, int was_clean) { EMWSClient *client = static_cast(obj); + client->_on_close_request(code, String(reason)); client->_is_connecting = false; client->_on_disconnect(); } @@ -182,9 +183,9 @@ NetworkedMultiplayerPeer::ConnectionStatus EMWSClient::get_connection_status() c return CONNECTION_DISCONNECTED; }; -void EMWSClient::disconnect_from_host() { +void EMWSClient::disconnect_from_host(int p_code, String p_reason) { - _peer->close(); + _peer->close(p_code, p_reason); }; IP_Address EMWSClient::get_connected_host() const { diff --git a/modules/websocket/emws_client.h b/modules/websocket/emws_client.h index 6c2fa23b53..b20633baff 100644 --- a/modules/websocket/emws_client.h +++ b/modules/websocket/emws_client.h @@ -48,7 +48,7 @@ public: Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, PoolVector p_protocol = PoolVector()); Ref get_peer(int p_peer_id) const; - void disconnect_from_host(); + void disconnect_from_host(int p_code = 1000, String p_reason = ""); IP_Address get_connected_host() const; uint16_t get_connected_port() const; virtual ConnectionStatus get_connection_status() const; diff --git a/modules/websocket/emws_peer.cpp b/modules/websocket/emws_peer.cpp index e0b987b4d7..61b2e3b01e 100644 --- a/modules/websocket/emws_peer.cpp +++ b/modules/websocket/emws_peer.cpp @@ -130,15 +130,17 @@ bool EMWSPeer::is_connected_to_host() const { return peer_sock != -1; }; -void EMWSPeer::close() { +void EMWSPeer::close(int p_code, String p_reason) { if (peer_sock != -1) { /* clang-format off */ EM_ASM({ var sock = Module.IDHandler.get($0); - sock.close(); + var code = $1; + var reason = UTF8ToString($2); + sock.close(code, reason); Module.IDHandler.remove($0); - }, peer_sock); + }, peer_sock, p_code); /* clang-format on */ } peer_sock = -1; diff --git a/modules/websocket/emws_peer.h b/modules/websocket/emws_peer.h index e06f725265..a4b2c8f50b 100644 --- a/modules/websocket/emws_peer.h +++ b/modules/websocket/emws_peer.h @@ -63,7 +63,7 @@ public: virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size); virtual int get_max_packet_size() const { return PACKET_BUFFER_SIZE; }; - virtual void close(); + virtual void close(int p_code = 1000, String p_reason = ""); virtual bool is_connected_to_host() const; virtual IP_Address get_connected_host() const; virtual uint16_t get_connected_port() const; diff --git a/modules/websocket/emws_server.cpp b/modules/websocket/emws_server.cpp index 3eb93e4152..ad4a758c0f 100644 --- a/modules/websocket/emws_server.cpp +++ b/modules/websocket/emws_server.cpp @@ -68,7 +68,7 @@ int EMWSServer::get_peer_port(int p_peer_id) const { return 0; } -void EMWSServer::disconnect_peer(int p_peer_id) { +void EMWSServer::disconnect_peer(int p_peer_id, int p_code, String p_reason) { } EMWSServer::EMWSServer() { diff --git a/modules/websocket/emws_server.h b/modules/websocket/emws_server.h index 9ec4ce72c8..74b689a29b 100644 --- a/modules/websocket/emws_server.h +++ b/modules/websocket/emws_server.h @@ -48,7 +48,7 @@ public: Ref get_peer(int p_id) const; IP_Address get_peer_address(int p_peer_id) const; int get_peer_port(int p_peer_id) const; - void disconnect_peer(int p_peer_id); + void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = ""); virtual void poll(); virtual PoolVector get_protocols() const; diff --git a/modules/websocket/lws_client.cpp b/modules/websocket/lws_client.cpp index 09f6422058..a568b97a1a 100644 --- a/modules/websocket/lws_client.cpp +++ b/modules/websocket/lws_client.cpp @@ -137,6 +137,13 @@ int LWSClient::_handle_cb(struct lws *wsi, enum lws_callback_reasons reason, voi destroy_context(); return -1; // We should close the connection (would probably happen anyway) + case LWS_CALLBACK_WS_PEER_INITIATED_CLOSE: { + int code; + String reason = peer->get_close_reason(in, len, code); + _on_close_request(code, reason); + return 0; + } + case LWS_CALLBACK_CLIENT_CLOSED: peer->close(); destroy_context(); @@ -150,8 +157,10 @@ int LWSClient::_handle_cb(struct lws *wsi, enum lws_callback_reasons reason, voi break; case LWS_CALLBACK_CLIENT_WRITEABLE: - if (peer_data->force_close) + if (peer_data->force_close) { + peer->send_close_status(wsi); return -1; + } peer->write_wsi(); break; @@ -179,13 +188,12 @@ NetworkedMultiplayerPeer::ConnectionStatus LWSClient::get_connection_status() co return CONNECTION_CONNECTING; } -void LWSClient::disconnect_from_host() { +void LWSClient::disconnect_from_host(int p_code, String p_reason) { if (context == NULL) return; - _peer->close(); - destroy_context(); + _peer->close(p_code, p_reason); }; IP_Address LWSClient::get_connected_host() const { @@ -208,6 +216,7 @@ LWSClient::~LWSClient() { invalidate_lws_ref(); // We do not want any more callback disconnect_from_host(); + destroy_context(); _peer = Ref(); }; diff --git a/modules/websocket/lws_client.h b/modules/websocket/lws_client.h index 8850683cb5..1bbc19f352 100644 --- a/modules/websocket/lws_client.h +++ b/modules/websocket/lws_client.h @@ -46,7 +46,7 @@ class LWSClient : public WebSocketClient { public: Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, PoolVector p_protocol = PoolVector()); Ref get_peer(int p_peer_id) const; - void disconnect_from_host(); + void disconnect_from_host(int p_code = 1000, String p_reason = ""); IP_Address get_connected_host() const; uint16_t get_connected_port() const; virtual ConnectionStatus get_connection_status() const; diff --git a/modules/websocket/lws_peer.cpp b/modules/websocket/lws_peer.cpp index 0989357258..43ffcfdea1 100644 --- a/modules/websocket/lws_peer.cpp +++ b/modules/websocket/lws_peer.cpp @@ -178,11 +178,48 @@ bool LWSPeer::is_connected_to_host() const { return wsi != NULL; }; -void LWSPeer::close() { +String LWSPeer::get_close_reason(void *in, size_t len, int &r_code) { + String s; + r_code = 0; + if (len < 2) // From docs this should not happen + return s; + + const uint8_t *b = (const uint8_t *)in; + r_code = b[0] << 8 | b[1]; + + if (len > 2) { + const char *utf8 = (const char *)&b[2]; + s.parse_utf8(utf8, len - 2); + } + return s; +} + +void LWSPeer::send_close_status(struct lws *p_wsi) { + if (close_code == -1) + return; + + int len = close_reason.size(); + ERR_FAIL_COND(len > 123); // Maximum allowed reason size in bytes + + lws_close_status code = (lws_close_status)close_code; + unsigned char *reason = len > 0 ? (unsigned char *)close_reason.utf8().ptrw() : NULL; + + lws_close_reason(p_wsi, code, reason, len); + + close_code = -1; + close_reason = ""; +} + +void LWSPeer::close(int p_code, String p_reason) { if (wsi != NULL) { + close_code = p_code; + close_reason = p_reason; PeerData *data = ((PeerData *)lws_wsi_user(wsi)); data->force_close = true; - lws_callback_on_writable(wsi); // notify that we want to disconnect + lws_callback_on_writable(wsi); // Notify that we want to disconnect + } else { + close_code = -1; + close_reason = ""; } wsi = NULL; rbw.resize(0); diff --git a/modules/websocket/lws_peer.h b/modules/websocket/lws_peer.h index d7d46e3076..05a3880577 100644 --- a/modules/websocket/lws_peer.h +++ b/modules/websocket/lws_peer.h @@ -53,6 +53,9 @@ private: WriteMode write_mode; bool _was_string; + int close_code; + String close_reason; + public: struct PeerData { uint32_t peer_id; @@ -71,7 +74,7 @@ public: virtual Error put_packet(const uint8_t *p_buffer, int p_buffer_size); virtual int get_max_packet_size() const { return PACKET_BUFFER_SIZE; }; - virtual void close(); + virtual void close(int p_code = 1000, String p_reason = ""); virtual bool is_connected_to_host() const; virtual IP_Address get_connected_host() const; virtual uint16_t get_connected_port() const; @@ -83,6 +86,8 @@ public: void set_wsi(struct lws *wsi); Error read_wsi(void *in, size_t len); Error write_wsi(); + void send_close_status(struct lws *wsi); + String get_close_reason(void *in, size_t len, int &r_code); LWSPeer(); ~LWSPeer(); diff --git a/modules/websocket/lws_server.cpp b/modules/websocket/lws_server.cpp index 4a614f6933..d13a8024fb 100644 --- a/modules/websocket/lws_server.cpp +++ b/modules/websocket/lws_server.cpp @@ -90,11 +90,24 @@ int LWSServer::_handle_cb(struct lws *wsi, enum lws_callback_reasons reason, voi peer_data->peer_id = id; peer_data->force_close = false; - _on_connect(id, lws_get_protocol(wsi)->name); break; } + case LWS_CALLBACK_WS_PEER_INITIATED_CLOSE: { + if (peer_data == NULL) + return 0; + + int32_t id = peer_data->peer_id; + if (_peer_map.has(id)) { + int code; + Ref peer = _peer_map[id]; + String reason = peer->get_close_reason(in, len, code); + _on_close_request(id, code, reason); + } + return 0; + } + case LWS_CALLBACK_CLOSED: { if (peer_data == NULL) return 0; @@ -118,10 +131,15 @@ int LWSServer::_handle_cb(struct lws *wsi, enum lws_callback_reasons reason, voi } case LWS_CALLBACK_SERVER_WRITEABLE: { - if (peer_data->force_close) - return -1; - int id = peer_data->peer_id; + if (peer_data->force_close) { + if (_peer_map.has(id)) { + Ref peer = _peer_map[id]; + peer->send_close_status(wsi); + } + return -1; + } + if (_peer_map.has(id)) static_cast >(_peer_map[id])->write_wsi(); break; @@ -164,10 +182,10 @@ int LWSServer::get_peer_port(int p_peer_id) const { return _peer_map[p_peer_id]->get_connected_port(); } -void LWSServer::disconnect_peer(int p_peer_id) { +void LWSServer::disconnect_peer(int p_peer_id, int p_code, String p_reason) { ERR_FAIL_COND(!has_peer(p_peer_id)); - get_peer(p_peer_id)->close(); + get_peer(p_peer_id)->close(p_code, p_reason); } LWSServer::LWSServer() { diff --git a/modules/websocket/lws_server.h b/modules/websocket/lws_server.h index 9e3fb9b775..346773ebc4 100644 --- a/modules/websocket/lws_server.h +++ b/modules/websocket/lws_server.h @@ -54,7 +54,7 @@ public: Ref get_peer(int p_id) const; IP_Address get_peer_address(int p_peer_id) const; int get_peer_port(int p_peer_id) const; - void disconnect_peer(int p_peer_id); + void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = ""); virtual void poll() { _lws_poll(); } LWSServer(); diff --git a/modules/websocket/websocket_client.cpp b/modules/websocket/websocket_client.cpp index 7701163085..8211a7c361 100644 --- a/modules/websocket/websocket_client.cpp +++ b/modules/websocket/websocket_client.cpp @@ -107,6 +107,11 @@ void WebSocketClient::_on_connect(String p_protocol) { } } +void WebSocketClient::_on_close_request(int p_code, String p_reason) { + + emit_signal("server_close_request", p_code, p_reason); +} + void WebSocketClient::_on_disconnect() { if (_is_multiplayer) { @@ -127,7 +132,7 @@ void WebSocketClient::_on_error() { void WebSocketClient::_bind_methods() { ClassDB::bind_method(D_METHOD("connect_to_url", "url", "protocols", "gd_mp_api"), &WebSocketClient::connect_to_url, DEFVAL(PoolVector()), DEFVAL(false)); - ClassDB::bind_method(D_METHOD("disconnect_from_host"), &WebSocketClient::disconnect_from_host); + ClassDB::bind_method(D_METHOD("disconnect_from_host", "code", "reason"), &WebSocketClient::disconnect_from_host, DEFVAL(1000), DEFVAL("")); ClassDB::bind_method(D_METHOD("set_verify_ssl_enabled", "enabled"), &WebSocketClient::set_verify_ssl_enabled); ClassDB::bind_method(D_METHOD("is_verify_ssl_enabled"), &WebSocketClient::is_verify_ssl_enabled); @@ -135,6 +140,7 @@ void WebSocketClient::_bind_methods() { ADD_SIGNAL(MethodInfo("data_received")); ADD_SIGNAL(MethodInfo("connection_established", PropertyInfo(Variant::STRING, "protocol"))); + ADD_SIGNAL(MethodInfo("server_close_request", PropertyInfo(Variant::INT, "code"), PropertyInfo(Variant::STRING, "reason"))); ADD_SIGNAL(MethodInfo("connection_closed")); ADD_SIGNAL(MethodInfo("connection_error")); } diff --git a/modules/websocket/websocket_client.h b/modules/websocket/websocket_client.h index 6165f37d40..086e90318f 100644 --- a/modules/websocket/websocket_client.h +++ b/modules/websocket/websocket_client.h @@ -53,7 +53,7 @@ public: virtual void poll() = 0; virtual Error connect_to_host(String p_host, String p_path, uint16_t p_port, bool p_ssl, PoolVector p_protocol = PoolVector()) = 0; - virtual void disconnect_from_host() = 0; + virtual void disconnect_from_host(int p_code = 1000, String p_reason = "") = 0; virtual IP_Address get_connected_host() const = 0; virtual uint16_t get_connected_port() const = 0; @@ -62,6 +62,7 @@ public: void _on_peer_packet(); void _on_connect(String p_protocol); + void _on_close_request(int p_code, String p_reason); void _on_disconnect(); void _on_error(); diff --git a/modules/websocket/websocket_peer.cpp b/modules/websocket/websocket_peer.cpp index 61f783e377..3ecb32ce19 100644 --- a/modules/websocket/websocket_peer.cpp +++ b/modules/websocket/websocket_peer.cpp @@ -42,7 +42,7 @@ void WebSocketPeer::_bind_methods() { ClassDB::bind_method(D_METHOD("set_write_mode", "mode"), &WebSocketPeer::set_write_mode); ClassDB::bind_method(D_METHOD("is_connected_to_host"), &WebSocketPeer::is_connected_to_host); ClassDB::bind_method(D_METHOD("was_string_packet"), &WebSocketPeer::was_string_packet); - ClassDB::bind_method(D_METHOD("close"), &WebSocketPeer::close); + ClassDB::bind_method(D_METHOD("close", "code", "reason"), &WebSocketPeer::close, DEFVAL(1000), DEFVAL("")); ClassDB::bind_method(D_METHOD("get_connected_host"), &WebSocketPeer::get_connected_host); ClassDB::bind_method(D_METHOD("get_connected_port"), &WebSocketPeer::get_connected_port); diff --git a/modules/websocket/websocket_peer.h b/modules/websocket/websocket_peer.h index ad451e9cc7..5918fda3c2 100644 --- a/modules/websocket/websocket_peer.h +++ b/modules/websocket/websocket_peer.h @@ -58,7 +58,7 @@ public: virtual WriteMode get_write_mode() const = 0; virtual void set_write_mode(WriteMode p_mode) = 0; - virtual void close() = 0; + virtual void close(int p_code = 1000, String p_reason = "") = 0; virtual bool is_connected_to_host() const = 0; virtual IP_Address get_connected_host() const = 0; diff --git a/modules/websocket/websocket_server.cpp b/modules/websocket/websocket_server.cpp index 53dd7b51b7..ac3a8ca87d 100644 --- a/modules/websocket/websocket_server.cpp +++ b/modules/websocket/websocket_server.cpp @@ -46,8 +46,9 @@ void WebSocketServer::_bind_methods() { ClassDB::bind_method(D_METHOD("has_peer", "id"), &WebSocketServer::has_peer); ClassDB::bind_method(D_METHOD("get_peer_address", "id"), &WebSocketServer::get_peer_address); ClassDB::bind_method(D_METHOD("get_peer_port", "id"), &WebSocketServer::get_peer_port); - ClassDB::bind_method(D_METHOD("disconnect_peer", "id"), &WebSocketServer::disconnect_peer); + ClassDB::bind_method(D_METHOD("disconnect_peer", "id", "code", "reason"), &WebSocketServer::disconnect_peer, DEFVAL(1000), DEFVAL("")); + ADD_SIGNAL(MethodInfo("client_close_request", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::INT, "code"), PropertyInfo(Variant::STRING, "reason"))); ADD_SIGNAL(MethodInfo("client_disconnected", PropertyInfo(Variant::INT, "id"))); ADD_SIGNAL(MethodInfo("client_connected", PropertyInfo(Variant::INT, "id"), PropertyInfo(Variant::STRING, "protocol"))); ADD_SIGNAL(MethodInfo("data_received", PropertyInfo(Variant::INT, "id"))); @@ -95,3 +96,8 @@ void WebSocketServer::_on_disconnect(int32_t p_peer_id) { emit_signal("client_disconnected", p_peer_id); } } + +void WebSocketServer::_on_close_request(int32_t p_peer_id, int p_code, String p_reason) { + + emit_signal("client_close_request", p_peer_id, p_code, p_reason); +} diff --git a/modules/websocket/websocket_server.h b/modules/websocket/websocket_server.h index 64935f8a58..0f044c17ad 100644 --- a/modules/websocket/websocket_server.h +++ b/modules/websocket/websocket_server.h @@ -54,11 +54,12 @@ public: virtual IP_Address get_peer_address(int p_peer_id) const = 0; virtual int get_peer_port(int p_peer_id) const = 0; - virtual void disconnect_peer(int p_peer_id) = 0; + virtual void disconnect_peer(int p_peer_id, int p_code = 1000, String p_reason = "") = 0; void _on_peer_packet(int32_t p_peer_id); void _on_connect(int32_t p_peer_id, String p_protocol); void _on_disconnect(int32_t p_peer_id); + void _on_close_request(int32_t p_peer_id, int p_code, String p_reason); WebSocketServer(); ~WebSocketServer();