From 4c131b2c78bae793509bea776107a8183274a709 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 15:47:47 +0000 Subject: [PATCH 01/82] Implement v2 API for send_join --- synapse/federation/federation_client.py | 46 +++++++++++++++++++++---- synapse/federation/transport/client.py | 13 ++++++- 2 files changed, 51 insertions(+), 8 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 545d719652..50ae40504d 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -664,13 +664,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) + content = self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) @@ -737,6 +731,44 @@ class FederationClient(FederationBase): return self._try_destination_list("send_join", destinations, send_request) + @defer.inlineCallbacks + def _do_send_join(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_join_v2( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() + + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if not err.errcode in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_join_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] + @defer.inlineCallbacks def send_invite(self, destination, room_id, event_id, pdu): room_version = yield self.store.get_room_version(room_id) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 920fa86853..ba68f7c0b4 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -267,7 +267,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_join(self, destination, room_id, event_id, content): + def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -276,6 +276,17 @@ class TransportLayerClient(object): return response + @defer.inlineCallbacks + @log_function + def send_join_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_join/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, path=path, data=content + ) + + return response + @defer.inlineCallbacks @log_function def send_leave(self, destination, room_id, event_id, content): From 92527d7b2186a06c204c3c4bff47207252c5dea2 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:20:53 +0000 Subject: [PATCH 02/82] Add missing yield --- synapse/federation/federation_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 50ae40504d..4a8e65c292 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -664,7 +664,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - content = self._do_send_join(destination, pdu) + content = yield self._do_send_join(destination, pdu) logger.debug("Got content: %s", content) From 1e202a90f15a8f518e4350075d40d0423b64318d Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:26:53 +0000 Subject: [PATCH 03/82] Implement v2 API for send_leave --- synapse/federation/federation_client.py | 41 ++++++++++++++++++++++--- synapse/federation/transport/client.py | 20 +++++++++++- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 4a8e65c292..289017b2ed 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -879,17 +879,50 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_leave( + content = yield self._do_send_leave(destination, pdu) + + logger.debug("Got content: %s", content) + return None + + return self._try_destination_list("send_leave", destinations, send_request) + + @defer.inlineCallbacks + def _do_send_leave(self, destination, pdu): + time_now = self._clock.time_msec() + + try: + content = yield self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - logger.debug("Got content: %s", content) - return None + return content + except HttpResponseException as e: + if e.code in [400, 404]: + err = e.to_synapse_error() - return self._try_destination_list("send_leave", destinations, send_request) + # If we receive an error response that isn't a generic error, or an + # unrecognised endpoint error, we assume that the remote understands + # the v2 invite API and this is a legitimate error. + if not err.errcode in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + raise err + else: + raise e.to_synapse_error() + + logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") + + resp = yield self.transport_layer.send_leave_v1( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + # We expect the v1 API to respond with [200, content], so we only return the + # content. + return resp[1] def get_public_rooms( self, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index ba68f7c0b4..df2b5dc91b 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -289,7 +289,7 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def send_leave(self, destination, room_id, event_id, content): + def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) response = yield self.client.put_json( @@ -305,6 +305,24 @@ class TransportLayerClient(object): return response + @defer.inlineCallbacks + @log_function + def send_leave_v2(self, destination, room_id, event_id, content): + path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) + + response = yield self.client.put_json( + destination=destination, + path=path, + data=content, + # we want to do our best to send this through. The problem is + # that if it fails, we won't retry it later, so if the remote + # server was just having a momentary blip, the room will be out of + # sync. + ignore_backoff=True, + ) + + return response + @defer.inlineCallbacks @log_function def send_invite_v1(self, destination, room_id, event_id, content): From 74897de01fc271ee04ce0638654d991f5fb8e2fa Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:40:45 +0000 Subject: [PATCH 04/82] Add server-side support to the v2 API --- synapse/federation/federation_server.py | 17 ++++++-------- synapse/federation/transport/server.py | 30 +++++++++++++++++++++---- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index d942d77a72..0ce83cd882 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -376,15 +376,12 @@ class FederationServer(FederationBase): res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - return ( - 200, - { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - }, - ) + return { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + } async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) @@ -411,7 +408,7 @@ class FederationServer(FederationBase): pdu = await self._check_sigs_and_hash(room_version, pdu) await self.handler.on_send_leave_request(origin, pdu) - return 200, {} + return {} async def on_event_auth(self, origin, room_id, event_id): with (await self._server_linearizer.queue((origin, room_id))): diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index d6c23f22bd..5263e292c1 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -506,7 +506,15 @@ class FederationMakeLeaveServlet(BaseFederationServlet): return 200, content -class FederationSendLeaveServlet(BaseFederationServlet): +class FederationV1SendLeaveServlet(BaseFederationServlet): + PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + + async def on_PUT(self, origin, content, query, room_id, event_id): + content = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, content) + + +class FederationV2SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): @@ -521,9 +529,21 @@ class FederationEventAuthServlet(BaseFederationServlet): return await self.handler.on_event_auth(origin, context, event_id) -class FederationSendJoinServlet(BaseFederationServlet): +class FederationV1SendJoinServlet(BaseFederationServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + async def on_PUT(self, origin, content, query, context, event_id): + # TODO(paul): assert that context/event_id parsed from path actually + # match those given in content + content = await self.handler.on_send_join_request(origin, content, context) + return 200, (200, content) + + +class FederationV2SendJoinServlet(BaseFederationServlet): + PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + async def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content @@ -1367,8 +1387,10 @@ FEDERATION_SERVLET_CLASSES = ( FederationMakeJoinServlet, FederationMakeLeaveServlet, FederationEventServlet, - FederationSendJoinServlet, - FederationSendLeaveServlet, + FederationV1SendJoinServlet, + FederationV2SendJoinServlet, + FederationV1SendLeaveServlet, + FederationV2SendLeaveServlet, FederationV1InviteServlet, FederationV2InviteServlet, FederationQueryAuthServlet, From 5e18dc7955180086870013b55bf5ae326fa6c695 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:46:09 +0000 Subject: [PATCH 05/82] Fix prefix for v2/send_leave --- synapse/federation/transport/server.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5263e292c1..551a162eb7 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -517,6 +517,8 @@ class FederationV1SendLeaveServlet(BaseFederationServlet): class FederationV2SendLeaveServlet(BaseFederationServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" + PREFIX = FEDERATION_V2_PREFIX + async def on_PUT(self, origin, content, query, room_id, event_id): content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content From edc4c7d4c55b1843c85d452ff92bda17dfff47bc Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:51:54 +0000 Subject: [PATCH 06/82] Lint --- synapse/federation/federation_server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 0ce83cd882..08a913e08a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -378,9 +378,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return { "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], + "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], } async def on_make_leave_request(self, origin, room_id, user_id): From 21056ad12a9dbfadad085802c7a0096d3b071681 Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:53:29 +0000 Subject: [PATCH 07/82] Changelog --- changelog.d/6349.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6349.feature diff --git a/changelog.d/6349.feature b/changelog.d/6349.feature new file mode 100644 index 0000000000..56c4fbf78e --- /dev/null +++ b/changelog.d/6349.feature @@ -0,0 +1 @@ +Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)). From 94cdd6fffed90cceaf0396a66174ddd5c990c8eb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Mon, 11 Nov 2019 16:56:55 +0000 Subject: [PATCH 08/82] Lint --- synapse/federation/federation_client.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 289017b2ed..23c08104b3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -751,7 +751,7 @@ class FederationClient(FederationBase): # If we receive an error response that isn't a generic error, or an # unrecognised endpoint error, we assume that the remote understands # the v2 invite API and this is a legitimate error. - if not err.errcode in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: raise err else: raise e.to_synapse_error() @@ -878,7 +878,6 @@ class FederationClient(FederationBase): @defer.inlineCallbacks def send_request(destination): - time_now = self._clock.time_msec() content = yield self._do_send_leave(destination, pdu) logger.debug("Got content: %s", content) @@ -906,7 +905,7 @@ class FederationClient(FederationBase): # If we receive an error response that isn't a generic error, or an # unrecognised endpoint error, we assume that the remote understands # the v2 invite API and this is a legitimate error. - if not err.errcode in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: + if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]: raise err else: raise e.to_synapse_error() From f166a8d1f52f1fda1a64d8cd50f17275d63c8843 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Dec 2019 13:42:49 +0000 Subject: [PATCH 09/82] Remove SnapshotCache in favour of ResponseCache --- synapse/handlers/initial_sync.py | 19 +++--- synapse/util/caches/snapshot_cache.py | 94 --------------------------- tests/util/test_snapshot_cache.py | 63 ------------------ 3 files changed, 8 insertions(+), 168 deletions(-) delete mode 100644 synapse/util/caches/snapshot_cache.py delete mode 100644 tests/util/test_snapshot_cache.py diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 81dce96f4b..73c110a92b 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig from synapse.types import StreamToken, UserID from synapse.util import unwrapFirstError from synapse.util.async_helpers import concurrently_execute -from synapse.util.caches.snapshot_cache import SnapshotCache +from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client from ._base import BaseHandler @@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler): self.state = hs.get_state_handler() self.clock = hs.get_clock() self.validator = EventValidator() - self.snapshot_cache = SnapshotCache() + self.snapshot_cache = ResponseCache(hs, "initial_sync_cache") self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state @@ -79,17 +79,14 @@ class InitialSyncHandler(BaseHandler): as_client_event, include_archived, ) - now_ms = self.clock.time_msec() - result = self.snapshot_cache.get(now_ms, key) - if result is not None: - return result - return self.snapshot_cache.set( - now_ms, + return self.snapshot_cache.wrap( key, - self._snapshot_all_rooms( - user_id, pagin_config, as_client_event, include_archived - ), + self._snapshot_all_rooms, + user_id, + pagin_config, + as_client_event, + include_archived, ) @defer.inlineCallbacks diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py deleted file mode 100644 index 8318db8d2c..0000000000 --- a/synapse/util/caches/snapshot_cache.py +++ /dev/null @@ -1,94 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.util.async_helpers import ObservableDeferred - - -class SnapshotCache(object): - """Cache for snapshots like the response of /initialSync. - The response of initialSync only has to be a recent snapshot of the - server state. It shouldn't matter to clients if it is a few minutes out - of date. - - This caches a deferred response. Until the deferred completes it will be - returned from the cache. This means that if the client retries the request - while the response is still being computed, that original response will be - used rather than trying to compute a new response. - - Once the deferred completes it will removed from the cache after 5 minutes. - We delay removing it from the cache because a client retrying its request - could race with us finishing computing the response. - - Rather than tracking precisely how long something has been in the cache we - keep two generations of completed responses. Every 5 minutes discard the - old generation, move the new generation to the old generation, and set the - new generation to be empty. This means that a result will be in the cache - somewhere between 5 and 10 minutes. - """ - - DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes. - - def __init__(self): - self.pending_result_cache = {} # Request that haven't finished yet. - self.prev_result_cache = {} # The older requests that have finished. - self.next_result_cache = {} # The newer requests that have finished. - self.time_last_rotated_ms = 0 - - def rotate(self, time_now_ms): - # Rotate once if the cache duration has passed since the last rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms += self.DURATION_MS - - # Rotate again if the cache duration has passed twice since the last - # rotation. - if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS: - self.prev_result_cache = self.next_result_cache - self.next_result_cache = {} - self.time_last_rotated_ms = time_now_ms - - def get(self, time_now_ms, key): - self.rotate(time_now_ms) - # This cache is intended to deduplicate requests, so we expect it to be - # missed most of the time. So we just lookup the key in all of the - # dictionaries rather than trying to short circuit the lookup if the - # key is found. - result = self.prev_result_cache.get(key) - result = self.next_result_cache.get(key, result) - result = self.pending_result_cache.get(key, result) - if result is not None: - return result.observe() - else: - return None - - def set(self, time_now_ms, key, deferred): - self.rotate(time_now_ms) - - result = ObservableDeferred(deferred) - - self.pending_result_cache[key] = result - - def shuffle_along(r): - # When the deferred completes we shuffle it along to the first - # generation of the result cache. So that it will eventually - # expire from the rotation of that cache. - self.next_result_cache[key] = result - self.pending_result_cache.pop(key, None) - return r - - result.addBoth(shuffle_along) - - return result.observe() diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py deleted file mode 100644 index 1a44f72425..0000000000 --- a/tests/util/test_snapshot_cache.py +++ /dev/null @@ -1,63 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet.defer import Deferred - -from synapse.util.caches.snapshot_cache import SnapshotCache - -from .. import unittest - - -class SnapshotCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = SnapshotCache() - self.cache.DURATION_MS = 1 - - def test_get_set(self): - # Check that getting a missing key returns None - self.assertEquals(self.cache.get(0, "key"), None) - - # Check that setting a key with a deferred returns - # a deferred that resolves when the initial deferred does - d = Deferred() - set_result = self.cache.set(0, "key", d) - self.assertIsNotNone(set_result) - self.assertFalse(set_result.called) - - # Check that getting the key before the deferred has resolved - # returns a deferred that resolves when the initial deferred does. - get_result_at_10 = self.cache.get(10, "key") - self.assertIsNotNone(get_result_at_10) - self.assertFalse(get_result_at_10.called) - - # Check that the returned deferreds resolve when the initial deferred - # does. - d.callback("v") - self.assertTrue(set_result.called) - self.assertTrue(get_result_at_10.called) - - # Check that getting the key after the deferred has resolved - # before the cache expires returns a resolved deferred. - get_result_at_11 = self.cache.get(11, "key") - self.assertIsNotNone(get_result_at_11) - if isinstance(get_result_at_11, Deferred): - # The cache may return the actual result rather than a deferred - self.assertTrue(get_result_at_11.called) - - # Check that getting the key after the deferred has resolved - # after the cache expires returns None - get_result_at_12 = self.cache.get(12, "key") - self.assertIsNone(get_result_at_12) From a1f8ea9051e7e7c34ed1677871585cc68e3d0311 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Dec 2019 13:46:45 +0000 Subject: [PATCH 10/82] Port synapse.handlers.initial_sync to async/await --- synapse/handlers/initial_sync.py | 96 +++++++++++++++----------------- 1 file changed, 44 insertions(+), 52 deletions(-) diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 73c110a92b..44ec3e66ae 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -89,8 +89,7 @@ class InitialSyncHandler(BaseHandler): include_archived, ) - @defer.inlineCallbacks - def _snapshot_all_rooms( + async def _snapshot_all_rooms( self, user_id=None, pagin_config=None, @@ -102,7 +101,7 @@ class InitialSyncHandler(BaseHandler): if include_archived: memberships.append(Membership.LEAVE) - room_list = yield self.store.get_rooms_for_user_where_membership_is( + room_list = await self.store.get_rooms_for_user_where_membership_is( user_id=user_id, membership_list=memberships ) @@ -110,33 +109,32 @@ class InitialSyncHandler(BaseHandler): rooms_ret = [] - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() presence_stream = self.hs.get_event_sources().sources["presence"] pagination_config = PaginationConfig(from_token=now_token) - presence, _ = yield presence_stream.get_pagination_rows( + presence, _ = await presence_stream.get_pagination_rows( user, pagination_config.get_source_config("presence"), None ) receipt_stream = self.hs.get_event_sources().sources["receipt"] - receipt, _ = yield receipt_stream.get_pagination_rows( + receipt, _ = await receipt_stream.get_pagination_rows( user, pagination_config.get_source_config("receipt"), None ) - tags_by_room = yield self.store.get_tags_for_user(user_id) + tags_by_room = await self.store.get_tags_for_user(user_id) - account_data, account_data_by_room = yield self.store.get_account_data_for_user( + account_data, account_data_by_room = await self.store.get_account_data_for_user( user_id ) - public_room_ids = yield self.store.get_public_room_ids() + public_room_ids = await self.store.get_public_room_ids() limit = pagin_config.limit if limit is None: limit = 10 - @defer.inlineCallbacks - def handle_room(event): + async def handle_room(event): d = { "room_id": event.room_id, "membership": event.membership, @@ -149,8 +147,8 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["inviter"] = event.sender - invite_event = yield self.store.get_event(event.event_id) - d["invite"] = yield self._event_serializer.serialize_event( + invite_event = await self.store.get_event(event.event_id) + d["invite"] = await self._event_serializer.serialize_event( invite_event, time_now, as_client_event ) @@ -174,7 +172,7 @@ class InitialSyncHandler(BaseHandler): lambda states: states[event.event_id] ) - (messages, token), current_state = yield make_deferred_yieldable( + (messages, token), current_state = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -188,7 +186,7 @@ class InitialSyncHandler(BaseHandler): ) ).addErrback(unwrapFirstError) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages ) @@ -198,7 +196,7 @@ class InitialSyncHandler(BaseHandler): d["messages"] = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( messages, time_now=time_now, as_client_event=as_client_event ) ), @@ -206,7 +204,7 @@ class InitialSyncHandler(BaseHandler): "end": end_token.to_string(), } - d["state"] = yield self._event_serializer.serialize_events( + d["state"] = await self._event_serializer.serialize_events( current_state.values(), time_now=time_now, as_client_event=as_client_event, @@ -229,7 +227,7 @@ class InitialSyncHandler(BaseHandler): except Exception: logger.exception("Failed to get snapshot") - yield concurrently_execute(handle_room, room_list, 10) + await concurrently_execute(handle_room, room_list, 10) account_data_events = [] for account_data_type, content in account_data.items(): @@ -253,8 +251,7 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def room_initial_sync(self, requester, room_id, pagin_config=None): + async def room_initial_sync(self, requester, room_id, pagin_config=None): """Capture the a snapshot of a room. If user is currently a member of the room this will be what is currently in the room. If the user left the room this will be what was in the room when they left. @@ -271,32 +268,32 @@ class InitialSyncHandler(BaseHandler): A JSON serialisable dict with the snapshot of the room. """ - blocked = yield self.store.is_room_blocked(room_id) + blocked = await self.store.is_room_blocked(room_id) if blocked: raise SynapseError(403, "This room has been blocked on this server") user_id = requester.user.to_string() - membership, member_event_id = yield self._check_in_room_or_world_readable( + membership, member_event_id = await self._check_in_room_or_world_readable( room_id, user_id ) is_peeking = member_event_id is None if membership == Membership.JOIN: - result = yield self._room_initial_sync_joined( + result = await self._room_initial_sync_joined( user_id, room_id, pagin_config, membership, is_peeking ) elif membership == Membership.LEAVE: - result = yield self._room_initial_sync_parted( + result = await self._room_initial_sync_parted( user_id, room_id, pagin_config, membership, member_event_id, is_peeking ) account_data_events = [] - tags = yield self.store.get_tags_for_room(user_id, room_id) + tags = await self.store.get_tags_for_room(user_id, room_id) if tags: account_data_events.append({"type": "m.tag", "content": {"tags": tags}}) - account_data = yield self.store.get_account_data_for_room(user_id, room_id) + account_data = await self.store.get_account_data_for_room(user_id, room_id) for account_data_type, content in account_data.items(): account_data_events.append({"type": account_data_type, "content": content}) @@ -304,11 +301,10 @@ class InitialSyncHandler(BaseHandler): return result - @defer.inlineCallbacks - def _room_initial_sync_parted( + async def _room_initial_sync_parted( self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking ): - room_state = yield self.state_store.get_state_for_events([member_event_id]) + room_state = await self.state_store.get_state_for_events([member_event_id]) room_state = room_state[member_event_id] @@ -316,13 +312,13 @@ class InitialSyncHandler(BaseHandler): if limit is None: limit = 10 - stream_token = yield self.store.get_stream_token_for_event(member_event_id) + stream_token = await self.store.get_stream_token_for_event(member_event_id) - messages, token = yield self.store.get_recent_events_for_room( + messages, token = await self.store.get_recent_events_for_room( room_id, limit=limit, end_token=stream_token ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -336,13 +332,13 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), }, "state": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( room_state.values(), time_now ) ), @@ -350,19 +346,18 @@ class InitialSyncHandler(BaseHandler): "receipts": [], } - @defer.inlineCallbacks - def _room_initial_sync_joined( + async def _room_initial_sync_joined( self, user_id, room_id, pagin_config, membership, is_peeking ): - current_state = yield self.state.get_current_state(room_id=room_id) + current_state = await self.state.get_current_state(room_id=room_id) # TODO: These concurrently time_now = self.clock.time_msec() - state = yield self._event_serializer.serialize_events( + state = await self._event_serializer.serialize_events( current_state.values(), time_now ) - now_token = yield self.hs.get_event_sources().get_current_token() + now_token = await self.hs.get_event_sources().get_current_token() limit = pagin_config.limit if pagin_config else None if limit is None: @@ -377,28 +372,26 @@ class InitialSyncHandler(BaseHandler): presence_handler = self.hs.get_presence_handler() - @defer.inlineCallbacks - def get_presence(): + async def get_presence(): # If presence is disabled, return an empty list if not self.hs.config.use_presence: return [] - states = yield presence_handler.get_states( + states = await presence_handler.get_states( [m.user_id for m in room_members], as_event=True ) return states - @defer.inlineCallbacks - def get_receipts(): - receipts = yield self.store.get_linearized_receipts_for_room( + async def get_receipts(): + receipts = await self.store.get_linearized_receipts_for_room( room_id, to_key=now_token.receipt_key ) if not receipts: receipts = [] return receipts - presence, receipts, (messages, token) = yield make_deferred_yieldable( + presence, receipts, (messages, token) = await make_deferred_yieldable( defer.gatherResults( [ run_in_background(get_presence), @@ -414,7 +407,7 @@ class InitialSyncHandler(BaseHandler): ).addErrback(unwrapFirstError) ) - messages = yield filter_events_for_client( + messages = await filter_events_for_client( self.storage, user_id, messages, is_peeking=is_peeking ) @@ -427,7 +420,7 @@ class InitialSyncHandler(BaseHandler): "room_id": room_id, "messages": { "chunk": ( - yield self._event_serializer.serialize_events(messages, time_now) + await self._event_serializer.serialize_events(messages, time_now) ), "start": start_token.to_string(), "end": end_token.to_string(), @@ -441,18 +434,17 @@ class InitialSyncHandler(BaseHandler): return ret - @defer.inlineCallbacks - def _check_in_room_or_world_readable(self, room_id, user_id): + async def _check_in_room_or_world_readable(self, room_id, user_id): try: # check_user_was_in_room will return the most recent membership # event for the user if: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = yield self.auth.check_user_was_in_room(room_id, user_id) + member_event = await self.auth.check_user_was_in_room(room_id, user_id) return member_event.membership, member_event.event_id except AuthError: - visibility = yield self.state_handler.get_current_state( + visibility = await self.state_handler.get_current_state( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( From aeaeb72ee41076b0f08b07ea4686d1c38acf2d6b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 9 Dec 2019 13:48:14 +0000 Subject: [PATCH 11/82] Newsfile --- changelog.d/6496.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6496.misc diff --git a/changelog.d/6496.misc b/changelog.d/6496.misc new file mode 100644 index 0000000000..19c6e926b8 --- /dev/null +++ b/changelog.d/6496.misc @@ -0,0 +1 @@ +Port synapse.handlers.initial_sync to async/await. From 353396e3a7c00c2634e3a175e9aeac98c6a4598f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:12:56 +0000 Subject: [PATCH 12/82] Port handlers.account_data to async/await. --- synapse/handlers/account_data.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 2d7e6df6e4..fe44d62a1c 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - class AccountDataEventSource(object): def __init__(self, hs): @@ -23,15 +21,14 @@ class AccountDataEventSource(object): def get_current_key(self, direction="f"): return self.store.get_max_account_data_stream_id() - @defer.inlineCallbacks - def get_new_events(self, user, from_key, **kwargs): + async def get_new_events(self, user, from_key, **kwargs): user_id = user.to_string() last_stream_id = from_key - current_stream_id = yield self.store.get_max_account_data_stream_id() + current_stream_id = await self.store.get_max_account_data_stream_id() results = [] - tags = yield self.store.get_updated_tags(user_id, last_stream_id) + tags = await self.store.get_updated_tags(user_id, last_stream_id) for room_id, room_tags in tags.items(): results.append( @@ -41,7 +38,7 @@ class AccountDataEventSource(object): ( account_data, room_account_data, - ) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id) + ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id) for account_data_type, content in account_data.items(): results.append({"type": account_data_type, "content": content}) @@ -54,6 +51,5 @@ class AccountDataEventSource(object): return results, current_stream_id - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): + async def get_pagination_rows(self, user, config, key): return [], config.to_id From 9a2223d4c8c2a46f7ab072597bdba2614bc3e6ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:22:12 +0000 Subject: [PATCH 13/82] Fix make_deferred_yieldable to work with coroutines --- synapse/logging/context.py | 9 ++++++++- tests/util/test_logcontext.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 2c1fb9ddac..9f484ce59e 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -23,6 +23,7 @@ them. See doc/log_contexts.rst for details on how this works. """ +import inspect import logging import threading import types @@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs): def make_deferred_yieldable(deferred): - """Given a deferred, make it follow the Synapse logcontext rules: + """Given a deferred (or coroutine), make it follow the Synapse logcontext + rules: If the deferred has completed (or is not actually a Deferred), essentially does nothing (just returns another completed deferred with the @@ -624,6 +626,11 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ + if inspect.isawaitable(deferred): + # If we're given a coroutine we need to convert it to a deferred so that + # we can attach callbacks (and not immediately return). + deferred = defer.ensureDeferred(deferred) + if not isinstance(deferred, defer.Deferred): return deferred diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 8b8455c8b7..281b32c4b8 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.request, "foo-bar") + @defer.inlineCallbacks + def test_make_deferred_yieldable_with_await(self): + # an async function which retuns an incomplete coroutine, but doesn't + # follow the synapse rules. + + async def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + await d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on From f5bb1531b7307bc2d826789746e7c82fa4dbf36c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:23:52 +0000 Subject: [PATCH 14/82] Newsfile --- changelog.d/6505.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6505.misc diff --git a/changelog.d/6505.misc b/changelog.d/6505.misc new file mode 100644 index 0000000000..3a75b2d9dd --- /dev/null +++ b/changelog.d/6505.misc @@ -0,0 +1 @@ +Make `make_deferred_yieldable` to work with async/await. From b1e7012deea2d254ecbf92d1fed429c34a65db54 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:29:44 +0000 Subject: [PATCH 15/82] Newsfile --- changelog.d/6506.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6506.misc diff --git a/changelog.d/6506.misc b/changelog.d/6506.misc new file mode 100644 index 0000000000..99d7a70bcf --- /dev/null +++ b/changelog.d/6506.misc @@ -0,0 +1 @@ +Remove `SnapshotCache` in favour of `ResponseCache`. From ffeafade4879a1135ab6795d4ba0c11131f82f01 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 13:17:39 +0000 Subject: [PATCH 16/82] Update comment --- synapse/logging/context.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 9f484ce59e..6747f29e6a 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -627,8 +627,10 @@ def make_deferred_yieldable(deferred): (This is more-or-less the opposite operation to run_in_background.) """ if inspect.isawaitable(deferred): - # If we're given a coroutine we need to convert it to a deferred so that - # we can attach callbacks (and not immediately return). + # If we're given a coroutine we convert it to a deferred so that we + # run it and find out if it immediately finishes, it it does then we + # don't need to fiddle with log contexts at all and can return + # immediately. deferred = defer.ensureDeferred(deferred) if not isinstance(deferred, defer.Deferred): From 663238aeb4ff1df2e857a12355fdf7cf76dcd6d9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 17:42:14 +0000 Subject: [PATCH 17/82] Phone home stats DB reporting should not assume a single DB. --- synapse/app/homeserver.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..032010600a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -519,8 +519,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): # Database version # - stats["database_engine"] = hs.database_engine.module.__name__ - stats["database_server_version"] = hs.database_engine.server_version + # This only reports info about the *main* database. + stats["database_engine"] = hs.get_datastore().db.engine.module.__name__ + stats["database_server_version"] = hs.get_datastore().db.engine.server_version + logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) try: yield hs.get_proxied_http_client().put_json( From accd343f9104653c0e89d79101ae6ae767f7aa35 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 13:22:42 +0000 Subject: [PATCH 18/82] Newsfile --- changelog.d/6510.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6510.misc diff --git a/changelog.d/6510.misc b/changelog.d/6510.misc new file mode 100644 index 0000000000..214f06539b --- /dev/null +++ b/changelog.d/6510.misc @@ -0,0 +1 @@ +Change phone home stats to not assume there is a single database and report information about the database used by the main data store. From 28b758fa0f396aafa92cdc96ff716469a3b083d4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 17:42:26 +0000 Subject: [PATCH 19/82] Silence mypy errors for files outside those specified --- mypy.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 1d77c0ecc8..a66434b76b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,7 +1,7 @@ [mypy] namespace_packages = True plugins = mypy_zope:plugin -follow_imports = normal +follow_imports = silent check_untyped_defs = True show_error_codes = True show_traceback = True From 4643bb2a37dcd6302fd5b9e185ea5d0f4eb0df8c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 13:36:00 +0000 Subject: [PATCH 20/82] Newsfile --- changelog.d/6512.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6512.misc diff --git a/changelog.d/6512.misc b/changelog.d/6512.misc new file mode 100644 index 0000000000..37a8099eec --- /dev/null +++ b/changelog.d/6512.misc @@ -0,0 +1 @@ +Silence mypy errors for files outside those specified. From bc5cb8bfe8c5b0b551de9a97912cd00caeaf6b48 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 6 Dec 2019 17:42:18 +0000 Subject: [PATCH 21/82] Remove database config parsing from apps. --- synapse/app/admin_cmd.py | 5 ----- synapse/app/appservice.py | 5 ----- synapse/app/client_reader.py | 5 ----- synapse/app/event_creator.py | 5 ----- synapse/app/federation_reader.py | 5 ----- synapse/app/federation_sender.py | 5 ----- synapse/app/frontend_proxy.py | 5 ----- synapse/app/homeserver.py | 7 +------ synapse/app/media_repository.py | 5 ----- synapse/app/pusher.py | 5 ----- synapse/app/synchrotron.py | 5 ----- synapse/app/user_dir.py | 5 ----- synapse/server.py | 10 +++++++++- tests/test_types.py | 19 ++++++++----------- tests/utils.py | 2 -- 15 files changed, 18 insertions(+), 75 deletions(-) diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 04751a6a5e..51a909419f 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.room import RoomStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.logcontext import LoggingContext from synapse.util.versionstring import get_version_string @@ -229,14 +228,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = AdminCmdServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py index 02b900f382..e82e0f11e3 100644 --- a/synapse/app/appservice.py +++ b/synapse/app/appservice.py @@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -143,8 +142,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.notify_appservices: sys.stderr.write( "\nThe appservices must be disabled in the main synapse process" @@ -159,10 +156,8 @@ def start(config_options): ps = AppserviceServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py index dadb487d5f..3edfe19567 100644 --- a/synapse/app/client_reader.py +++ b/synapse/app/client_reader.py @@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -181,14 +180,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = ClientReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py index d110599a35..d0ddbe38fc 100644 --- a/synapse/app/event_creator.py +++ b/synapse/app/event_creator.py @@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import ( ) from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -180,14 +179,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = EventCreatorServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py index 418c086254..311523e0ed 100644 --- a/synapse/app/federation_reader.py +++ b/synapse/app/federation_reader.py @@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.key.v2 import KeyApiV2Resource from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -162,14 +161,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FederationReaderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py index f24920a7d6..83c436229c 100644 --- a/synapse/app/federation_sender.py +++ b/synapse/app/federation_sender.py @@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler from synapse.replication.tcp.streams._base import ReceiptsStream from synapse.server import HomeServer from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer from synapse.util.httpresourcetree import create_resource_tree @@ -174,8 +173,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.send_federation: sys.stderr.write( "\nThe send_federation must be disabled in the main synapse process" @@ -190,10 +187,8 @@ def start(config_options): ss = FederationSenderServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py index e647459d0e..30e435eead 100644 --- a/synapse/app/frontend_proxy.py +++ b/synapse/app/frontend_proxy.py @@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.client.v2_alpha._base import client_patterns from synapse.server import HomeServer -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -234,14 +233,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = FrontendProxyServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index df65d0a989..79341784c5 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import IncorrectDatabaseSetup, create_engine +from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree @@ -328,15 +328,10 @@ def setup(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection - hs = SynapseHomeServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) synapse.config.logger.setup_logging(hs, config, use_worker_options=False) diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 2c6dd3ef02..4c80f257e2 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -157,14 +156,10 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = MediaRepositoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py index 01a5ffc363..e157cbf64b 100644 --- a/synapse/app/pusher.py +++ b/synapse/app/pusher.py @@ -36,7 +36,6 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.versionstring import get_version_string @@ -198,14 +197,10 @@ def start(config_options): # Force the pushers to start since they will be disabled in the main config config.start_pushers = True - database_engine = create_engine(config.database_config) - ps = PusherServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ps, config, use_worker_options=True) diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py index 288ee64b42..dd2132e608 100644 --- a/synapse/app/synchrotron.py +++ b/synapse/app/synchrotron.py @@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet from synapse.rest.client.v2_alpha import sync from synapse.server import HomeServer from synapse.storage.data_stores.main.presence import UserPresenceState -from synapse.storage.engines import create_engine from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.stringutils import random_string @@ -437,14 +436,10 @@ def start(config_options): synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - ss = SynchrotronServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, application_service_handler=SynchrotronApplicationService(), ) diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py index c01fb34a9b..1257098f92 100644 --- a/synapse/app/user_dir.py +++ b/synapse/app/user_dir.py @@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory from synapse.server import HomeServer from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.storage.database import Database -from synapse.storage.engines import create_engine from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole @@ -200,8 +199,6 @@ def start(config_options): events.USE_FROZEN_DICTS = config.use_frozen_dicts - database_engine = create_engine(config.database_config) - if config.update_user_directory: sys.stderr.write( "\nThe update_user_directory must be disabled in the main synapse process" @@ -216,10 +213,8 @@ def start(config_options): ss = UserDirectoryServer( config.server_name, - db_config=config.database_config, config=config, version_string="Synapse/" + get_version_string(synapse), - database_engine=database_engine, ) setup_logging(ss, config, use_worker_options=True) diff --git a/synapse/server.py b/synapse/server.py index 2db3dab221..a8b5459ff3 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage +from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -209,7 +210,7 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, reactor=None, **kwargs): + def __init__(self, hostname, config, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. @@ -219,6 +220,7 @@ class HomeServer(object): self._reactor = reactor self.hostname = hostname + self.config = config self._building = {} self._listening_services = [] self.start_time = None @@ -229,6 +231,12 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() + self.database_engine = create_engine(config.database_config) + config.database_config.setdefault("args", {})[ + "cp_openfun" + ] = self.database_engine.on_new_connection + self.db_config = config.database_config + self.datastores = None # Other kwargs are explicit dependencies diff --git a/tests/test_types.py b/tests/test_types.py index 9ab5f829b0..8d97c751ea 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart from tests import unittest -from tests.utils import TestHomeServer - -mock_homeserver = TestHomeServer(hostname="my.domain") -class UserIDTestCase(unittest.TestCase): +class UserIDTestCase(unittest.HomeserverTestCase): def test_parse(self): - user = UserID.from_string("@1234abcd:my.domain") + user = UserID.from_string("@1234abcd:test") self.assertEquals("1234abcd", user.localpart) - self.assertEquals("my.domain", user.domain) - self.assertEquals(True, mock_homeserver.is_mine(user)) + self.assertEquals("test", user.domain) + self.assertEquals(True, self.hs.is_mine(user)) def test_pase_empty(self): with self.assertRaises(SynapseError): @@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase): self.assertTrue(userA != userB) -class RoomAliasTestCase(unittest.TestCase): +class RoomAliasTestCase(unittest.HomeserverTestCase): def test_parse(self): - room = RoomAlias.from_string("#channel:my.domain") + room = RoomAlias.from_string("#channel:test") self.assertEquals("channel", room.localpart) - self.assertEquals("my.domain", room.domain) - self.assertEquals(True, mock_homeserver.is_mine(room)) + self.assertEquals("test", room.domain) + self.assertEquals(True, self.hs.is_mine(room)) def test_build(self): room = RoomAlias("channel", "my.domain") diff --git a/tests/utils.py b/tests/utils.py index c57da59191..585f305b9a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -260,9 +260,7 @@ def setup_test_homeserver( hs = homeserverToUse( name, config=config, - db_config=config.database_config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, From 67c991b78fd38df687f28bb22f66f34fe2c5bb9f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 13:32:34 +0000 Subject: [PATCH 22/82] Fix upgrade db script --- scripts-dev/update_database | 33 ++++++--------------------------- 1 file changed, 6 insertions(+), 27 deletions(-) diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 1776d202c5..23017c21f8 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -35,21 +34,11 @@ logger = logging.getLogger("update_database") class MockHomeserver(HomeServer): DATASTORE_CLASS = DataStore - def __init__(self, config, database_engine, db_conn, **kwargs): + def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, - reactor=reactor, - config=config, - database_engine=database_engine, - **kwargs + config.server_name, reactor=reactor, config=config, **kwargs ) - self.database_engine = database_engine - self.db_conn = db_conn - - def get_db_conn(self): - return self.db_conn - if __name__ == "__main__": parser = argparse.ArgumentParser( @@ -85,24 +74,14 @@ if __name__ == "__main__": config = HomeServerConfig() config.parse_config_dict(hs_config, "", "") - # Create the database engine and a connection to it. - database_engine = create_engine(config.database_config) - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in config.database_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) + # Instantiate and initialise the homeserver object. + hs = MockHomeserver(config) + db_conn = hs.get_db_conn() # Update the database to the latest schema. - prepare_database(db_conn, database_engine, config=config) + prepare_database(db_conn, hs.database_engine, config=config) db_conn.commit() - # Instantiate and initialise the homeserver object. - hs = MockHomeserver( - config, database_engine, db_conn, db_config=config.database_config, - ) # setup instantiates the store within the homeserver object. hs.setup() store = hs.get_datastore() From d630c8234928c008e12c79ebef10cee570779fa5 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 13:34:17 +0000 Subject: [PATCH 23/82] Newsfile --- changelog.d/6511.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6511.misc diff --git a/changelog.d/6511.misc b/changelog.d/6511.misc new file mode 100644 index 0000000000..19ce435e68 --- /dev/null +++ b/changelog.d/6511.misc @@ -0,0 +1 @@ +Move database config from apps into HomeServer object. From 257ef2c727abbb289c460b930e8bac18ab80e053 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:13:15 +0000 Subject: [PATCH 24/82] Port handlers.account_validity to async/await. --- synapse/handlers/account_data.py | 2 +- synapse/handlers/account_validity.py | 86 ++++++++++----------- tests/rest/client/v2_alpha/test_register.py | 3 +- 3 files changed, 42 insertions(+), 49 deletions(-) diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index fe44d62a1c..20ec1ca01b 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -25,7 +25,7 @@ class AccountDataEventSource(object): user_id = user.to_string() last_stream_id = from_key - current_stream_id = await self.store.get_max_account_data_stream_id() + current_stream_id = self.store.get_max_account_data_stream_id() results = [] tags = await self.store.get_updated_tags(user_id, last_stream_id) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index d04e0fe576..829f52eca1 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -18,8 +18,7 @@ import email.utils import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText - -from twisted.internet import defer +from typing import List from synapse.api.errors import StoreError from synapse.logging.context import make_deferred_yieldable @@ -78,42 +77,39 @@ class AccountValidityHandler(object): # run as a background process to make sure that the database transactions # have a logcontext to report to return run_as_background_process( - "send_renewals", self.send_renewal_emails + "send_renewals", self._send_renewal_emails ) self.clock.looping_call(send_emails, 30 * 60 * 1000) - @defer.inlineCallbacks - def send_renewal_emails(self): + async def _send_renewal_emails(self): """Gets the list of users whose account is expiring in the amount of time configured in the ``renew_at`` parameter from the ``account_validity`` configuration, and sends renewal emails to all of these users as long as they have an email 3PID attached to their account. """ - expiring_users = yield self.store.get_users_expiring_soon() + expiring_users = await self.store.get_users_expiring_soon() if expiring_users: for user in expiring_users: - yield self._send_renewal_email( + await self._send_renewal_email( user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"] ) - @defer.inlineCallbacks - def send_renewal_email_to_user(self, user_id): - expiration_ts = yield self.store.get_expiration_ts_for_user(user_id) - yield self._send_renewal_email(user_id, expiration_ts) + async def send_renewal_email_to_user(self, user_id: str): + expiration_ts = await self.store.get_expiration_ts_for_user(user_id) + await self._send_renewal_email(user_id, expiration_ts) - @defer.inlineCallbacks - def _send_renewal_email(self, user_id, expiration_ts): + async def _send_renewal_email(self, user_id: str, expiration_ts: int): """Sends out a renewal email to every email address attached to the given user with a unique link allowing them to renew their account. Args: - user_id (str): ID of the user to send email(s) to. - expiration_ts (int): Timestamp in milliseconds for the expiration date of + user_id: ID of the user to send email(s) to. + expiration_ts: Timestamp in milliseconds for the expiration date of this user's account (used in the email templates). """ - addresses = yield self._get_email_addresses_for_user(user_id) + addresses = await self._get_email_addresses_for_user(user_id) # Stop right here if the user doesn't have at least one email address. # In this case, they will have to ask their server admin to renew their @@ -125,7 +121,7 @@ class AccountValidityHandler(object): return try: - user_display_name = yield self.store.get_profile_displayname( + user_display_name = await self.store.get_profile_displayname( UserID.from_string(user_id).localpart ) if user_display_name is None: @@ -133,7 +129,7 @@ class AccountValidityHandler(object): except StoreError: user_display_name = user_id - renewal_token = yield self._get_renewal_token(user_id) + renewal_token = await self._get_renewal_token(user_id) url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % ( self.hs.config.public_baseurl, renewal_token, @@ -165,7 +161,7 @@ class AccountValidityHandler(object): logger.info("Sending renewal email to %s", address) - yield make_deferred_yieldable( + await make_deferred_yieldable( self.sendmail( self.hs.config.email_smtp_host, self._raw_from, @@ -180,19 +176,18 @@ class AccountValidityHandler(object): ) ) - yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) + await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True) - @defer.inlineCallbacks - def _get_email_addresses_for_user(self, user_id): + async def _get_email_addresses_for_user(self, user_id: str) -> List[str]: """Retrieve the list of email addresses attached to a user's account. Args: - user_id (str): ID of the user to lookup email addresses for. + user_id: ID of the user to lookup email addresses for. Returns: - defer.Deferred[list[str]]: Email addresses for this account. + Email addresses for this account. """ - threepids = yield self.store.user_get_threepids(user_id) + threepids = await self.store.user_get_threepids(user_id) addresses = [] for threepid in threepids: @@ -201,16 +196,15 @@ class AccountValidityHandler(object): return addresses - @defer.inlineCallbacks - def _get_renewal_token(self, user_id): + async def _get_renewal_token(self, user_id: str) -> str: """Generates a 32-byte long random string that will be inserted into the user's renewal email's unique link, then saves it into the database. Args: - user_id (str): ID of the user to generate a string for. + user_id: ID of the user to generate a string for. Returns: - defer.Deferred[str]: The generated string. + The generated string. Raises: StoreError(500): Couldn't generate a unique string after 5 attempts. @@ -219,52 +213,52 @@ class AccountValidityHandler(object): while attempts < 5: try: renewal_token = stringutils.random_string(32) - yield self.store.set_renewal_token_for_user(user_id, renewal_token) + await self.store.set_renewal_token_for_user(user_id, renewal_token) return renewal_token except StoreError: attempts += 1 raise StoreError(500, "Couldn't generate a unique string as refresh string.") - @defer.inlineCallbacks - def renew_account(self, renewal_token): + async def renew_account(self, renewal_token: str) -> bool: """Renews the account attached to a given renewal token by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. + renewal_token: Token sent with the renewal request. Returns: - bool: Whether the provided token is valid. + Whether the provided token is valid. """ try: - user_id = yield self.store.get_user_from_renewal_token(renewal_token) + user_id = await self.store.get_user_from_renewal_token(renewal_token) except StoreError: - defer.returnValue(False) + return False logger.debug("Renewing an account for user %s", user_id) - yield self.renew_account_for_user(user_id) + await self.renew_account_for_user(user_id) - defer.returnValue(True) + return True - @defer.inlineCallbacks - def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False): + async def renew_account_for_user( + self, user_id: str, expiration_ts: int = None, email_sent: bool = False + ) -> int: """Renews the account attached to a given user by pushing back the expiration date by the current validity period in the server's configuration. Args: - renewal_token (str): Token sent with the renewal request. - expiration_ts (int): New expiration date. Defaults to now + validity period. - email_sent (bool): Whether an email has been sent for this validity period. + renewal_token: Token sent with the renewal request. + expiration_ts: New expiration date. Defaults to now + validity period. + email_sen: Whether an email has been sent for this validity period. Defaults to False. Returns: - defer.Deferred[int]: New expiration date for this account, as a timestamp - in milliseconds since epoch. + New expiration date for this account, as a timestamp in + milliseconds since epoch. """ if expiration_ts is None: expiration_ts = self.clock.time_msec() + self._account_validity.period - yield self.store.set_account_validity_for_user( + await self.store.set_account_validity_for_user( user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent ) diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index c0d0d2b44e..d0c997e385 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): # Email config. self.email_attempts = [] - def sendmail(*args, **kwargs): + async def sendmail(*args, **kwargs): self.email_attempts.append((args, kwargs)) - return config["email"] = { "enable_notifs": True, From 3f97b4c16bbe8a32fc465bd59421f3ba879d2124 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 10 Dec 2019 11:17:13 +0000 Subject: [PATCH 25/82] Newsfile --- changelog.d/6504.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6504.misc diff --git a/changelog.d/6504.misc b/changelog.d/6504.misc new file mode 100644 index 0000000000..7c873459af --- /dev/null +++ b/changelog.d/6504.misc @@ -0,0 +1 @@ +Port handlers.account_data and handlers.account_validity to async/await. From 424fd58237d4a40c6f94772be136298d326fcd69 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Dec 2019 15:09:45 +0000 Subject: [PATCH 26/82] Remove redundant code from event authorisation implementation. (#6502) --- changelog.d/6502.removal | 1 + synapse/event_auth.py | 8 ++------ 2 files changed, 3 insertions(+), 6 deletions(-) create mode 100644 changelog.d/6502.removal diff --git a/changelog.d/6502.removal b/changelog.d/6502.removal new file mode 100644 index 0000000000..0b72261d58 --- /dev/null +++ b/changelog.d/6502.removal @@ -0,0 +1 @@ +Remove redundant code from event authorisation implementation. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index ec3243b27b..c940b84470 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru Returns: if the auth checks pass. """ + assert isinstance(auth_events, dict) + if do_size_check: _check_size_limits(event) @@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not event.signatures.get(event_id_domain): raise AuthError(403, "Event not signed by sending server") - if auth_events is None: - # Oh, we don't know what the state of the room was, so we - # are trusting that this is allowed (at least for now) - logger.warning("Trusting event: %s", event.event_id) - return - if event.type == EventTypes.Create: sender_domain = get_domain_from_id(event.sender) room_id_domain = get_domain_from_id(event.room_id) From c3dda2874d78790525f47e502aaed22b64961873 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Dec 2019 16:22:00 +0000 Subject: [PATCH 27/82] Refactor get_events_from_store_or_dest to return a dict (#6501) There was a bunch of unnecessary conversion back and forth between dict and list going on here. We can simplify a bunch of the code. --- changelog.d/6501.misc | 1 + synapse/federation/federation_client.py | 44 +++++++++---------------- 2 files changed, 16 insertions(+), 29 deletions(-) create mode 100644 changelog.d/6501.misc diff --git a/changelog.d/6501.misc b/changelog.d/6501.misc new file mode 100644 index 0000000000..255f45a9c3 --- /dev/null +++ b/changelog.d/6501.misc @@ -0,0 +1 @@ +Refactor get_events_from_store_or_dest to return a dict. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 709449c9e3..73e1dda6a3 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,8 +18,6 @@ import copy import itertools import logging -from six.moves import range - from prometheus_client import Counter from twisted.internet import defer @@ -41,7 +39,7 @@ from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.utils import log_function -from synapse.util import unwrapFirstError +from synapse.util import batch_iter, unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -331,10 +329,12 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest( - destination, room_id, set(state_event_ids + auth_event_ids) + desired_events = set(state_event_ids + auth_event_ids) + event_map = yield self.get_events_from_store_or_dest( + destination, room_id, desired_events ) + failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( "Failed to fetch missing state/auth events for %s: %s", @@ -342,8 +342,6 @@ class FederationClient(FederationBase): failed_to_fetch, ) - event_map = {ev.event_id: ev for ev in fetched_events} - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] @@ -358,23 +356,18 @@ class FederationClient(FederationBase): Args: destination (str) room_id (str) - event_ids (list) + event_ids (Iterable[str]) Returns: - Deferred: A deferred resolving to a 2-tuple where the first is a list of - events and the second is a list of event ids that we failed to fetch. + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event """ - seen_events = yield self.store.get_events(event_ids, allow_rejected=True) - signed_events = list(seen_events.values()) + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) - failed_to_fetch = set() - - missing_events = set(event_ids) - for k in seen_events: - missing_events.discard(k) + missing_events = set(event_ids) - fetched_events.keys() if not missing_events: - return signed_events, failed_to_fetch + return fetched_events logger.debug( "Fetching unknown state/auth events %s for room %s", @@ -384,11 +377,8 @@ class FederationClient(FederationBase): room_version = yield self.store.get_room_version(room_id) - batch_size = 20 - missing_events = list(missing_events) - for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i : i + batch_size]) - + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): deferreds = [ run_in_background( self.get_pdu, @@ -404,13 +394,9 @@ class FederationClient(FederationBase): ) for success, result in res: if success and result: - signed_events.append(result) - batch.discard(result.event_id) + fetched_events[result.event_id] = result - # We removed all events we successfully fetched from `batch` - failed_to_fetch.update(batch) - - return signed_events, failed_to_fetch + return fetched_events @defer.inlineCallbacks @log_function From 40eda849338b6e47a5804b4cf7000e9d2417c4d8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Dec 2019 16:22:29 +0000 Subject: [PATCH 28/82] Fix race which caused deleted devices to reappear (#6514) Stop the `update_client_ips` background job from recreating deleted devices. --- changelog.d/6514.bugfix | 1 + .../storage/data_stores/main/client_ips.py | 8 +-- tests/storage/test_client_ips.py | 49 +++++++++++-------- 3 files changed, 35 insertions(+), 23 deletions(-) create mode 100644 changelog.d/6514.bugfix diff --git a/changelog.d/6514.bugfix b/changelog.d/6514.bugfix new file mode 100644 index 0000000000..6dc1985c24 --- /dev/null +++ b/changelog.d/6514.bugfix @@ -0,0 +1 @@ +Fix race which occasionally caused deleted devices to reappear. diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 320c5b0f07..add3037b69 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): # Technically an access token might not be associated with # a device so we need to check. if device_id: - self.db.simple_upsert_txn( + # this is always an update rather than an upsert: the row should + # already exist, and if it doesn't, that may be because it has been + # deleted, and we don't want to re-create it. + self.db.simple_update_txn( txn, table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - values={ + updatevalues={ "user_agent": user_agent, "last_seen": last_seen, "ip": ip, }, - lock=False, ) except Exception as e: # Failed to upsert, log and continue diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index fc279340d4..bf674dd184 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(12345678) user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.reactor.advance(10) result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, @@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) - # Force persisting to disk self.reactor.advance(200) @@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.get_success( self.store.db.simple_update( table="devices", - keyvalues={"user_id": user_id, "device_id": "device_id"}, + keyvalues={"user_id": user_id, "device_id": device_id}, updatevalues={"last_seen": None, "ip": None, "user_agent": None}, desc="test_devices_last_seen_bg_update", ) @@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get nulls when querying result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, @@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # We should now get the correct result again result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, @@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): self.store.db.updates.do_next_background_update(100), by=0.1 ) - # Insert a user IP user_id = "@user:id" + device_id = "MY_DEVICE" + + # Insert a user IP + self.get_success(self.store.store_device(user_id, device_id, "display name",)) self.get_success( self.store.insert_client_ip( - user_id, "access_token", "ip", "user_agent", "device_id" + user_id, "access_token", "ip", "user_agent", device_id ) ) @@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): "access_token": "access_token", "ip": "ip", "user_agent": "user_agent", - "device_id": "device_id", + "device_id": device_id, "last_seen": 0, } ], @@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): # But we should still get the correct values for the device result = self.get_success( - self.store.get_last_client_ip_by_device(user_id, "device_id") + self.store.get_last_client_ip_by_device(user_id, device_id) ) - r = result[(user_id, "device_id")] + r = result[(user_id, device_id)] self.assertDictContainsSubset( { "user_id": user_id, - "device_id": "device_id", + "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, From 4947de5a147d4f6a4e60aecac1284714fb64df8a Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Tue, 10 Dec 2019 17:30:16 +0000 Subject: [PATCH 29/82] Allow SAML username provider plugins (#6411) --- changelog.d/6411.feature | 1 + docs/saml_mapping_providers.md | 77 ++++++++++++ docs/sample_config.yaml | 59 ++++++--- synapse/config/saml2_config.py | 180 +++++++++++++++++++--------- synapse/handlers/saml_handler.py | 200 ++++++++++++++++++++++++++----- 5 files changed, 414 insertions(+), 103 deletions(-) create mode 100644 changelog.d/6411.feature create mode 100644 docs/saml_mapping_providers.md diff --git a/changelog.d/6411.feature b/changelog.d/6411.feature new file mode 100644 index 0000000000..ebea4a208d --- /dev/null +++ b/changelog.d/6411.feature @@ -0,0 +1 @@ +Allow custom SAML username mapping functinality through an external provider plugin. \ No newline at end of file diff --git a/docs/saml_mapping_providers.md b/docs/saml_mapping_providers.md new file mode 100644 index 0000000000..92f2380488 --- /dev/null +++ b/docs/saml_mapping_providers.md @@ -0,0 +1,77 @@ +# SAML Mapping Providers + +A SAML mapping provider is a Python class (loaded via a Python module) that +works out how to map attributes of a SAML response object to Matrix-specific +user attributes. Details such as user ID localpart, displayname, and even avatar +URLs are all things that can be mapped from talking to a SSO service. + +As an example, a SSO service may return the email address +"john.smith@example.com" for a user, whereas Synapse will need to figure out how +to turn that into a displayname when creating a Matrix user for this individual. +It may choose `John Smith`, or `Smith, John [Example.com]` or any number of +variations. As each Synapse configuration may want something different, this is +where SAML mapping providers come into play. + +## Enabling Providers + +External mapping providers are provided to Synapse in the form of an external +Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere, +then tell Synapse where to look for the handler class by editing the +`saml2_config.user_mapping_provider.module` config option. + +`saml2_config.user_mapping_provider.config` allows you to provide custom +configuration options to the module. Check with the module's documentation for +what options it provides (if any). The options listed by default are for the +user mapping provider built in to Synapse. If using a custom module, you should +comment these options out and use those specified by the module instead. + +## Building a Custom Mapping Provider + +A custom mapping provider must specify the following methods: + +* `__init__(self, parsed_config)` + - Arguments: + - `parsed_config` - A configuration object that is the return value of the + `parse_config` method. You should set any configuration options needed by + the module here. +* `saml_response_to_user_attributes(self, saml_response, failures)` + - Arguments: + - `saml_response` - A `saml2.response.AuthnResponse` object to extract user + information from. + - `failures` - An `int` that represents the amount of times the returned + mxid localpart mapping has failed. This should be used + to create a deduplicated mxid localpart which should be + returned instead. For example, if this method returns + `john.doe` as the value of `mxid_localpart` in the returned + dict, and that is already taken on the homeserver, this + method will be called again with the same parameters but + with failures=1. The method should then return a different + `mxid_localpart` value, such as `john.doe1`. + - This method must return a dictionary, which will then be used by Synapse + to build a new user. The following keys are allowed: + * `mxid_localpart` - Required. The mxid localpart of the new user. + * `displayname` - The displayname of the new user. If not provided, will default to + the value of `mxid_localpart`. +* `parse_config(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A `dict` representing the parsed content of the + `saml2_config.user_mapping_provider.config` homeserver config option. + Runs on homeserver startup. Providers should extract any option values + they need here. + - Whatever is returned will be passed back to the user mapping provider module's + `__init__` method during construction. +* `get_saml_attributes(config)` + - This method should have the `@staticmethod` decoration. + - Arguments: + - `config` - A object resulting from a call to `parse_config`. + - Returns a tuple of two sets. The first set equates to the saml auth + response attributes that are required for the module to function, whereas + the second set consists of those attributes which can be used if available, + but are not necessary. + +## Synapse's Default Provider + +Synapse has a built-in SAML mapping provider if a custom provider isn't +specified in the config. It is located at +[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py). diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 10664ae8f7..4d44e631d1 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1250,33 +1250,58 @@ saml2_config: # #config_path: "CONFDIR/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c5ea2d43a1..b91414aa35 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -14,17 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re +import logging from synapse.python_dependencies import DependencyException, check_requirements -from synapse.types import ( - map_username_to_mxid_localpart, - mxid_localpart_allowed_characters, -) -from synapse.util.module_loader import load_python_module +from synapse.util.module_loader import load_module, load_python_module from ._base import Config, ConfigError +logger = logging.getLogger(__name__) + +DEFAULT_USER_MAPPING_PROVIDER = ( + "synapse.handlers.saml_handler.DefaultSamlMappingProvider" +) + def _dict_merge(merge_dict, into_dict): """Do a deep merge of two dicts @@ -75,15 +77,69 @@ class SAML2Config(Config): self.saml2_enabled = True - self.saml2_mxid_source_attribute = saml2_config.get( - "mxid_source_attribute", "uid" - ) - self.saml2_grandfathered_mxid_source_attribute = saml2_config.get( "grandfathered_mxid_source_attribute", "uid" ) - saml2_config_dict = self._default_saml_config_dict() + # user_mapping_provider may be None if the key is present but has no value + ump_dict = saml2_config.get("user_mapping_provider") or {} + + # Use the default user mapping provider if not set + ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + + # Ensure a config is present + ump_dict["config"] = ump_dict.get("config") or {} + + if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER: + # Load deprecated options for use by the default module + old_mxid_source_attribute = saml2_config.get("mxid_source_attribute") + if old_mxid_source_attribute: + logger.warning( + "The config option saml2_config.mxid_source_attribute is deprecated. " + "Please use saml2_config.user_mapping_provider.config" + ".mxid_source_attribute instead." + ) + ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute + + old_mxid_mapping = saml2_config.get("mxid_mapping") + if old_mxid_mapping: + logger.warning( + "The config option saml2_config.mxid_mapping is deprecated. Please " + "use saml2_config.user_mapping_provider.config.mxid_mapping instead." + ) + ump_dict["config"]["mxid_mapping"] = old_mxid_mapping + + # Retrieve an instance of the module's class + # Pass the config dictionary to the module for processing + ( + self.saml2_user_mapping_provider_class, + self.saml2_user_mapping_provider_config, + ) = load_module(ump_dict) + + # Ensure loaded user mapping module has defined all necessary methods + # Note parse_config() is already checked during the call to load_module + required_methods = [ + "get_saml_attributes", + "saml_response_to_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(self.saml2_user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by saml2_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + # Get the desired saml auth response attributes from the module + saml2_config_dict = self._default_saml_config_dict( + *self.saml2_user_mapping_provider_class.get_saml_attributes( + self.saml2_user_mapping_provider_config + ) + ) _dict_merge( merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict ) @@ -103,22 +159,27 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "5m") ) - mapping = saml2_config.get("mxid_mapping", "hexencode") - try: - self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping] - except KeyError: - raise ConfigError("%s is not a known mxid_mapping" % (mapping,)) + def _default_saml_config_dict( + self, required_attributes: set, optional_attributes: set + ): + """Generate a configuration dictionary with required and optional attributes that + will be needed to process new user registration - def _default_saml_config_dict(self): + Args: + required_attributes: SAML auth response attributes that are + necessary to function + optional_attributes: SAML auth response attributes that can be used to add + additional information to Synapse user accounts, but are not required + + Returns: + dict: A SAML configuration dictionary + """ import saml2 public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") - required_attributes = {"uid", self.saml2_mxid_source_attribute} - - optional_attributes = {"displayName"} if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes @@ -207,33 +268,58 @@ class SAML2Config(Config): # #config_path: "%(config_dir_path)s/sp_conf.py" - # the lifetime of a SAML session. This defines how long a user has to + # The lifetime of a SAML session. This defines how long a user has to # complete the authentication process, if allow_unsolicited is unset. # The default is 5 minutes. # #saml_session_lifetime: 5m - # The SAML attribute (after mapping via the attribute maps) to use to derive - # the Matrix ID from. 'uid' by default. + # An external module can be provided here as a custom solution to + # mapping attributes returned from a saml provider onto a matrix user. # - #mxid_source_attribute: displayName + user_mapping_provider: + # The custom module's class. Uncomment to use a custom module. + # + #module: mapping_provider.SamlMappingProvider - # The mapping system to use for mapping the saml attribute onto a matrix ID. - # Options include: - # * 'hexencode' (which maps unpermitted characters to '=xx') - # * 'dotreplace' (which replaces unpermitted characters with '.'). - # The default is 'hexencode'. - # - #mxid_mapping: dotreplace + # Custom configuration values for the module. Below options are + # intended for the built-in provider, they should be changed if + # using a custom module. This section will be passed as a Python + # dictionary to the module's `parse_config` method. + # + config: + # The SAML attribute (after mapping via the attribute maps) to use + # to derive the Matrix ID from. 'uid' by default. + # + # Note: This used to be configured by the + # saml2_config.mxid_source_attribute option. If that is still + # defined, its value will be used instead. + # + #mxid_source_attribute: displayName - # In previous versions of synapse, the mapping from SAML attribute to MXID was - # always calculated dynamically rather than stored in a table. For backwards- - # compatibility, we will look for user_ids matching such a pattern before - # creating a new account. + # The mapping system to use for mapping the saml attribute onto a + # matrix ID. + # + # Options include: + # * 'hexencode' (which maps unpermitted characters to '=xx') + # * 'dotreplace' (which replaces unpermitted characters with + # '.'). + # The default is 'hexencode'. + # + # Note: This used to be configured by the + # saml2_config.mxid_mapping option. If that is still defined, its + # value will be used instead. + # + #mxid_mapping: dotreplace + + # In previous versions of synapse, the mapping from SAML attribute to + # MXID was always calculated dynamically rather than stored in a + # table. For backwards- compatibility, we will look for user_ids + # matching such a pattern before creating a new account. # # This setting controls the SAML attribute which will be used for this - # backwards-compatibility lookup. Typically it should be 'uid', but if the - # attribute maps are changed, it may be necessary to change it. + # backwards-compatibility lookup. Typically it should be 'uid', but if + # the attribute maps are changed, it may be necessary to change it. # # The default is 'uid'. # @@ -241,23 +327,3 @@ class SAML2Config(Config): """ % { "config_dir_path": config_dir_path } - - -DOT_REPLACE_PATTERN = re.compile( - ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) -) - - -def dot_replace_for_mxid(username: str) -> str: - username = username.lower() - username = DOT_REPLACE_PATTERN.sub(".", username) - - # regular mxids aren't allowed to start with an underscore either - username = re.sub("^_", "", username) - return username - - -MXID_MAPPER_MAP = { - "hexencode": map_username_to_mxid_localpart, - "dotreplace": dot_replace_for_mxid, -} diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index cc9e6b9bd0..0082f85c26 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re +from typing import Tuple import attr import saml2 +import saml2.response from saml2.client import Saml2Client from synapse.api.errors import SynapseError +from synapse.config import ConfigError from synapse.http.servlet import parse_string from synapse.rest.client.v1.login import SSOAuthHandler -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import ( + UserID, + map_username_to_mxid_localpart, + mxid_localpart_allowed_characters, +) from synapse.util.async_helpers import Linearizer logger = logging.getLogger(__name__) +@attr.s +class Saml2SessionData: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib() + + class SamlHandler: def __init__(self, hs): self._saml_client = Saml2Client(hs.config.saml2_sp_config) @@ -37,11 +53,14 @@ class SamlHandler: self._datastore = hs.get_datastore() self._hostname = hs.hostname self._saml2_session_lifetime = hs.config.saml2_session_lifetime - self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute self._grandfathered_mxid_source_attribute = ( hs.config.saml2_grandfathered_mxid_source_attribute ) - self._mxid_mapper = hs.config.saml2_mxid_mapper + + # plugin to do custom mapping from saml response to mxid + self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( + hs.config.saml2_user_mapping_provider_config + ) # identifier for the external_ids table self._auth_provider_id = "saml" @@ -118,22 +137,10 @@ class SamlHandler: remote_user_id = saml2_auth.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "uid not in SAML2 response") - - try: - mxid_source = saml2_auth.ava[self._mxid_source_attribute][0] - except KeyError: - logger.warning( - "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute - ) - raise SynapseError( - 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) - ) + raise SynapseError(400, "'uid' not in SAML2 response") self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None) - displayName = saml2_auth.ava.get("displayName", [None])[0] - with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user logger.info( @@ -173,22 +180,46 @@ class SamlHandler: ) return registered_user_id - # figure out a new mxid for this user - base_mxid_localpart = self._mxid_mapper(mxid_source) + # Map saml response to user attributes using the configured mapping provider + for i in range(1000): + attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes( + saml2_auth, i + ) - suffix = 0 - while True: - localpart = base_mxid_localpart + (str(suffix) if suffix else "") + logger.debug( + "Retrieved SAML attributes from user mapping provider: %s " + "(attempt %d)", + attribute_dict, + i, + ) + + localpart = attribute_dict.get("mxid_localpart") + if not localpart: + logger.error( + "SAML mapping provider plugin did not return a " + "mxid_localpart object" + ) + raise SynapseError(500, "Error parsing SAML2 response") + + displayname = attribute_dict.get("displayname") + + # Check if this mxid already exists if not await self._datastore.get_users_by_id_case_insensitive( UserID(localpart, self._hostname).to_string() ): + # This mxid is free break - suffix += 1 - logger.info("Allocating mxid for new user with localpart %s", localpart) + else: + # Unable to generate a username in 1000 iterations + # Break and return error to the user + raise SynapseError( + 500, "Unable to generate a Matrix ID from the SAML response" + ) registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=displayName + localpart=localpart, default_display_name=displayname ) + await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) @@ -205,9 +236,120 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] -@attr.s -class Saml2SessionData: - """Data we track about SAML2 sessions""" +DOT_REPLACE_PATTERN = re.compile( + ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),)) +) - # time the session was created, in milliseconds - creation_time = attr.ib() + +def dot_replace_for_mxid(username: str) -> str: + username = username.lower() + username = DOT_REPLACE_PATTERN.sub(".", username) + + # regular mxids aren't allowed to start with an underscore either + username = re.sub("^_", "", username) + return username + + +MXID_MAPPER_MAP = { + "hexencode": map_username_to_mxid_localpart, + "dotreplace": dot_replace_for_mxid, +} + + +@attr.s +class SamlConfig(object): + mxid_source_attribute = attr.ib() + mxid_mapper = attr.ib() + + +class DefaultSamlMappingProvider(object): + __version__ = "0.0.1" + + def __init__(self, parsed_config: SamlConfig): + """The default SAML user mapping provider + + Args: + parsed_config: Module configuration + """ + self._mxid_source_attribute = parsed_config.mxid_source_attribute + self._mxid_mapper = parsed_config.mxid_mapper + + def saml_response_to_user_attributes( + self, saml_response: saml2.response.AuthnResponse, failures: int = 0, + ) -> dict: + """Maps some text from a SAML response to attributes of a new user + + Args: + saml_response: A SAML auth response object + + failures: How many times a call to this function with this + saml_response has resulted in a failure + + Returns: + dict: A dict containing new user attributes. Possible keys: + * mxid_localpart (str): Required. The localpart of the user's mxid + * displayname (str): The displayname of the user + """ + try: + mxid_source = saml_response.ava[self._mxid_source_attribute][0] + except KeyError: + logger.warning( + "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute, + ) + raise SynapseError( + 400, "%s not in SAML2 response" % (self._mxid_source_attribute,) + ) + + # Use the configured mapper for this mxid_source + base_mxid_localpart = self._mxid_mapper(mxid_source) + + # Append suffix integer if last call to this function failed to produce + # a usable mxid + localpart = base_mxid_localpart + (str(failures) if failures else "") + + # Retrieve the display name from the saml response + # If displayname is None, the mxid_localpart will be used instead + displayname = saml_response.ava.get("displayName", [None])[0] + + return { + "mxid_localpart": localpart, + "displayname": displayname, + } + + @staticmethod + def parse_config(config: dict) -> SamlConfig: + """Parse the dict provided by the homeserver's config + Args: + config: A dictionary containing configuration options for this provider + Returns: + SamlConfig: A custom config object for this module + """ + # Parse config options and use defaults where necessary + mxid_source_attribute = config.get("mxid_source_attribute", "uid") + mapping_type = config.get("mxid_mapping", "hexencode") + + # Retrieve the associating mapping function + try: + mxid_mapper = MXID_MAPPER_MAP[mapping_type] + except KeyError: + raise ConfigError( + "saml2_config.user_mapping_provider.config: '%s' is not a valid " + "mxid_mapping value" % (mapping_type,) + ) + + return SamlConfig(mxid_source_attribute, mxid_mapper) + + @staticmethod + def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]: + """Returns the required attributes of a SAML + + Args: + config: A SamlConfig object containing configuration params for this provider + + Returns: + tuple[set,set]: The first set equates to the saml auth response + attributes that are required for the module to function, whereas the + second set consists of those attributes which can be used if + available, but are not necessary + """ + return {"uid", config.mxid_source_attribute}, {"displayName"} From f8bc2ae8830615698ae683cafe4fdddb9a05a1f9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 10 Dec 2019 17:42:46 +0000 Subject: [PATCH 30/82] Move get_state methods into FederationHandler (#6503) This is a non-functional refactor as a precursor to some other work. --- changelog.d/6503.misc | 1 + synapse/federation/federation_client.py | 91 +++------------------ synapse/handlers/federation.py | 101 ++++++++++++++++++++++-- 3 files changed, 107 insertions(+), 86 deletions(-) create mode 100644 changelog.d/6503.misc diff --git a/changelog.d/6503.misc b/changelog.d/6503.misc new file mode 100644 index 0000000000..e4e9a5a3d4 --- /dev/null +++ b/changelog.d/6503.misc @@ -0,0 +1 @@ +Move get_state methods into FederationHandler. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 73e1dda6a3..d396e6564f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -37,9 +37,9 @@ from synapse.api.room_versions import ( ) from synapse.events import builder, room_version_to_event_format from synapse.federation.federation_base import FederationBase, event_from_pdu_json -from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.logging.context import make_deferred_yieldable from synapse.logging.utils import log_function -from synapse.util import batch_iter, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -308,19 +308,12 @@ class FederationClient(FederationBase): return signed_pdu @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the room state at a given event from a remote homeserver. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + def get_room_state_ids(self, destination: str, room_id: str, event_id: str): + """Calls the /state_ids endpoint to fetch the state at a particular point + in the room, and the auth events for the given event Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids) """ result = yield self.transport_layer.get_room_state_ids( destination, room_id, event_id=event_id @@ -329,74 +322,12 @@ class FederationClient(FederationBase): state_event_ids = result["pdu_ids"] auth_event_ids = result.get("auth_chain_ids", []) - desired_events = set(state_event_ids + auth_event_ids) - event_map = yield self.get_events_from_store_or_dest( - destination, room_id, desired_events - ) + if not isinstance(state_event_ids, list) or not isinstance( + auth_event_ids, list + ): + raise Exception("invalid response from /state_ids") - failed_to_fetch = desired_events - event_map.keys() - if failed_to_fetch: - logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, - failed_to_fetch, - ) - - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] - - auth_chain.sort(key=lambda e: e.depth) - - return pdus, auth_chain - - @defer.inlineCallbacks - def get_events_from_store_or_dest(self, destination, room_id, event_ids): - """Fetch events from a remote destination, checking if we already have them. - - Args: - destination (str) - room_id (str) - event_ids (Iterable[str]) - - Returns: - Deferred[dict[str, EventBase]]: A deferred resolving to a map - from event_id to event - """ - fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) - - missing_events = set(event_ids) - fetched_events.keys() - - if not missing_events: - return fetched_events - - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, - ) - - room_version = yield self.store.get_room_version(room_id) - - # XXX 20 requests at once? really? - for batch in batch_iter(missing_events, 20): - deferreds = [ - run_in_background( - self.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = yield make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - for success, result in res: - if success and result: - fetched_events[result.event_id] = result - - return fetched_events + return state_event_ids, auth_event_ids @defer.inlineCallbacks @log_function diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bc26921768..c0dcf9abf8 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -64,7 +64,7 @@ from synapse.replication.http.federation import ( from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.types import UserID, get_domain_from_id -from synapse.util import unwrapFirstError +from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination @@ -379,11 +379,9 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self.federation_client.get_state_for_room( - origin, room_id, p - ) + ) = yield self._get_state_for_room(origin, room_id, p) - # we want the state *after* p; get_state_for_room returns the + # we want the state *after* p; _get_state_for_room returns the # state *before* p. remote_event = yield self.federation_client.get_pdu( [origin], p, room_version, outlier=True @@ -583,6 +581,97 @@ class FederationHandler(BaseHandler): else: raise + @defer.inlineCallbacks + @log_function + def _get_state_for_room(self, destination, room_id, event_id): + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. + """ + ( + state_event_ids, + auth_event_ids, + ) = yield self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + desired_events = set(state_event_ids + auth_event_ids) + event_map = yield self._get_events_from_store_or_dest( + destination, room_id, desired_events + ) + + failed_to_fetch = desired_events - event_map.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state/auth events for %s: %s", + room_id, + failed_to_fetch, + ) + + pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + + auth_chain.sort(key=lambda e: e.depth) + + return pdus, auth_chain + + @defer.inlineCallbacks + def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + """Fetch events from a remote destination, checking if we already have them. + + Args: + destination (str) + room_id (str) + event_ids (Iterable[str]) + + Returns: + Deferred[dict[str, EventBase]]: A deferred resolving to a map + from event_id to event + """ + fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + + missing_events = set(event_ids) - fetched_events.keys() + + if not missing_events: + return fetched_events + + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + event_ids, + ) + + room_version = yield self.store.get_room_version(room_id) + + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] + + res = yield make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) + ) + for success, result in res: + if success and result: + fetched_events[result.event_id] = result + + return fetched_events + @defer.inlineCallbacks def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it @@ -723,7 +812,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self.federation_client.get_state_for_room( + state, auth = yield self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) From ea0f0ad4144e3ce0cf10f3ec461ecd8f654955a2 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 11 Dec 2019 13:07:25 +0000 Subject: [PATCH 31/82] Prevent message search in upgraded rooms we're not in (#6385) --- changelog.d/6385.bugfix | 1 + synapse/handlers/federation.py | 4 +-- synapse/handlers/search.py | 34 +++++++++++++++++------ synapse/storage/data_stores/main/state.py | 18 ++++++++---- 4 files changed, 41 insertions(+), 16 deletions(-) create mode 100644 changelog.d/6385.bugfix diff --git a/changelog.d/6385.bugfix b/changelog.d/6385.bugfix new file mode 100644 index 0000000000..7a2bc02170 --- /dev/null +++ b/changelog.d/6385.bugfix @@ -0,0 +1 @@ +Prevent error on trying to search a upgraded room when the server is not in the predecessor room. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c0dcf9abf8..13865c470c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1299,7 +1299,7 @@ class FederationHandler(BaseHandler): # Check whether this room is the result of an upgrade of a room we already know # about. If so, migrate over user information predecessor = yield self.store.get_room_predecessor(room_id) - if not predecessor: + if not predecessor or not isinstance(predecessor.get("room_id"), str): return old_room_id = predecessor["room_id"] logger.debug( @@ -1542,7 +1542,7 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content): origin, event, event_format_version = yield self._make_and_verify_event( - target_hosts, room_id, user_id, "leave", content=content, + target_hosts, room_id, user_id, "leave", content=content ) # Mark as outlier as we don't have any state for this event; we're not # even in the room. diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 56ed262a1f..ef750d1497 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64 from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.api.errors import SynapseError +from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -37,6 +37,7 @@ class SearchHandler(BaseHandler): self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state + self.auth = hs.get_auth() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -53,23 +54,38 @@ class SearchHandler(BaseHandler): room_id (str): id of the room to search through. Returns: - Deferred[iterable[unicode]]: predecessor room ids + Deferred[iterable[str]]: predecessor room ids """ historical_room_ids = [] - while True: - predecessor = yield self.store.get_room_predecessor(room_id) + # The initial room must have been known for us to get this far + predecessor = yield self.store.get_room_predecessor(room_id) - # If no predecessor, assume we've hit a dead end + while True: if not predecessor: + # We have reached the end of the chain of predecessors break - # Add predecessor's room ID - historical_room_ids.append(predecessor["room_id"]) + if not isinstance(predecessor.get("room_id"), str): + # This predecessor object is malformed. Exit here + break - # Scan through the old room for further predecessors - room_id = predecessor["room_id"] + predecessor_room_id = predecessor["room_id"] + + # Don't add it to the list until we have checked that we are in the room + try: + next_predecessor_room = yield self.store.get_room_predecessor( + predecessor_room_id + ) + except NotFoundError: + # The predecessor is not a known room, so we are done here + break + + historical_room_ids.append(predecessor_room_id) + + # And repeat + predecessor = next_predecessor_room return historical_room_ids diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 9ef7b48c74..dcc6b43cdf 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -278,7 +278,7 @@ class StateGroupWorkerStore( @defer.inlineCallbacks def get_room_predecessor(self, room_id): - """Get the predecessor room of an upgraded room if one exists. + """Get the predecessor of an upgraded room if it exists. Otherwise return None. Args: @@ -291,14 +291,22 @@ class StateGroupWorkerStore( * room_id (str): The room ID of the predecessor room * event_id (str): The ID of the tombstone event in the predecessor room + None if a predecessor key is not found, or is not a dictionary. + Raises: - NotFoundError if the room is unknown + NotFoundError if the given room is unknown """ # Retrieve the room's create event create_event = yield self.get_create_event_for_room(room_id) - # Return predecessor if present - return create_event.content.get("predecessor", None) + # Retrieve the predecessor key of the create event + predecessor = create_event.content.get("predecessor", None) + + # Ensure the key is a dictionary + if not isinstance(predecessor, dict): + return None + + return predecessor @defer.inlineCallbacks def get_create_event_for_room(self, room_id): @@ -318,7 +326,7 @@ class StateGroupWorkerStore( # If we can't find the create event, assume we've hit a dead end if not create_id: - raise NotFoundError("Unknown room %s" % (room_id)) + raise NotFoundError("Unknown room %s" % (room_id,)) # Retrieve the room's create event and return create_event = yield self.get_event(create_id) From 6676ee9c4a74e15afdd752e05ca38d82da94c2c1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 11 Dec 2019 13:16:01 +0000 Subject: [PATCH 32/82] Add dev script to generate full SQL schema files (#6394) --- changelog.d/6394.feature | 1 + scripts-dev/make_full_schema.sh | 184 ++++++++++++++++++ .../main/schema/full_schemas/README.md | 13 ++ .../main/schema/full_schemas/README.txt | 19 -- 4 files changed, 198 insertions(+), 19 deletions(-) create mode 100644 changelog.d/6394.feature create mode 100755 scripts-dev/make_full_schema.sh create mode 100644 synapse/storage/data_stores/main/schema/full_schemas/README.md delete mode 100644 synapse/storage/data_stores/main/schema/full_schemas/README.txt diff --git a/changelog.d/6394.feature b/changelog.d/6394.feature new file mode 100644 index 0000000000..1a0e8845ad --- /dev/null +++ b/changelog.d/6394.feature @@ -0,0 +1 @@ +Add a develop script to generate full SQL schemas. \ No newline at end of file diff --git a/scripts-dev/make_full_schema.sh b/scripts-dev/make_full_schema.sh new file mode 100755 index 0000000000..60e8970a35 --- /dev/null +++ b/scripts-dev/make_full_schema.sh @@ -0,0 +1,184 @@ +#!/bin/bash +# +# This script generates SQL files for creating a brand new Synapse DB with the latest +# schema, on both SQLite3 and Postgres. +# +# It does so by having Synapse generate an up-to-date SQLite DB, then running +# synapse_port_db to convert it to Postgres. It then dumps the contents of both. + +POSTGRES_HOST="localhost" +POSTGRES_DB_NAME="synapse_full_schema.$$" + +SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite" +POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres" + +REQUIRED_DEPS=("matrix-synapse" "psycopg2") + +usage() { + echo + echo "Usage: $0 -p -o [-c] [-n] [-h]" + echo + echo "-p " + echo " Username to connect to local postgres instance. The password will be requested" + echo " during script execution." + echo "-c" + echo " CI mode. Enables coverage tracking and prints every command that the script runs." + echo "-o " + echo " Directory to output full schema files to." + echo "-h" + echo " Display this help text." +} + +while getopts "p:co:h" opt; do + case $opt in + p) + POSTGRES_USERNAME=$OPTARG + ;; + c) + # Print all commands that are being executed + set -x + + # Modify required dependencies for coverage + REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess") + + COVERAGE=1 + ;; + o) + command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1) + OUTPUT_DIR="$(realpath "$OPTARG")" + ;; + h) + usage + exit + ;; + \?) + echo "ERROR: Invalid option: -$OPTARG" >&2 + usage + exit + ;; + esac +done + +# Check that required dependencies are installed +unsatisfied_requirements=() +for dep in "${REQUIRED_DEPS[@]}"; do + pip show "$dep" --quiet || unsatisfied_requirements+=("$dep") +done +if [ ${#unsatisfied_requirements} -ne 0 ]; then + echo "Please install the following python packages: ${unsatisfied_requirements[*]}" + exit 1 +fi + +if [ -z "$POSTGRES_USERNAME" ]; then + echo "No postgres username supplied" + usage + exit 1 +fi + +if [ -z "$OUTPUT_DIR" ]; then + echo "No output directory supplied" + usage + exit 1 +fi + +# Create the output directory if it doesn't exist +mkdir -p "$OUTPUT_DIR" + +read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD +echo "" + +# Exit immediately if a command fails +set -e + +# cd to root of the synapse directory +cd "$(dirname "$0")/.." + +# Create temporary SQLite and Postgres homeserver db configs and key file +TMPDIR=$(mktemp -d) +KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path +SQLITE_CONFIG=$TMPDIR/sqlite.conf +SQLITE_DB=$TMPDIR/homeserver.db +POSTGRES_CONFIG=$TMPDIR/postgres.conf + +# Ensure these files are delete on script exit +trap 'rm -rf $TMPDIR' EXIT + +cat > "$SQLITE_CONFIG" < "$POSTGRES_CONFIG" < "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE" + +echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..." +pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE" + +echo "Cleaning up temporary Postgres database..." +dropdb $POSTGRES_DB_NAME + +echo "Done! Files dumped to: $OUTPUT_DIR" diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md new file mode 100644 index 0000000000..bbd3f18604 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md @@ -0,0 +1,13 @@ +# Building full schema dumps + +These schemas need to be made from a database that has had all background updates run. + +To do so, use `scripts-dev/make_full_schema.sh`. This will produce +`full.sql.postgres ` and `full.sql.sqlite` files. + +Ensure postgres is installed and your user has the ability to run bash commands +such as `createdb`. + +``` +./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/ +``` diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt deleted file mode 100644 index d3f6401344..0000000000 --- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt +++ /dev/null @@ -1,19 +0,0 @@ -Building full schema dumps -========================== - -These schemas need to be made from a database that has had all background updates run. - -Postgres --------- - -$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres - -SQLite ------- - -$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite - -After ------ - -Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas". \ No newline at end of file From fc316a4894912f49f5d0321e533aabca5624b0ba Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 11 Dec 2019 13:39:47 +0000 Subject: [PATCH 33/82] Prevent redacted events from appearing in message search (#6377) --- changelog.d/6377.bugfix | 1 + synapse/handlers/federation.py | 7 +- synapse/handlers/message.py | 5 +- synapse/state/__init__.py | 3 +- .../storage/data_stores/main/events_worker.py | 97 ++++++++++++------- synapse/storage/data_stores/main/search.py | 8 +- 6 files changed, 78 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6377.bugfix diff --git a/changelog.d/6377.bugfix b/changelog.d/6377.bugfix new file mode 100644 index 0000000000..ccda96962f --- /dev/null +++ b/changelog.d/6377.bugfix @@ -0,0 +1 @@ +Prevent redacted events from being returned during message search. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 13865c470c..8f3c9d7702 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -63,6 +63,7 @@ from synapse.replication.http.federation import ( ) from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.state import StateResolutionStore, resolve_events_with_store +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer @@ -423,7 +424,7 @@ class FederationHandler(BaseHandler): evs = yield self.store.get_events( list(state_map.values()), get_prev_content=False, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, ) event_map.update(evs) @@ -1000,7 +1001,9 @@ class FederationHandler(BaseHandler): forward_events = yield self.store.get_successor_events(list(extremities)) extremities_events = yield self.store.get_events( - forward_events, check_redacted=False, get_prev_content=False + forward_events, + redact_behaviour=EventRedactBehaviour.AS_IS, + get_prev_content=False, ) # We set `check_history_visibility_only` as we might otherwise get false diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 54fa216d83..bf9add7fe2 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator from synapse.logging.context import run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester from synapse.util.async_helpers import Linearizer @@ -875,7 +876,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, @@ -952,7 +953,7 @@ class EventCreationHandler(object): if event.type == EventTypes.Redaction: original_event = yield self.store.get_event( event.redacts, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=False, allow_none=True, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 139beef8ed..3e6d62eef1 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,6 +32,7 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.logging.utils import log_function from synapse.state import v1, v2 +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.util.async_helpers import Linearizer from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache @@ -645,7 +646,7 @@ class StateResolutionStore(object): return self.store.get_events( event_ids, - check_redacted=False, + redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, allow_rejected=allow_rejected, ) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 9ee117ce0f..2c9142814c 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -19,8 +19,10 @@ import itertools import logging import threading from collections import namedtuple +from typing import List, Optional from canonicaljson import json +from constantly import NamedConstant, Names from twisted.internet import defer @@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) +class EventRedactBehaviour(Names): + """ + What to do when retrieving a redacted event from the database. + """ + + AS_IS = NamedConstant() + REDACT = NamedConstant() + BLOCK = NamedConstant() + + class EventsWorkerStore(SQLBaseStore): def __init__(self, database: Database, db_conn, hs): super(EventsWorkerStore, self).__init__(database, db_conn, hs) @@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_event( self, - event_id, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, - allow_none=False, - check_room_id=None, + event_id: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, + allow_none: bool = False, + check_room_id: Optional[str] = None, ): """Get an event from the database by event_id. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_id: The event_id of the event to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if + allow_rejected: If True return rejected events. + allow_none: If True, return None if no event found, if False throw a NotFoundError - check_room_id (str|None): if not None, check the room of the found event. + check_room_id: if not None, check the room of the found event. If there is a mismatch, behave as per allow_none. Returns: @@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore): events = yield self.get_events_as_list( [event_id], - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible + values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True return rejected events. Returns: Deferred : Dict from event_id to event. """ events = yield self.get_events_as_list( event_ids, - check_redacted=check_redacted, + redact_behaviour=redact_behaviour, get_prev_content=get_prev_content, allow_rejected=allow_rejected, ) @@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_events_as_list( self, - event_ids, - check_redacted=True, - get_prev_content=False, - allow_rejected=False, + event_ids: List[str], + redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT, + get_prev_content: bool = False, + allow_rejected: bool = False, ): """Get events from the database and return in a list in the same order as given by `event_ids` arg. Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, + event_ids: The event_ids of the events to fetch + redact_behaviour: Determine what to do with a redacted event. Possible values: + * AS_IS - Return the full event body with no redacted content + * REDACT - Return the event but with a redacted body + * DISALLOW - Do not return redacted events + get_prev_content: If True and event is a state event, include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + allow_rejected: If True, return rejected events. Returns: Deferred[list[EventBase]]: List of events fetched from the database. The @@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore): # Update the cache to save doing the checks again. entry.event.internal_metadata.recheck_redaction = False - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event + event = entry.event + + if entry.redacted_event: + if redact_behaviour == EventRedactBehaviour.BLOCK: + # Skip this event + continue + elif redact_behaviour == EventRedactBehaviour.REDACT: + event = entry.redacted_event events.append(event) diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index 4eec2fae5e..dfb46ee0f8 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -25,6 +25,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.database import Database from synapse.storage.engines import PostgresEngine, Sqlite3Engine @@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} From 7c429f92d6935a3e9e0140fdd82801edc43b66b8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 11 Dec 2019 14:32:25 +0000 Subject: [PATCH 34/82] Clean up some logging (#6515) This just makes some of the logging easier to follow when things start going wrong. --- changelog.d/6515.misc | 1 + synapse/handlers/federation.py | 37 +++++++++++++++++----------------- 2 files changed, 20 insertions(+), 18 deletions(-) create mode 100644 changelog.d/6515.misc diff --git a/changelog.d/6515.misc b/changelog.d/6515.misc new file mode 100644 index 0000000000..a9c303ed1c --- /dev/null +++ b/changelog.d/6515.misc @@ -0,0 +1 @@ +Clean up some logging when handling incoming events over federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 8f3c9d7702..cf9c46d027 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -183,7 +183,7 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu) + logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers existing = yield self.store.get_event( @@ -279,9 +279,15 @@ class FederationHandler(BaseHandler): len(missing_prevs), ) - yield self._get_missing_events_for_pdu( - origin, pdu, prevs, min_depth - ) + try: + yield self._get_missing_events_for_pdu( + origin, pdu, prevs, min_depth + ) + except Exception as e: + raise Exception( + "Error fetching missing prev_events for %s: %s" + % (event_id, e) + ) # Update the set of things we've seen after trying to # fetch the missing stuff @@ -293,14 +299,6 @@ class FederationHandler(BaseHandler): room_id, event_id, ) - elif missing_prevs: - logger.info( - "[%s %s] Not recursively fetching %d missing prev_events: %s", - room_id, - event_id, - len(missing_prevs), - shortstr(missing_prevs), - ) if prevs - seen: # We've still not been able to get all of the prev_events for this event. @@ -345,6 +343,12 @@ class FederationHandler(BaseHandler): affected=pdu.event_id, ) + logger.info( + "Event %s is missing prev_events: calculating state for a " + "backwards extremity", + event_id, + ) + # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. auth_chains = set() @@ -365,10 +369,7 @@ class FederationHandler(BaseHandler): # know about for p in prevs - seen: logger.info( - "[%s %s] Requesting state at missing prev_event %s", - room_id, - event_id, - p, + "Requesting state at missing prev_event %s", event_id, ) room_version = yield self.store.get_room_version(room_id) @@ -612,8 +613,8 @@ class FederationHandler(BaseHandler): failed_to_fetch = desired_events - event_map.keys() if failed_to_fetch: logger.warning( - "Failed to fetch missing state/auth events for %s: %s", - room_id, + "Failed to fetch missing state/auth events for %s %s", + event_id, failed_to_fetch, ) From 7712e751b87b83086a8cb0a1cda1eef40e177d07 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 16:54:34 +0000 Subject: [PATCH 35/82] Convert federation backfill to async PaginationHandler.get_messages is only called by RoomMessageListRestServlet, which is async. Chase the code path down from there: - FederationHandler.maybe_backfill (and nested try_backfill) - FederationHandler.backfill --- synapse/handlers/federation.py | 47 ++++++++++++++++------------------ synapse/handlers/pagination.py | 27 ++++++++++--------- 2 files changed, 35 insertions(+), 39 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index cf9c46d027..e54d509b62 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -756,8 +756,7 @@ class FederationHandler(BaseHandler): yield self.user_joined_room(user, room_id) @log_function - @defer.inlineCallbacks - def backfill(self, dest, room_id, limit, extremities): + async def backfill(self, dest, room_id, limit, extremities): """ Trigger a backfill request to `dest` for the given `room_id` This will attempt to get more events from the remote. If the other side @@ -774,9 +773,9 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) - events = yield self.federation_client.backfill( + events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -791,7 +790,7 @@ class FederationHandler(BaseHandler): # self._sanity_check_event(ev) # Don't bother processing events we already have. - seen_events = yield self.store.have_events_in_timeline( + seen_events = await self.store.have_events_in_timeline( set(e.event_id for e in events) ) @@ -814,7 +813,7 @@ class FederationHandler(BaseHandler): state_events = {} events_to_state = {} for e_id in edges: - state, auth = yield self._get_state_for_room( + state, auth = await self._get_state_for_room( destination=dest, room_id=room_id, event_id=e_id ) auth_events.update({a.event_id: a for a in auth}) @@ -839,7 +838,7 @@ class FederationHandler(BaseHandler): # We repeatedly do this until we stop finding new auth events. while missing_auth - failed_to_fetch: logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) + ret_events = await self.store.get_events(missing_auth - failed_to_fetch) auth_events.update(ret_events) required_auth.update( @@ -853,7 +852,7 @@ class FederationHandler(BaseHandler): missing_auth - failed_to_fetch, ) - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -880,7 +879,7 @@ class FederationHandler(BaseHandler): failed_to_fetch = missing_auth - set(auth_events) - seen_events = yield self.store.have_seen_events( + seen_events = await self.store.have_seen_events( set(auth_events.keys()) | set(state_events.keys()) ) @@ -942,7 +941,7 @@ class FederationHandler(BaseHandler): ) ) - yield self._handle_new_events(dest, ev_infos, backfilled=True) + await self._handle_new_events(dest, ev_infos, backfilled=True) # Step 2: Persist the rest of the events in the chunk one by one events.sort(key=lambda e: e.depth) @@ -958,16 +957,15 @@ class FederationHandler(BaseHandler): # We store these one at a time since each event depends on the # previous to work out the state. # TODO: We can probably do something more clever here. - yield self._handle_new_event(dest, event, backfilled=True) + await self._handle_new_event(dest, event, backfilled=True) return events - @defer.inlineCallbacks - def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill(self, room_id, current_depth): """Checks the database to see if we should backfill before paginating, and if so do. """ - extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id) + extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") @@ -999,9 +997,9 @@ class FederationHandler(BaseHandler): # state *before* the event, ignoring the special casing certain event # types have. - forward_events = yield self.store.get_successor_events(list(extremities)) + forward_events = await self.store.get_successor_events(list(extremities)) - extremities_events = yield self.store.get_events( + extremities_events = await self.store.get_events( forward_events, redact_behaviour=EventRedactBehaviour.AS_IS, get_prev_content=False, @@ -1009,7 +1007,7 @@ class FederationHandler(BaseHandler): # We set `check_history_visibility_only` as we might otherwise get false # positives from users having been erased. - filtered_extremities = yield filter_events_for_server( + filtered_extremities = await filter_events_for_server( self.storage, self.server_name, list(extremities_events.values()), @@ -1039,7 +1037,7 @@ class FederationHandler(BaseHandler): # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = yield self.state_handler.get_current_state(room_id) + curr_state = await self.state_handler.get_current_state(room_id) def get_domains_from_state(state): """Get joined domains from state @@ -1078,12 +1076,11 @@ class FederationHandler(BaseHandler): domain for domain, depth in curr_domains if domain != self.server_name ] - @defer.inlineCallbacks - def try_backfill(domains): + async def try_backfill(domains): # TODO: Should we try multiple of these at a time? for dom in domains: try: - yield self.backfill( + await self.backfill( dom, room_id, limit=100, extremities=extremities ) # If this succeeded then we probably already have the @@ -1114,7 +1111,7 @@ class FederationHandler(BaseHandler): return False - success = yield try_backfill(likely_domains) + success = await try_backfill(likely_domains) if success: return True @@ -1128,7 +1125,7 @@ class FederationHandler(BaseHandler): logger.debug("calling resolve_state_groups in _maybe_backfill") resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events) - states = yield make_deferred_yieldable( + states = await make_deferred_yieldable( defer.gatherResults( [resolve(room_id, [e]) for e in event_ids], consumeErrors=True ) @@ -1138,7 +1135,7 @@ class FederationHandler(BaseHandler): # event_ids. states = dict(zip(event_ids, [s.state for s in states])) - state_map = yield self.store.get_events( + state_map = await self.store.get_events( [e_id for ids in itervalues(states) for e_id in itervalues(ids)], get_prev_content=False, ) @@ -1154,7 +1151,7 @@ class FederationHandler(BaseHandler): for e_id, _ in sorted_extremeties_tuple: likely_domains = get_domains_from_state(states[e_id]) - success = yield try_backfill( + success = await try_backfill( [dom for dom, _ in likely_domains if dom not in tried_domains] ) if success: diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 8514ddc600..00a6afc963 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -280,8 +280,7 @@ class PaginationHandler(object): await self.storage.purge_events.purge_room(room_id) - @defer.inlineCallbacks - def get_messages( + async def get_messages( self, requester, room_id=None, @@ -307,7 +306,7 @@ class PaginationHandler(object): room_token = pagin_config.from_token.room_key else: pagin_config.from_token = ( - yield self.hs.get_event_sources().get_current_token_for_pagination() + await self.hs.get_event_sources().get_current_token_for_pagination() ) room_token = pagin_config.from_token.room_key @@ -319,11 +318,11 @@ class PaginationHandler(object): source_config = pagin_config.get_source_config("room") - with (yield self.pagination_lock.read(room_id)): + with (await self.pagination_lock.read(room_id)): ( membership, member_event_id, - ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id) + ) = await self.auth.check_in_room_or_world_readable(room_id, user_id) if source_config.direction == "b": # if we're going backwards, we might need to backfill. This @@ -331,7 +330,7 @@ class PaginationHandler(object): if room_token.topological: max_topo = room_token.topological else: - max_topo = yield self.store.get_max_topological_token( + max_topo = await self.store.get_max_topological_token( room_id, room_token.stream ) @@ -339,18 +338,18 @@ class PaginationHandler(object): # If they have left the room then clamp the token to be before # they left the room, to save the effort of loading from the # database. - leave_token = yield self.store.get_topological_token_for_event( + leave_token = await self.store.get_topological_token_for_event( member_event_id ) leave_token = RoomStreamToken.parse(leave_token) if leave_token.topological < max_topo: source_config.from_key = str(leave_token) - yield self.hs.get_handlers().federation_handler.maybe_backfill( + await self.hs.get_handlers().federation_handler.maybe_backfill( room_id, max_topo ) - events, next_key = yield self.store.paginate_room_events( + events, next_key = await self.store.paginate_room_events( room_id=room_id, from_key=source_config.from_key, to_key=source_config.to_key, @@ -365,7 +364,7 @@ class PaginationHandler(object): if event_filter: events = event_filter.filter(events) - events = yield filter_events_for_client( + events = await filter_events_for_client( self.storage, user_id, events, is_peeking=(member_event_id is None) ) @@ -385,19 +384,19 @@ class PaginationHandler(object): (EventTypes.Member, event.sender) for event in events ) - state_ids = yield self.state_store.get_state_ids_for_event( + state_ids = await self.state_store.get_state_ids_for_event( events[0].event_id, state_filter=state_filter ) if state_ids: - state = yield self.store.get_events(list(state_ids.values())) + state = await self.store.get_events(list(state_ids.values())) state = state.values() time_now = self.clock.time_msec() chunk = { "chunk": ( - yield self._event_serializer.serialize_events( + await self._event_serializer.serialize_events( events, time_now, as_client_event=as_client_event ) ), @@ -406,7 +405,7 @@ class PaginationHandler(object): } if state: - chunk["state"] = yield self._event_serializer.serialize_events( + chunk["state"] = await self._event_serializer.serialize_events( state, time_now, as_client_event=as_client_event ) From e77237b9350a7054657b1a641883a09c6b5d44f3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:01:37 +0000 Subject: [PATCH 36/82] convert to async: FederationHandler.on_receive_pdu and associated functions: * on_receive_pdu * handle_queued_pdus * get_missing_events_for_pdu --- synapse/handlers/federation.py | 49 +++++++++++++++------------------- tests/test_federation.py | 14 ++++++---- 2 files changed, 31 insertions(+), 32 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index e54d509b62..01d9c5120e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -165,8 +165,7 @@ class FederationHandler(BaseHandler): self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages - @defer.inlineCallbacks - def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False): + async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None: """ Process a PDU received via a federation /send/ transaction, or via backfill of missing prev_events @@ -176,8 +175,6 @@ class FederationHandler(BaseHandler): pdu (FrozenEvent): received PDU sent_to_us_directly (bool): True if this event was pushed to us; False if we pulled it as the result of a missing prev_event. - - Returns (Deferred): completes with None """ room_id = pdu.room_id @@ -186,7 +183,7 @@ class FederationHandler(BaseHandler): logger.info("handling received PDU: %s", pdu) # We reprocess pdus when we have seen them only as outliers - existing = yield self.store.get_event( + existing = await self.store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -230,7 +227,7 @@ class FederationHandler(BaseHandler): # # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. - is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name) + is_in_room = await self.auth.check_host_in_room(room_id, self.server_name) if not is_in_room: logger.info( "[%s %s] Ignoring PDU from %s as we're not in the room", @@ -246,12 +243,12 @@ class FederationHandler(BaseHandler): # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): # We only backfill backwards to the min depth. - min_depth = yield self.get_min_depth_for_context(pdu.room_id) + min_depth = await self.get_min_depth_for_context(pdu.room_id) logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth) prevs = set(pdu.prev_event_ids()) - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this @@ -271,7 +268,7 @@ class FederationHandler(BaseHandler): len(missing_prevs), shortstr(missing_prevs), ) - with (yield self._room_pdu_linearizer.queue(pdu.room_id)): + with (await self._room_pdu_linearizer.queue(pdu.room_id)): logger.info( "[%s %s] Acquired room lock to fetch %d missing prev_events", room_id, @@ -280,7 +277,7 @@ class FederationHandler(BaseHandler): ) try: - yield self._get_missing_events_for_pdu( + await self._get_missing_events_for_pdu( origin, pdu, prevs, min_depth ) except Exception as e: @@ -291,7 +288,7 @@ class FederationHandler(BaseHandler): # Update the set of things we've seen after trying to # fetch the missing stuff - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: logger.info( @@ -355,7 +352,7 @@ class FederationHandler(BaseHandler): event_map = {event_id: pdu} try: # Get the state of the events we know about - ours = yield self.state_store.get_state_groups_ids(room_id, seen) + ours = await self.state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps = list( @@ -372,7 +369,7 @@ class FederationHandler(BaseHandler): "Requesting state at missing prev_event %s", event_id, ) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) with nested_logging_context(p): # note that if any of the missing prevs share missing state or @@ -381,11 +378,11 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = yield self._get_state_for_room(origin, room_id, p) + ) = await self._get_state_for_room(origin, room_id, p) # we want the state *after* p; _get_state_for_room returns the # state *before* p. - remote_event = yield self.federation_client.get_pdu( + remote_event = await self.federation_client.get_pdu( [origin], p, room_version, outlier=True ) @@ -410,7 +407,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x - state_map = yield resolve_events_with_store( + state_map = await resolve_events_with_store( room_version, state_maps, event_map, @@ -422,7 +419,7 @@ class FederationHandler(BaseHandler): # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = yield self.store.get_events( + evs = await self.store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -446,12 +443,11 @@ class FederationHandler(BaseHandler): affected=event_id, ) - yield self._process_received_pdu( + await self._process_received_pdu( origin, pdu, state=state, auth_chain=auth_chain ) - @defer.inlineCallbacks - def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): + async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ Args: origin (str): Origin of the pdu. Will be called to get the missing events @@ -463,12 +459,12 @@ class FederationHandler(BaseHandler): room_id = pdu.room_id event_id = pdu.event_id - seen = yield self.store.have_seen_events(prevs) + seen = await self.store.have_seen_events(prevs) if not prevs - seen: return - latest = yield self.store.get_latest_event_ids_in_room(room_id) + latest = await self.store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -532,7 +528,7 @@ class FederationHandler(BaseHandler): # All that said: Let's try increasing the timout to 60s and see what happens. try: - missing_events = yield self.federation_client.get_missing_events( + missing_events = await self.federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -571,7 +567,7 @@ class FederationHandler(BaseHandler): ) with nested_logging_context(ev.event_id): try: - yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False) + await self.on_receive_pdu(origin, ev, sent_to_us_directly=False) except FederationError as e: if e.code == 403: logger.warning( @@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler): return True - @defer.inlineCallbacks - def _handle_queued_pdus(self, room_queue): + async def _handle_queued_pdus(self, room_queue): """Process PDUs which got queued up while we were busy send_joining. Args: @@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler): p.room_id, ) with nested_logging_context(p.event_id): - yield self.on_receive_pdu(origin, p, sent_to_us_directly=True) + await self.on_receive_pdu(origin, p, sent_to_us_directly=True) except Exception as e: logger.warning( "Error handling queued PDU %s from %s: %s", p.event_id, origin, e diff --git a/tests/test_federation.py b/tests/test_federation.py index ad165d7295..68684460c6 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -1,6 +1,6 @@ from mock import Mock -from twisted.internet.defer import maybeDeferred, succeed +from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed from synapse.events import FrozenEvent from synapse.logging.context import LoggingContext @@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase): ) # Send the join, it should return None (which is not an error) - d = self.handler.on_receive_pdu( - "test.serv", join_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) ) self.reactor.advance(1) self.assertEqual(self.successResultOf(d), None) @@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase): ) with LoggingContext(request="lying_event"): - d = self.handler.on_receive_pdu( - "test.serv", lying_event, sent_to_us_directly=True + d = ensureDeferred( + self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) ) # Step the reactor, so the database fetches come back From 4db394a4b3e9c823655cdd3e716a5f234107d337 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:25:18 +0000 Subject: [PATCH 37/82] convert to async: FederationHandler._get_state_for_room ... and _get_events_from_store_or_dest --- synapse/handlers/federation.py | 42 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 01d9c5120e..724cae9647 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -19,7 +19,7 @@ import itertools import logging -from typing import Dict, Iterable, Optional, Sequence, Tuple +from typing import Dict, Iterable, List, Optional, Sequence, Tuple import six from six import iteritems, itervalues @@ -579,30 +579,30 @@ class FederationHandler(BaseHandler): else: raise - @defer.inlineCallbacks @log_function - def _get_state_for_room(self, destination, room_id, event_id): + async def _get_state_for_room( + self, destination: str, room_id: str, event_id: str + ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. + destination:: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. Returns: - Deferred[Tuple[List[EventBase], List[EventBase]]]: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, and a list of events in the auth chain + for the given event. """ ( state_event_ids, auth_event_ids, - ) = yield self.federation_client.get_room_state_ids( + ) = await self.federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) desired_events = set(state_event_ids + auth_event_ids) - event_map = yield self._get_events_from_store_or_dest( + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -621,20 +621,20 @@ class FederationHandler(BaseHandler): return pdus, auth_chain - @defer.inlineCallbacks - def _get_events_from_store_or_dest(self, destination, room_id, event_ids): + async def _get_events_from_store_or_dest( + self, destination: str, room_id: str, event_ids: Iterable[str] + ) -> Dict[str, EventBase]: """Fetch events from a remote destination, checking if we already have them. Args: - destination (str) - room_id (str) - event_ids (Iterable[str]) + destination + room_id + event_ids Returns: - Deferred[dict[str, EventBase]]: A deferred resolving to a map - from event_id to event + map from event_id to event """ - fetched_events = yield self.store.get_events(event_ids, allow_rejected=True) + fetched_events = await self.store.get_events(event_ids, allow_rejected=True) missing_events = set(event_ids) - fetched_events.keys() @@ -647,7 +647,7 @@ class FederationHandler(BaseHandler): event_ids, ) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) # XXX 20 requests at once? really? for batch in batch_iter(missing_events, 20): @@ -661,7 +661,7 @@ class FederationHandler(BaseHandler): for e_id in batch ] - res = yield make_deferred_yieldable( + res = await make_deferred_yieldable( defer.DeferredList(deferreds, consumeErrors=True) ) for success, result in res: From 6637d90d778f5604b4827ab4f7d9a4cf05802466 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:27:13 +0000 Subject: [PATCH 38/82] convert to async: FederationHandler._process_received_pdu also fix user_joined_room to consistently return deferreds --- synapse/handlers/federation.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 724cae9647..bcd3b422aa 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -670,8 +670,7 @@ class FederationHandler(BaseHandler): return fetched_events - @defer.inlineCallbacks - def _process_received_pdu(self, origin, event, state, auth_chain): + async def _process_received_pdu(self, origin, event, state, auth_chain): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. """ @@ -686,7 +685,7 @@ class FederationHandler(BaseHandler): if auth_chain: event_ids |= {e.event_id for e in auth_chain} - seen_ids = yield self.store.have_seen_events(event_ids) + seen_ids = await self.store.have_seen_events(event_ids) if state and auth_chain is not None: # If we have any state or auth_chain given to us by the replication @@ -713,18 +712,18 @@ class FederationHandler(BaseHandler): event_id, [e.event.event_id for e in event_infos], ) - yield self._handle_new_events(origin, event_infos) + await self._handle_new_events(origin, event_infos) try: - context = yield self._handle_new_event(origin, event, state=state) + context = await self._handle_new_event(origin, event, state=state) except AuthError as e: raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) - room = yield self.store.get_room(room_id) + room = await self.store.get_room(room_id) if not room: try: - yield self.store.store_room( + await self.store.store_room( room_id=room_id, room_creator_user_id="", is_public=False ) except StoreError: @@ -737,11 +736,11 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids(self.store) prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: - prev_state = yield self.store.get_event( + prev_state = await self.store.get_event( prev_state_id, allow_none=True ) if prev_state and prev_state.membership == Membership.JOIN: @@ -749,7 +748,7 @@ class FederationHandler(BaseHandler): if newly_joined: user = UserID.from_string(event.state_key) - yield self.user_joined_room(user, room_id) + await self.user_joined_room(user, room_id) @log_function async def backfill(self, dest, room_id, limit, extremities): @@ -2899,7 +2898,7 @@ class FederationHandler(BaseHandler): room_id=room_id, user_id=user.to_string(), change="joined" ) else: - return user_joined_room(self.distributor, user, room_id) + return defer.succeed(user_joined_room(self.distributor, user, room_id)) @defer.inlineCallbacks def get_room_complexity(self, remote_room_hosts, room_id): From 5324bc20a628116aecf4781744cadd611f51d7a6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 10 Dec 2019 17:59:46 +0000 Subject: [PATCH 39/82] changelog --- changelog.d/6517.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6517.misc diff --git a/changelog.d/6517.misc b/changelog.d/6517.misc new file mode 100644 index 0000000000..c6ffed9952 --- /dev/null +++ b/changelog.d/6517.misc @@ -0,0 +1 @@ +Port some of FederationHandler to async/await. \ No newline at end of file From 20453565176cfd358212a23cf89dfd2deab1d690 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 11 Dec 2019 16:37:51 +0000 Subject: [PATCH 40/82] Add `include_event_in_state` to _get_state_for_room (#6521) Make it return the state *after* the requested event, rather than the one before it. This is a bit easier and requires fewer calls to get_events_from_store_or_dest. --- changelog.d/6521.misc | 1 + synapse/handlers/federation.py | 50 +++++++++++++++++++--------------- 2 files changed, 29 insertions(+), 22 deletions(-) create mode 100644 changelog.d/6521.misc diff --git a/changelog.d/6521.misc b/changelog.d/6521.misc new file mode 100644 index 0000000000..d9a44389b9 --- /dev/null +++ b/changelog.d/6521.misc @@ -0,0 +1 @@ +Refactor some code in the event authentication path for clarity. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bcd3b422aa..62985bab9f 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -378,22 +378,10 @@ class FederationHandler(BaseHandler): ( remote_state, got_auth_chain, - ) = await self._get_state_for_room(origin, room_id, p) - - # we want the state *after* p; _get_state_for_room returns the - # state *before* p. - remote_event = await self.federation_client.get_pdu( - [origin], p, room_version, outlier=True + ) = await self._get_state_for_room( + origin, room_id, p, include_event_in_state=True ) - if remote_event is None: - raise Exception( - "Unable to get missing prev_event %s" % (p,) - ) - - if remote_event.is_state(): - remote_state.append(remote_event) - # XXX hrm I'm not convinced that duplicate events will compare # for equality, so I'm not sure this does what the author # hoped. @@ -579,20 +567,25 @@ class FederationHandler(BaseHandler): else: raise - @log_function async def _get_state_for_room( - self, destination: str, room_id: str, event_id: str + self, + destination: str, + room_id: str, + event_id: str, + include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. Args: - destination:: The remote homeserver to query for the state. + destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. + include_event_in_state: if true, the event itself will be included in the + returned state event list. Returns: - A list of events in the state, and a list of events in the auth chain - for the given event. + A list of events in the state, possibly including the event itself, and + a list of events in the auth chain for the given event. """ ( state_event_ids, @@ -602,6 +595,10 @@ class FederationHandler(BaseHandler): ) desired_events = set(state_event_ids + auth_event_ids) + + if include_event_in_state: + desired_events.add(event_id) + event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -614,12 +611,21 @@ class FederationHandler(BaseHandler): failed_to_fetch, ) - pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] + remote_state = [ + event_map[e_id] for e_id in state_event_ids if e_id in event_map + ] + if include_event_in_state: + remote_event = event_map.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + if remote_event.is_state(): + remote_state.append(remote_event) + + auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) - return pdus, auth_chain + return remote_state, auth_chain async def _get_events_from_store_or_dest( self, destination: str, room_id: str, event_ids: Iterable[str] From cfcfb57e581bc1577a5fba3b34a1b1af7ccb6c0d Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 11 Dec 2019 17:27:46 +0000 Subject: [PATCH 41/82] Add new config param to docstring and add types --- synapse/server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/synapse/server.py b/synapse/server.py index a8b5459ff3..5021068ce0 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -34,6 +34,7 @@ from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler +from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory @@ -210,10 +211,11 @@ class HomeServer(object): # instantiated during setup() for future return by get_datastore() DATASTORE_CLASS = abc.abstractproperty() - def __init__(self, hostname, config, reactor=None, **kwargs): + def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs): """ Args: hostname : The hostname for the server. + config: The full config for the homeserver. """ if not reactor: from twisted.internet import reactor From 25f12443298f4995613a475ac304e84d50317e18 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 12 Dec 2019 12:57:45 +0000 Subject: [PATCH 42/82] Check the room_id of events when fetching room state/auth (#6524) When we request the state/auth_events to populate a backwards extremity (on backfill or in the case of missing events in a transaction push), we should check that the returned events are in the right room rather than blindly using them in the room state or auth chain. Given that _get_events_from_store_or_dest takes a room_id, it seems clear that it should be sanity-checking the room_id of the requested events, so let's do it there. --- changelog.d/6524.misc | 2 + synapse/handlers/federation.py | 78 +++++++++++++++++++++++----------- 2 files changed, 55 insertions(+), 25 deletions(-) create mode 100644 changelog.d/6524.misc diff --git a/changelog.d/6524.misc b/changelog.d/6524.misc new file mode 100644 index 0000000000..f885597426 --- /dev/null +++ b/changelog.d/6524.misc @@ -0,0 +1,2 @@ +Improve sanity-checking when receiving events over federation. + diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 62985bab9f..2ea69c5468 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -637,6 +637,10 @@ class FederationHandler(BaseHandler): room_id event_ids + If we fail to fetch any of the events, a warning will be logged, and the event + will be omitted from the result. Likewise, any events which turn out not to + be in the given room. + Returns: map from event_id to event """ @@ -644,35 +648,59 @@ class FederationHandler(BaseHandler): missing_events = set(event_ids) - fetched_events.keys() - if not missing_events: - return fetched_events + if missing_events: + logger.debug( + "Fetching unknown state/auth events %s for room %s", + missing_events, + room_id, + ) - logger.debug( - "Fetching unknown state/auth events %s for room %s", - missing_events, - event_ids, + room_version = await self.store.get_room_version(room_id) + + # XXX 20 requests at once? really? + for batch in batch_iter(missing_events, 20): + deferreds = [ + run_in_background( + self.federation_client.get_pdu, + destinations=[destination], + event_id=e_id, + room_version=room_version, + ) + for e_id in batch + ] + + res = await make_deferred_yieldable( + defer.DeferredList(deferreds, consumeErrors=True) + ) + + for success, result in res: + if success and result: + fetched_events[result.event_id] = result + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = list( + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id ) - room_version = await self.store.get_room_version(room_id) - - # XXX 20 requests at once? really? - for batch in batch_iter(missing_events, 20): - deferreds = [ - run_in_background( - self.federation_client.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = await make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned auth/state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, ) - for success, result in res: - if success and result: - fetched_events[result.event_id] = result + del fetched_events[bad_event_id] return fetched_events From 324d4f61b89d1fe16833480d5bd5b47d3b4ef0ba Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Dec 2019 14:52:11 +0000 Subject: [PATCH 43/82] Include more folders in mypy --- tox.ini | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 903a245fb0..2ae82d674f 100644 --- a/tox.ini +++ b/tox.ini @@ -177,5 +177,17 @@ env = MYPYPATH = stubs/ extras = all commands = mypy \ + synapse/config/ \ + synapse/handlers/ui_auth \ synapse/logging/ \ - synapse/config/ + synapse/module_api \ + synapse/rest/consent \ + synapse/rest/media/v0 \ + synapse/rest/saml2 \ + synapse/spam_checker_api \ + synapse/storage/engines \ + synapse/streams + +# To find all folders that pass mypy you run: +# +# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print From c965253e4b38b9d941793469ef64b014e5a25947 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Dec 2019 14:54:03 +0000 Subject: [PATCH 44/82] Newsfile --- changelog.d/6534.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6534.misc diff --git a/changelog.d/6534.misc b/changelog.d/6534.misc new file mode 100644 index 0000000000..7df6bb442a --- /dev/null +++ b/changelog.d/6534.misc @@ -0,0 +1 @@ +Test more folders against mypy. From 495005360cd37009818ad4214a5462c3dd0b15a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Dec 2019 15:21:12 +0000 Subject: [PATCH 45/82] Bump version of mypy --- synapse/config/emailconfig.py | 3 ++- synapse/config/ratelimiting.py | 3 +-- synapse/config/server.py | 2 +- synapse/logging/_terse_json.py | 2 +- synapse/logging/context.py | 3 +++ synapse/streams/events.py | 4 +++- tox.ini | 2 +- 7 files changed, 12 insertions(+), 7 deletions(-) diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 18f42a87f9..35756bed87 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -21,6 +21,7 @@ from __future__ import print_function import email.utils import os from enum import Enum +from typing import Optional import pkg_resources @@ -101,7 +102,7 @@ class EmailConfig(Config): # both in RegistrationConfig and here. We should factor this bit out self.account_threepid_delegate_email = self.trusted_third_party_id_servers[ 0 - ] + ] # type: Optional[str] self.using_identity_server_from_trusted_list = True else: raise ConfigError( diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 947f653e03..4a3bfc4354 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -83,10 +83,9 @@ class RatelimitConfig(Config): ) rc_admin_redaction = config.get("rc_admin_redaction") + self.rc_admin_redaction = None if rc_admin_redaction: self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction) - else: - self.rc_admin_redaction = None def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/server.py b/synapse/config/server.py index a4bef00936..50af858c76 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -200,7 +200,7 @@ class ServerConfig(Config): self.admin_contact = config.get("admin_contact", None) # FIXME: federation_domain_whitelist needs sytests - self.federation_domain_whitelist = None + self.federation_domain_whitelist = None # type: Optional[dict] federation_domain_whitelist = config.get("federation_domain_whitelist", None) if federation_domain_whitelist is not None: diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py index 03934956f4..c0b9384189 100644 --- a/synapse/logging/_terse_json.py +++ b/synapse/logging/_terse_json.py @@ -171,7 +171,7 @@ class LogProducer(object): def stopProducing(self): self._paused = True - self._buffer = None + self._buffer = deque() def resumeProducing(self): self._paused = False diff --git a/synapse/logging/context.py b/synapse/logging/context.py index 6747f29e6a..33b322209d 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -405,6 +405,9 @@ class LoggingContext(object): """ current = get_thread_resource_usage() + # Indicate to mypy that we know that self.usage_start is None. + assert self.usage_start is not None + utime_delta = current.ru_utime - self.usage_start.ru_utime stime_delta = current.ru_stime - self.usage_start.ru_stime diff --git a/synapse/streams/events.py b/synapse/streams/events.py index b91fb2db7b..fcd2aaa9c9 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict + from twisted.internet import defer from synapse.handlers.account_data import AccountDataEventSource @@ -35,7 +37,7 @@ class EventSources(object): def __init__(self, hs): self.sources = { name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items() - } + } # type: Dict[str, Any] self.store = hs.get_datastore() @defer.inlineCallbacks diff --git a/tox.ini b/tox.ini index 2ae82d674f..1d6428f64f 100644 --- a/tox.ini +++ b/tox.ini @@ -171,7 +171,7 @@ basepython = python3.7 skip_install = True deps = {[base]deps} - mypy==0.730 + mypy==0.750 mypy-zope env = MYPYPATH = stubs/ From 5056d6d90a58edc2a70d19233082f24fded4c73f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 12 Dec 2019 15:22:46 +0000 Subject: [PATCH 46/82] Newsfile --- changelog.d/6537.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6537.misc diff --git a/changelog.d/6537.misc b/changelog.d/6537.misc new file mode 100644 index 0000000000..3543153584 --- /dev/null +++ b/changelog.d/6537.misc @@ -0,0 +1 @@ +Update `mypy` to new version. From 5bfd8855d6b9ed8bcf28a107e6654c7cd7d3da2b Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 12 Dec 2019 15:53:49 +0000 Subject: [PATCH 47/82] Fix redacted events being returned in search results ordered by "recent" (#6522) --- changelog.d/6522.bugfix | 1 + synapse/storage/data_stores/main/search.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6522.bugfix diff --git a/changelog.d/6522.bugfix b/changelog.d/6522.bugfix new file mode 100644 index 0000000000..ccda96962f --- /dev/null +++ b/changelog.d/6522.bugfix @@ -0,0 +1 @@ +Prevent redacted events from being returned during message search. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py index dfb46ee0f8..47ebb8a214 100644 --- a/synapse/storage/data_stores/main/search.py +++ b/synapse/storage/data_stores/main/search.py @@ -385,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -501,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore): """ clauses = [] - search_query = search_query = _parse_query(self.database_engine, search_term) + search_query = _parse_query(self.database_engine, search_term) args = [] @@ -606,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self.get_events_as_list([r["event_id"] for r in results]) + # We set redact_behaviour to BLOCK here to prevent redacted events being returned in + # search results (which is a data leak) + events = yield self.get_events_as_list( + [r["event_id"] for r in results], + redact_behaviour=EventRedactBehaviour.BLOCK, + ) event_map = {ev.event_id: ev for ev in events} From cb2db179945f567410b565f29725dff28449f013 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 12 Dec 2019 12:03:28 -0500 Subject: [PATCH 48/82] look up cross-signing keys from the DB in bulk (#6486) --- changelog.d/6486.bugfix | 1 + synapse/handlers/e2e_keys.py | 35 ++- .../data_stores/main/end_to_end_keys.py | 217 +++++++++++++++++- synapse/util/caches/descriptors.py | 2 +- tests/handlers/test_e2e_keys.py | 8 - 5 files changed, 242 insertions(+), 21 deletions(-) create mode 100644 changelog.d/6486.bugfix diff --git a/changelog.d/6486.bugfix b/changelog.d/6486.bugfix new file mode 100644 index 0000000000..b98c5a9ae5 --- /dev/null +++ b/changelog.d/6486.bugfix @@ -0,0 +1 @@ +Improve performance of looking up cross-signing keys. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 57a10daefd..2d889364d4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -264,6 +264,7 @@ class E2eKeysHandler(object): return ret + @defer.inlineCallbacks def get_cross_signing_keys_from_cache(self, query, from_user_id): """Get cross-signing keys for users from the database @@ -283,14 +284,32 @@ class E2eKeysHandler(object): self_signing_keys = {} user_signing_keys = {} - # Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486 - return defer.succeed( - { - "master_keys": master_keys, - "self_signing_keys": self_signing_keys, - "user_signing_keys": user_signing_keys, - } - ) + user_ids = list(query) + + keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id) + + for user_id, user_info in keys.items(): + if user_info is None: + continue + if "master" in user_info: + master_keys[user_id] = user_info["master"] + if "self_signing" in user_info: + self_signing_keys[user_id] = user_info["self_signing"] + + if ( + from_user_id in keys + and keys[from_user_id] is not None + and "user_signing" in keys[from_user_id] + ): + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @trace @defer.inlineCallbacks diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py index 38cd0ca9b8..e551606f9d 100644 --- a/synapse/storage/data_stores/main/end_to_end_keys.py +++ b/synapse/storage/data_stores/main/end_to_end_keys.py @@ -14,15 +14,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + from six import iteritems from canonicaljson import encode_canonical_json, json +from twisted.enterprise.adbapi import Connection from twisted.internet import defer from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList class EndToEndKeyWorkerStore(SQLBaseStore): @@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): Args: txn (twisted.enterprise.adbapi.Connection): db connection user_id (str): the user whose key is being requested - key_type (str): the type of key that is being set: either 'master' + key_type (str): the type of key that is being requested: either 'master' for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on @@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): """Returns a user's cross-signing key. Args: - user_id (str): the user whose self-signing key is being requested - key_type (str): the type of cross-signing key to get + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being requested: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key from_user_id (str): if specified, signatures made by this user on the self-signing key will be included in the result @@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore): from_user_id, ) + @cached(num_args=1) + def _get_bare_e2e_cross_signing_keys(self, user_id): + """Dummy function. Only used to make a cache for + _get_bare_e2e_cross_signing_keys_bulk. + """ + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_bare_e2e_cross_signing_keys", + list_name="user_ids", + num_args=1, + ) + def _get_bare_e2e_cross_signing_keys_bulk( + self, user_ids: List[str] + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + + """ + return self.db.runInteraction( + "get_bare_e2e_cross_signing_keys_bulk", + self._get_bare_e2e_cross_signing_keys_bulk_txn, + user_ids, + ) + + def _get_bare_e2e_cross_signing_keys_bulk_txn( + self, txn: Connection, user_ids: List[str], + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing keys for a set of users. The output of this + function should be passed to _get_e2e_cross_signing_signatures_txn if + the signatures for the calling user need to be fetched. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_ids (list[str]): the users whose keys are being requested + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. If a user's cross-signing keys were not found, their user + ID will not be in the dict. + + """ + result = {} + + batch_size = 100 + chunks = [ + user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT k.user_id, k.keytype, k.keydata, k.stream_id + FROM e2e_cross_signing_keys k + INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id + FROM e2e_cross_signing_keys + GROUP BY user_id, keytype) s + USING (user_id, stream_id, keytype) + WHERE k.user_id IN (%s) + """ % ( + ",".join("?" for u in user_chunk), + ) + query_params = [] + query_params.extend(user_chunk) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + for row in rows: + user_id = row["user_id"] + key_type = row["keytype"] + key = json.loads(row["keydata"]) + user_info = result.setdefault(user_id, {}) + user_info[key_type] = key + + return result + + def _get_e2e_cross_signing_signatures_txn( + self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str, + ) -> Dict[str, Dict[str, dict]]: + """Returns the cross-signing signatures made by a user on a set of keys. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + keys (dict[str, dict[str, dict]]): a map of user ID to key type to + key data. This dict will be modified to add signatures. + from_user_id (str): fetch the signatures made by this user + + Returns: + dict[str, dict[str, dict]]: mapping from user ID to key type to key + data. The return value will be the same as the keys argument, + with the modifications included. + """ + + # find out what cross-signing keys (a.k.a. devices) we need to get + # signatures for. This is a map of (user_id, device_id) to key type + # (device_id is the key's public part). + devices = {} + + for user_id, user_info in keys.items(): + if user_info is None: + continue + for key_type, key in user_info.items(): + device_id = None + for k in key["keys"].values(): + device_id = k + devices[(user_id, device_id)] = key_type + + device_list = list(devices) + + # split into batches + batch_size = 100 + chunks = [ + device_list[i : i + batch_size] + for i in range(0, len(device_list), batch_size) + ] + for user_chunk in chunks: + sql = """ + SELECT target_user_id, target_device_id, key_id, signature + FROM e2e_cross_signing_signatures + WHERE user_id = ? + AND (%s) + """ % ( + " OR ".join( + "(target_user_id = ? AND target_device_id = ?)" for d in devices + ) + ) + query_params = [from_user_id] + for item in devices: + # item is a (user_id, device_id) tuple + query_params.extend(item) + + txn.execute(sql, query_params) + rows = self.db.cursor_to_dict(txn) + + # and add the signatures to the appropriate keys + for row in rows: + key_id = row["key_id"] + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + key_type = devices[(target_user_id, target_device_id)] + # We need to copy everything, because the result may have come + # from the cache. dict.copy only does a shallow copy, so we + # need to recursively copy the dicts that will be modified. + user_info = keys[target_user_id] = keys[target_user_id].copy() + target_user_key = user_info[key_type] = user_info[key_type].copy() + if "signatures" in target_user_key: + signatures = target_user_key["signatures"] = target_user_key[ + "signatures" + ].copy() + if from_user_id in signatures: + user_sigs = signatures[from_user_id] = signatures[from_user_id] + user_sigs[key_id] = row["signature"] + else: + signatures[from_user_id] = {key_id: row["signature"]} + else: + target_user_key["signatures"] = { + from_user_id: {key_id: row["signature"]} + } + + return keys + + @defer.inlineCallbacks + def get_e2e_cross_signing_keys_bulk( + self, user_ids: List[str], from_user_id: str = None + ) -> defer.Deferred: + """Returns the cross-signing keys for a set of users. + + Args: + user_ids (list[str]): the users whose keys are being requested + from_user_id (str): if specified, signatures made by this user on + the self-signing keys will be included in the result + + Returns: + Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to + key data. If a user's cross-signing keys were not found, either + their user ID will not be in the dict, or their user ID will map + to None. + """ + + result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids) + + if from_user_id: + result = yield self.db.runInteraction( + "get_e2e_cross_signing_signatures", + self._get_e2e_cross_signing_signatures_txn, + result, + from_user_id, + ) + + return result + def get_all_user_signature_changes_for_remotes(self, from_key, to_key): """Return a list of changes from the user signature stream to notify remotes. Note that the user signature stream represents when a user signs their @@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): }, ) + self._invalidate_cache_and_stream( + txn, self._get_bare_e2e_cross_signing_keys, (user_id,) + ) + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 84f5ae22c3..2e8f6543e5 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -271,7 +271,7 @@ class _CacheDescriptorBase(object): else: self.function_to_call = orig - arg_spec = inspect.getargspec(orig) + arg_spec = inspect.getfullargspec(orig) all_args = arg_spec.args if "cache_context" in all_args: diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index fdfa2cbbc4..854eb6c024 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - test_replace_master_key.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) - @defer.inlineCallbacks def test_reupload_signatures(self): """re-uploading a signature should not fail""" @@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ], other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - - test_upload_signatures.skip = ( - "Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486" - ) From 4ce05ec1716f757eb15c02e615ea9c84cb289b77 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 13 Dec 2019 10:15:20 +0000 Subject: [PATCH 49/82] Adjust the sytest blacklist for worker mode (#6538) Remove tests that got blacklisted while torturing was enabled, and add one that fails. --- .buildkite/worker-blacklist | 31 +++---------------------------- changelog.d/6538.misc | 1 + 2 files changed, 4 insertions(+), 28 deletions(-) create mode 100644 changelog.d/6538.misc diff --git a/.buildkite/worker-blacklist b/.buildkite/worker-blacklist index 7950d19db3..158ab79154 100644 --- a/.buildkite/worker-blacklist +++ b/.buildkite/worker-blacklist @@ -34,33 +34,8 @@ Device list doesn't change if remote server is down Remote servers cannot set power levels in rooms without existing powerlevels Remote servers should reject attempts by non-creators to set the power levels -# new failures as of https://github.com/matrix-org/sytest/pull/753 -GET /rooms/:room_id/messages returns a message -GET /rooms/:room_id/messages lazy loads members correctly -Read receipts are sent as events -Only original members of the room can see messages from erased users -Device deletion propagates over federation -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes -Changing user-signing key notifies local users -Newly updated tags appear in an incremental v2 /sync +# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5 Server correctly handles incoming m.device_list_update -Local device key changes get to remote servers with correct prev_id -AS-ghosted users can use rooms via AS -Ghost user must register before joining room -Test that a message is pushed -Invites are pushed -Rooms with aliases are correctly named in pushed -Rooms with names are correctly named in pushed -Rooms with canonical alias are correctly named in pushed -Rooms with many users are correctly pushed -Don't get pushed for rooms you've muted -Rejected events are not pushed -Test that rejected pushers are removed. -Events come down the correct room -# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603 -Presence changes to UNAVAILABLE are reported to remote room members -If remote user leaves room, changes device and rejoins we see update in sync -uploading self-signing key notifies over federation -Inbound federation can receive redacted events -Outbound federation can request missing events +# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536 +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state diff --git a/changelog.d/6538.misc b/changelog.d/6538.misc new file mode 100644 index 0000000000..cb4fd56948 --- /dev/null +++ b/changelog.d/6538.misc @@ -0,0 +1 @@ +Adjust the sytest blacklist for worker mode. From 971a0702b5fce743c8bb61424a5f1002d3eb63ff Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 13 Dec 2019 11:44:41 +0000 Subject: [PATCH 50/82] Sanity-check room ids in event auth (#6530) When we do an event auth operation, check that all of the events involved are in the right room. --- changelog.d/6530.misc | 2 ++ synapse/event_auth.py | 12 ++++++++++++ 2 files changed, 14 insertions(+) create mode 100644 changelog.d/6530.misc diff --git a/changelog.d/6530.misc b/changelog.d/6530.misc new file mode 100644 index 0000000000..f885597426 --- /dev/null +++ b/changelog.d/6530.misc @@ -0,0 +1,2 @@ +Improve sanity-checking when receiving events over federation. + diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c940b84470..80ec911b3d 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -50,6 +50,18 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru if not hasattr(event, "room_id"): raise AuthError(500, "Event has no room_id: %s" % event) + room_id = event.room_id + + # I'm not really expecting to get auth events in the wrong room, but let's + # sanity-check it + for auth_event in auth_events.values(): + if auth_event.room_id != room_id: + raise Exception( + "During auth for event %s in room %s, found event %s in the state " + "which is in room %s" + % (event.event_id, room_id, auth_event.event_id, auth_event.room_id) + ) + if do_sig_check: sender_domain = get_domain_from_id(event.sender) From 1da15f05f5c9c1e47c9fd1323caff869c2e55aa3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 13 Dec 2019 12:55:32 +0000 Subject: [PATCH 51/82] sanity-checking for events used in state res (#6531) When we perform state resolution, check that all of the events involved are in the right room. --- changelog.d/6531.misc | 1 + synapse/handlers/federation.py | 1 + synapse/state/__init__.py | 32 +++++++---- synapse/state/v1.py | 34 +++++++++-- synapse/state/v2.py | 100 ++++++++++++++++++++++++--------- tests/state/test_v2.py | 3 + 6 files changed, 128 insertions(+), 43 deletions(-) create mode 100644 changelog.d/6531.misc diff --git a/changelog.d/6531.misc b/changelog.d/6531.misc new file mode 100644 index 0000000000..598efb79fc --- /dev/null +++ b/changelog.d/6531.misc @@ -0,0 +1 @@ +Improve sanity-checking when receiving events over federation. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ea69c5468..1d39a9a4f5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -396,6 +396,7 @@ class FederationHandler(BaseHandler): event_map[x.event_id] = x state_map = await resolve_events_with_store( + room_id, room_version, state_maps, event_map, diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 3e6d62eef1..5accc071ab 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Iterable, Optional +from typing import Dict, Iterable, List, Optional, Tuple from six import iteritems, itervalues @@ -417,6 +417,7 @@ class StateHandler(object): with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + event.room_id, room_version, state_set_ids, event_map=state_map, @@ -462,7 +463,7 @@ class StateResolutionHandler(object): not be called for a single state group Args: - room_id (str): room we are resolving for (used for logging) + room_id (str): room we are resolving for (used for logging and sanity checks) room_version (str): version of the room state_groups_ids (dict[int, dict[(str, str), str]]): map from state group id to the state in that state group @@ -518,6 +519,7 @@ class StateResolutionHandler(object): logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): new_state = yield resolve_events_with_store( + room_id, room_version, list(itervalues(state_groups_ids)), event_map=event_map, @@ -589,36 +591,44 @@ def _make_state_cache_entry(new_state, state_groups_ids): ) -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "StateResolutionStore", +): """ Args: - room_version(str): Version of the room + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: Version of the room + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing events will be requested via state_map_factory. - If None, all events will be fetched via state_map_factory. + If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: a place to fetch events from - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ v = KNOWN_ROOM_VERSIONS[room_version] if v.state_res == StateResolutionVersions.V1: return v1.resolve_events_with_store( - state_sets, event_map, state_res_store.get_events + room_id, state_sets, event_map, state_res_store.get_events ) else: return v2.resolve_events_with_store( - room_version, state_sets, event_map, state_res_store + room_id, room_version, state_sets, event_map, state_res_store ) diff --git a/synapse/state/v1.py b/synapse/state/v1.py index a2f92d9ff9..b2f9865f39 100644 --- a/synapse/state/v1.py +++ b/synapse/state/v1.py @@ -15,6 +15,7 @@ import hashlib import logging +from typing import Callable, Dict, List, Optional, Tuple from six import iteritems, iterkeys, itervalues @@ -24,6 +25,7 @@ from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions +from synapse.events import EventBase logger = logging.getLogger(__name__) @@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "") @defer.inlineCallbacks -def resolve_events_with_store(state_sets, event_map, state_map_factory): +def resolve_events_with_store( + room_id: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_map_factory: Callable, +): """ Args: - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_id: the room we are working in + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): If None, all events will be fetched via state_map_factory. - state_map_factory(func): will be called + state_map_factory: will be called with a list of event_ids that are needed, and should return with a Deferred of dict of event_id to event. - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): if event_map is not None: state_map.update(event_map) + # everything in the state map should be in the right room + for event in state_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + # get the ids of the auth events which allow us to authenticate the # conflicted state, picking only from the unconflicting state. # @@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory): ) state_map_new = yield state_map_factory(new_needed_events) + for event in state_map_new.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + state_map.update(state_map_new) return _resolve_with_state( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index b327c86f40..cb77ed5b78 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -16,29 +16,40 @@ import heapq import itertools import logging +from typing import Dict, List, Optional, Tuple from six import iteritems, itervalues from twisted.internet import defer +import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError +from synapse.events import EventBase logger = logging.getLogger(__name__) @defer.inlineCallbacks -def resolve_events_with_store(room_version, state_sets, event_map, state_res_store): +def resolve_events_with_store( + room_id: str, + room_version: str, + state_sets: List[Dict[Tuple[str, str], str]], + event_map: Optional[Dict[str, EventBase]], + state_res_store: "synapse.state.StateResolutionStore", +): """Resolves the state using the v2 state resolution algorithm Args: - room_version (str): The room version + room_id: the room we are working in - state_sets(list): List of dicts of (type, state_key) -> event_id, + room_version: The room version + + state_sets: List of dicts of (type, state_key) -> event_id, which are the different state groups to resolve. - event_map(dict[str,FrozenEvent]|None): + event_map: a dict from event_id to event, for any events that we happen to have in flight (eg, those currently being persisted). This will be used as a starting point fof finding the state we need; any missing @@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto If None, all events will be fetched via state_res_store. - state_res_store (StateResolutionStore) + state_res_store: - Returns + Returns: Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id. """ @@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) event_map.update(events) + # everything in the event map should be in the right room + for event in event_map.values(): + if event.room_id != room_id: + raise Exception( + "Attempting to state-resolve for room %s with event %s which is in %s" + % (room_id, event.event_id, event.room_id,) + ) + full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map) logger.debug("%d full_conflicted_set entries", len(full_conflicted_set)) @@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto ) sorted_power_events = yield _reverse_topological_power_sort( - power_events, event_map, state_res_store, full_conflicted_set + room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one resolved_state = yield _iterative_auth_checks( + room_id, room_version, sorted_power_events, unconflicted_state, @@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto pl = resolved_state.get((EventTypes.PowerLevels, ""), None) leftover_events = yield _mainline_sort( - leftover_events, pl, event_map, state_res_store + room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") resolved_state = yield _iterative_auth_checks( - room_version, leftover_events, resolved_state, event_map, state_res_store + room_id, + room_version, + leftover_events, + resolved_state, + event_map, + state_res_store, ) logger.debug("resolved") @@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto @defer.inlineCallbacks -def _get_power_level_for_sender(event_id, event_map, state_res_store): +def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(event_id, event_map, state_res_store) + event = yield _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): pl = aev break @@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev.content.get("creator") == event.sender: return 100 @@ -279,7 +305,7 @@ def _is_power_event(event): @defer.inlineCallbacks def _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event and its auth chain (that is in the auth diff) to the graph @@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph( Args: graph (dict[str, set[str]]): A map from event ID to the events auth event IDs + room_id (str): the room we are working in event_id (str): Event to add to the graph event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(eid, event_map, state_res_store) + event = yield _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph( @defer.inlineCallbacks -def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff): +def _reverse_topological_power_sort( + room_id, event_ids, event_map, state_res_store, auth_diff +): """Returns a list of the event_ids sorted by reverse topological ordering, and then by power level and origin_server_ts Args: + room_id (str): the room we are working in event_ids (list[str]): The events to sort event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ graph = {} for event_id in event_ids: yield _add_event_and_auth_chain_to_graph( - graph, event_id, event_map, state_res_store, auth_diff + graph, room_id, event_id, event_map, state_res_store, auth_diff ) event_to_pl = {} for event_id in graph: - pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store) + pl = yield _get_power_level_for_sender( + room_id, event_id, event_map, state_res_store + ) event_to_pl[event_id] = pl def _get_power_order(event_id): @@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_ @defer.inlineCallbacks def _iterative_auth_checks( - room_version, event_ids, base_state, event_map, state_res_store + room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the state as it goes along. Args: + room_id (str) room_version (str) event_ids (list[str]): Ordered list of events to apply auth checks to base_state (dict[tuple[str, str], str]): The set of state to start with @@ -370,7 +403,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if ev.rejected_reason is None: auth_events[(ev.type, ev.state_key)] = ev @@ -378,7 +411,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(ev_id, event_map, state_res_store) + ev = yield _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -400,11 +433,14 @@ def _iterative_auth_checks( @defer.inlineCallbacks -def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store): +def _mainline_sort( + room_id, event_ids, resolved_power_event_id, event_map, state_res_store +): """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id Args: + room_id (str): room we're working in event_ids (list[str]): Events to sort resolved_power_event_id (str): The final resolved power level event ID event_map (dict[str,FrozenEvent]) @@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor pl = resolved_power_event_id while pl: mainline.append(pl) - pl_ev = yield _get_event(pl, event_map, state_res_store) + pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event(aid, event_map, state_res_store) + ev = yield _get_event(room_id, aid, event_map, state_res_store) if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break @@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor Deferred[int] """ + room_id = event.room_id + # We do an iterative search, replacing `event with the power level in its # auth events (if any) while event: @@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event(aid, event_map, state_res_store) + aev = yield _get_event(room_id, aid, event_map, state_res_store) if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): event = aev break @@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor @defer.inlineCallbacks -def _get_event(event_id, event_map, state_res_store): +def _get_event(room_id, event_id, event_map, state_res_store): """Helper function to look up event in event_map, falling back to looking it up in the store Args: + room_id (str) event_id (str) event_map (dict[str,FrozenEvent]) state_res_store (StateResolutionStore) @@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store): if event_id not in event_map: events = yield state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) - return event_map[event_id] + event = event_map[event_id] + assert event is not None + if event.room_id != room_id: + raise Exception( + "In state res for room %s, event %s is in %s" + % (room_id, event_id, event.room_id) + ) + return event def lexicographical_topological_sort(graph, key): diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..0f341d3ac3 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, From 0b90fc6ed22e6ebb137041a1f5006f52cea081e4 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Fri, 13 Dec 2019 15:28:48 +0000 Subject: [PATCH 52/82] Document Shutdown Room admin API (#6541) --- changelog.d/6541.doc | 1 + docs/admin_api/shutdown_room.md | 72 +++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) create mode 100644 changelog.d/6541.doc create mode 100644 docs/admin_api/shutdown_room.md diff --git a/changelog.d/6541.doc b/changelog.d/6541.doc new file mode 100644 index 0000000000..c20029edc0 --- /dev/null +++ b/changelog.d/6541.doc @@ -0,0 +1 @@ +Document the Room Shutdown Admin API. \ No newline at end of file diff --git a/docs/admin_api/shutdown_room.md b/docs/admin_api/shutdown_room.md new file mode 100644 index 0000000000..54ce1cd234 --- /dev/null +++ b/docs/admin_api/shutdown_room.md @@ -0,0 +1,72 @@ +# Shutdown room API + +Shuts down a room, preventing new joins and moves local users and room aliases automatically +to a new room. The new room will be created with the user specified by the +`new_room_user_id` parameter as room administrator and will contain a message +explaining what happened. Users invited to the new room will have power level +-10 by default, and thus be unable to speak. The old room's power levels will be changed to +disallow any further invites or joins. + +The local server will only have the power to move local user and room aliases to +the new room. Users on other servers will be unaffected. + +## API + +You will need to authenticate with an access token for an admin user. + +### URL + +`POST /_synapse/admin/v1/shutdown_room/{room_id}` + +### URL Parameters + +* `room_id` - The ID of the room (e.g `!someroom:example.com`) + +### JSON Body Parameters + +* `new_room_user_id` - Required. A string representing the user ID of the user that will admin + the new room that all users in the old room will be moved to. +* `room_name` - Optional. A string representing the name of the room that new users will be + invited to. +* `message` - Optional. A string containing the first message that will be sent as + `new_room_user_id` in the new room. Ideally this will clearly convey why the + original room was shut down. + +If not specified, the default value of `room_name` is "Content Violation +Notification". The default value of `message` is "Sharing illegal content on +othis server is not permitted and rooms in violation will be blocked." + +### Response Parameters + +* `kicked_users` - An integer number representing the number of users that + were kicked. +* `failed_to_kick_users` - An integer number representing the number of users + that were not kicked. +* `local_aliases` - An array of strings representing the local aliases that were migrated from + the old room to the new. +* `new_room_id` - A string representing the room ID of the new room. + +## Example + +Request: + +``` +POST /_synapse/admin/v1/shutdown_room/!somebadroom%3Aexample.com + +{ + "new_room_user_id": "@someuser:example.com", + "room_name": "Content Violation Notification", + "message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service." +} +``` + +Response: + +``` +{ + "kicked_users": 5, + "failed_to_kick_users": 0, + "local_aliases": ["#badroom:example.com", "#evilsaloon:example.com], + "new_room_id": "!newroomid:example.com", +}, +``` From 9d173b312cc3ab170de3a7c58ac24778eddf93f6 Mon Sep 17 00:00:00 2001 From: Werner Sembach Date: Mon, 16 Dec 2019 13:12:40 +0100 Subject: [PATCH 53/82] Automatically delete empty groups/communities (#6453) Signed-off-by: Werner Sembach --- AUTHORS.rst | 3 ++ changelog.d/6453.feature | 1 + synapse/groups/groups_server.py | 5 ++++ .../56/nuke_empty_communities_from_db.sql | 29 +++++++++++++++++++ 4 files changed, 38 insertions(+) create mode 100644 changelog.d/6453.feature create mode 100644 synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql diff --git a/AUTHORS.rst b/AUTHORS.rst index b8b31a5b47..014f16d4a2 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -46,3 +46,6 @@ Joseph Weston Benjamin Saunders * Documentation improvements + +Werner Sembach + * Automatically remove a group/community when it is empty diff --git a/changelog.d/6453.feature b/changelog.d/6453.feature new file mode 100644 index 0000000000..e7bb801c6a --- /dev/null +++ b/changelog.d/6453.feature @@ -0,0 +1 @@ +Automatically delete empty groups/communities. diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 29e8ffc295..0ec9be3cb5 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -773,6 +773,11 @@ class GroupsServerHandler(object): if not self.hs.is_mine_id(user_id): yield self.store.maybe_delete_remote_profile_cache(user_id) + # Delete group if the last user has left + users = yield self.store.get_users_in_group(group_id, include_private=True) + if not users: + yield self.store.delete_group(group_id) + return {} @defer.inlineCallbacks diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql new file mode 100644 index 0000000000..4f24c1405d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql @@ -0,0 +1,29 @@ +/* Copyright 2019 Werner Sembach + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made. +DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); +DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id)); From bc7de87650c2646cdc388f06a2a5bb6f94fc5c5c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Dec 2019 12:26:28 +0000 Subject: [PATCH 54/82] Persist auth/state events at backwards extremities when we fetch them (#6526) The main point here is to make sure that the state returned by _get_state_in_room has been authed before we try to use it as state in the room. --- changelog.d/6526.bugfix | 1 + synapse/handlers/federation.py | 247 +++++++++++---------------------- synapse/util/async_helpers.py | 4 +- 3 files changed, 83 insertions(+), 169 deletions(-) create mode 100644 changelog.d/6526.bugfix diff --git a/changelog.d/6526.bugfix b/changelog.d/6526.bugfix new file mode 100644 index 0000000000..53214b0748 --- /dev/null +++ b/changelog.d/6526.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 1d39a9a4f5..3f480f2056 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -65,8 +65,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRes from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import UserID, get_domain_from_id -from synapse.util import batch_iter, unwrapFirstError -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.distributor import user_joined_room from synapse.util.retryutils import NotRetryingDestination from synapse.visibility import filter_events_for_server @@ -238,7 +237,6 @@ class FederationHandler(BaseHandler): return None state = None - auth_chain = [] # Get missing pdus if necessary. if not pdu.internal_metadata.is_outlier(): @@ -348,7 +346,6 @@ class FederationHandler(BaseHandler): # Calculate the state after each of the previous events, and # resolve them to find the correct state at the current event. - auth_chains = set() event_map = {event_id: pdu} try: # Get the state of the events we know about @@ -369,24 +366,14 @@ class FederationHandler(BaseHandler): "Requesting state at missing prev_event %s", event_id, ) - room_version = await self.store.get_room_version(room_id) - with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - ( - remote_state, - got_auth_chain, - ) = await self._get_state_for_room( + (remote_state, _,) = await self._get_state_for_room( origin, room_id, p, include_event_in_state=True ) - # XXX hrm I'm not convinced that duplicate events will compare - # for equality, so I'm not sure this does what the author - # hoped. - auth_chains.update(got_auth_chain) - remote_state_map = { (x.type, x.state_key): x.event_id for x in remote_state } @@ -395,6 +382,7 @@ class FederationHandler(BaseHandler): for x in remote_state: event_map[x.event_id] = x + room_version = await self.store.get_room_version(room_id) state_map = await resolve_events_with_store( room_id, room_version, @@ -416,7 +404,6 @@ class FederationHandler(BaseHandler): event_map.update(evs) state = [event_map[e] for e in six.itervalues(state_map)] - auth_chain = list(auth_chains) except Exception: logger.warning( "[%s %s] Error attempting to resolve state at missing " @@ -432,9 +419,7 @@ class FederationHandler(BaseHandler): affected=event_id, ) - await self._process_received_pdu( - origin, pdu, state=state, auth_chain=auth_chain - ) + await self._process_received_pdu(origin, pdu, state=state) async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): """ @@ -633,10 +618,7 @@ class FederationHandler(BaseHandler): ) -> Dict[str, EventBase]: """Fetch events from a remote destination, checking if we already have them. - Args: - destination - room_id - event_ids + Persists any events we don't already have as outliers. If we fail to fetch any of the events, a warning will be logged, and the event will be omitted from the result. Likewise, any events which turn out not to @@ -656,27 +638,15 @@ class FederationHandler(BaseHandler): room_id, ) - room_version = await self.store.get_room_version(room_id) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) - # XXX 20 requests at once? really? - for batch in batch_iter(missing_events, 20): - deferreds = [ - run_in_background( - self.federation_client.get_pdu, - destinations=[destination], - event_id=e_id, - room_version=room_version, - ) - for e_id in batch - ] - - res = await make_deferred_yieldable( - defer.DeferredList(deferreds, consumeErrors=True) - ) - - for success, result in res: - if success and result: - fetched_events[result.event_id] = result + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_events, allow_rejected=True)) + ) # check for events which were in the wrong room. # @@ -705,50 +675,26 @@ class FederationHandler(BaseHandler): return fetched_events - async def _process_received_pdu(self, origin, event, state, auth_chain): + async def _process_received_pdu( + self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]], + ): """ Called when we have a new pdu. We need to do auth checks and put it through the StateHandler. + + Args: + origin: server sending the event + + event: event to be persisted + + state: Normally None, but if we are handling a gap in the graph + (ie, we are missing one or more prev_events), the resolved state at the + event """ room_id = event.room_id event_id = event.event_id logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) - event_ids = set() - if state: - event_ids |= {e.event_id for e in state} - if auth_chain: - event_ids |= {e.event_id for e in auth_chain} - - seen_ids = await self.store.have_seen_events(event_ids) - - if state and auth_chain is not None: - # If we have any state or auth_chain given to us by the replication - # layer, then we should handle them (if we haven't before.) - - event_infos = [] - - for e in itertools.chain(auth_chain, state): - if e.event_id in seen_ids: - continue - e.internal_metadata.outlier = True - auth_ids = e.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in auth_chain - if e.event_id in auth_ids or e.type == EventTypes.Create - } - event_infos.append(_NewEventInfo(event=e, auth_events=auth)) - seen_ids.add(e.event_id) - - logger.info( - "[%s %s] persisting newly-received auth/state events %s", - room_id, - event_id, - [e.event.event_id for e in event_infos], - ) - await self._handle_new_events(origin, event_infos) - try: context = await self._handle_new_event(origin, event, state=state) except AuthError as e: @@ -803,8 +749,6 @@ class FederationHandler(BaseHandler): if dest == self.server_name: raise SynapseError(400, "Can't backfill from self.") - room_version = await self.store.get_room_version(room_id) - events = await self.federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -833,6 +777,9 @@ class FederationHandler(BaseHandler): event_ids = set(e.event_id for e in events) + # build a list of events whose prev_events weren't in the batch. + # (XXX: this will include events whose prev_events we already have; that doesn't + # sound right?) edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) @@ -861,95 +808,11 @@ class FederationHandler(BaseHandler): auth_events.update( {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} ) - missing_auth = required_auth - set(auth_events) - failed_to_fetch = set() - # Try and fetch any missing auth events from both DB and remote servers. - # We repeatedly do this until we stop finding new auth events. - while missing_auth - failed_to_fetch: - logger.info("Missing auth for backfill: %r", missing_auth) - ret_events = await self.store.get_events(missing_auth - failed_to_fetch) - auth_events.update(ret_events) - - required_auth.update( - a_id for event in ret_events.values() for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - if missing_auth - failed_to_fetch: - logger.info( - "Fetching missing auth for backfill: %r", - missing_auth - failed_to_fetch, - ) - - results = await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.federation_client.get_pdu, - [dest], - event_id, - room_version=room_version, - outlier=True, - timeout=10000, - ) - for event_id in missing_auth - failed_to_fetch - ], - consumeErrors=True, - ) - ).addErrback(unwrapFirstError) - auth_events.update({a.event_id: a for a in results if a}) - required_auth.update( - a_id - for event in results - if event - for a_id in event.auth_event_ids() - ) - missing_auth = required_auth - set(auth_events) - - failed_to_fetch = missing_auth - set(auth_events) - - seen_events = await self.store.have_seen_events( - set(auth_events.keys()) | set(state_events.keys()) - ) - - # We now have a chunk of events plus associated state and auth chain to - # persist. We do the persistence in two steps: - # 1. Auth events and state get persisted as outliers, plus the - # backward extremities get persisted (as non-outliers). - # 2. The rest of the events in the chunk get persisted one by one, as - # each one depends on the previous event for its state. - # - # The important thing is that events in the chunk get persisted as - # non-outliers, including when those events are also in the state or - # auth chain. Caution must therefore be taken to ensure that they are - # not accidentally marked as outliers. - - # Step 1a: persist auth events that *don't* appear in the chunk ev_infos = [] - for a in auth_events.values(): - # We only want to persist auth events as outliers that we haven't - # seen and aren't about to persist as part of the backfilled chunk. - if a.event_id in seen_events or a.event_id in event_map: - continue - a.internal_metadata.outlier = True - ev_infos.append( - _NewEventInfo( - event=a, - auth_events={ - ( - auth_events[a_id].type, - auth_events[a_id].state_key, - ): auth_events[a_id] - for a_id in a.auth_event_ids() - if a_id in auth_events - }, - ) - ) - - # Step 1b: persist the events in the chunk we fetched state for (i.e. - # the backwards extremities) as non-outliers. + # Step 1: persist the events in the chunk we fetched state for (i.e. + # the backwards extremities), with custom auth events and state for e_id in events_to_state: # For paranoia we ensure that these events are marked as # non-outliers @@ -1191,6 +1054,56 @@ class FederationHandler(BaseHandler): return False + async def _get_events_and_persist( + self, destination: str, room_id: str, events: Iterable[str] + ): + """Fetch the given events from a server, and persist them as outliers. + + Logs a warning if we can't find the given event. + """ + + room_version = await self.store.get_room_version(room_id) + + event_infos = [] + + async def get_event(event_id: str): + with nested_logging_context(event_id): + try: + event = await self.federation_client.get_pdu( + [destination], event_id, room_version, outlier=True, + ) + if event is None: + logger.warning( + "Server %s didn't return event %s", destination, event_id, + ) + return + + # recursively fetch the auth events for this event + auth_events = await self._get_events_from_store_or_dest( + destination, room_id, event.auth_event_ids() + ) + auth = {} + for auth_event_id in event.auth_event_ids(): + ae = auth_events.get(auth_event_id) + if ae: + auth[(ae.type, ae.state_key)] = ae + + event_infos.append(_NewEventInfo(event, None, auth)) + + except Exception as e: + logger.warning( + "Error fetching missing state/auth event %s: %s %s", + event_id, + type(e), + e, + ) + + await concurrently_execute(get_event, events, 5) + + await self._handle_new_events( + destination, event_infos, + ) + def _sanity_check_event(self, ev): """ Do some early sanity checks of a received event diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 5c4de2e69f..04b6abdc24 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit): Args: func (func): Function to execute, should return a deferred or coroutine. - args (list): List of arguments to pass to func, each invocation of func - gets a signle argument. + args (Iterable): List of arguments to pass to func, each invocation of func + gets a single argument. limit (int): Maximum number of conccurent executions. Returns: From 6920d88892e77aec787b6afc0e01e6e09dc36216 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 16 Dec 2019 13:14:37 +0000 Subject: [PATCH 55/82] Exclude rejected state events when calculating state at backwards extrems (#6527) This fixes a weird bug where, if you were determined enough, you could end up with a rejected event forming part of the state at a backwards-extremity. Authing that backwards extrem would then lead to us trying to pull the rejected event from the db (with allow_rejected=False), which would fail with a 404. --- changelog.d/6527.bugfix | 1 + synapse/handlers/federation.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6527.bugfix diff --git a/changelog.d/6527.bugfix b/changelog.d/6527.bugfix new file mode 100644 index 0000000000..53214b0748 --- /dev/null +++ b/changelog.d/6527.bugfix @@ -0,0 +1 @@ +Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. \ No newline at end of file diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3f480f2056..3fccccfecd 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -605,7 +605,7 @@ class FederationHandler(BaseHandler): remote_event = event_map.get(event_id) if not remote_event: raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state(): + if remote_event.is_state() and remote_event.rejected_reason is None: remote_state.append(remote_event) auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] From bfb95654c97a8d3aa164eff96ecc13755c1c326d Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Mon, 16 Dec 2019 16:11:55 +0000 Subject: [PATCH 56/82] Add option to allow profile queries without sharing a room (#6523) --- changelog.d/6523.feature | 1 + docs/sample_config.yaml | 7 +++++++ synapse/config/server.py | 13 +++++++++++++ synapse/handlers/profile.py | 6 +++++- tests/rest/client/v1/test_profile.py | 2 ++ 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6523.feature diff --git a/changelog.d/6523.feature b/changelog.d/6523.feature new file mode 100644 index 0000000000..798fa143df --- /dev/null +++ b/changelog.d/6523.feature @@ -0,0 +1 @@ +Add option `limit_profile_requests_to_users_who_share_rooms` to prevent requirement of a local user sharing a room with another user to query their profile information. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 4d44e631d1..1787248f53 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -54,6 +54,13 @@ pid_file: DATADIR/homeserver.pid # #require_auth_for_profile_requests: true +# Uncomment to require a user to share a room with another user in order +# to retrieve their profile information. Only checked on Client-Server +# requests. Profile requests from other servers should be checked by the +# requesting server. Defaults to 'false'. +# +#limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/config/server.py b/synapse/config/server.py index 50af858c76..38f6ff9edc 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -102,6 +102,12 @@ class ServerConfig(Config): "require_auth_for_profile_requests", False ) + # Whether to require sharing a room with a user to retrieve their + # profile data + self.limit_profile_requests_to_users_who_share_rooms = config.get( + "limit_profile_requests_to_users_who_share_rooms", False, + ) + if "restrict_public_rooms_to_local_users" in config and ( "allow_public_rooms_without_auth" in config or "allow_public_rooms_over_federation" in config @@ -621,6 +627,13 @@ class ServerConfig(Config): # #require_auth_for_profile_requests: true + # Uncomment to require a user to share a room with another user in order + # to retrieve their profile information. Only checked on Client-Server + # requests. Profile requests from other servers should be checked by the + # requesting server. Defaults to 'false'. + # + #limit_profile_requests_to_users_who_share_rooms: true + # If set to 'true', removes the need for authentication to access the server's # public rooms directory through the client API, meaning that anyone can # query the room directory. Defaults to 'false'. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 1e5a4613c9..f9579d69ee 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -295,12 +295,16 @@ class BaseProfileHandler(BaseHandler): be found to be in any room the server is in, and therefore the query is denied. """ + # Implementation of MSC1301: don't allow looking up profiles if the # requester isn't in the same room as the target. We expect requester to # be None when this function is called outside of a profile query, e.g. # when building a membership event. In this case, we must allow the # lookup. - if not self.hs.config.require_auth_for_profile_requests or not requester: + if ( + not self.hs.config.limit_profile_requests_to_users_who_share_rooms + or not requester + ): return # Always allow the user to query their own profile. diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 12c5e95cb5..8df58b4a63 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -237,6 +237,7 @@ class ProfilesRestrictedTestCase(unittest.HomeserverTestCase): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs @@ -309,6 +310,7 @@ class OwnProfileUnrestrictedTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): config = self.default_config() config["require_auth_for_profile_requests"] = True + config["limit_profile_requests_to_users_who_share_rooms"] = True self.hs = self.setup_test_homeserver(config=config) return self.hs From 3fbe5b7ec3abd2864d8a64893fa494e9651c430a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 16 Dec 2019 16:59:32 +0000 Subject: [PATCH 57/82] Add auth events as per spec. (#6556) Previously we tried to be clever and filter out some unnecessary event IDs to keep the auth chain small, but that had some annoying interactions with state res v2 so we stop doing that for now. --- changelog.d/6556.bugfix | 1 + synapse/api/auth.py | 103 ++++++++++++++-------------------------- 2 files changed, 36 insertions(+), 68 deletions(-) create mode 100644 changelog.d/6556.bugfix diff --git a/changelog.d/6556.bugfix b/changelog.d/6556.bugfix new file mode 100644 index 0000000000..e75639f5b4 --- /dev/null +++ b/changelog.d/6556.bugfix @@ -0,0 +1 @@ +Fix a cause of state resets in room versions 2 onwards. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 5d0b7d2801..9fd52a8c77 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import Dict, Tuple from six import itervalues @@ -25,13 +26,7 @@ from twisted.internet import defer import synapse.logging.opentracing as opentracing import synapse.types from synapse import event_auth -from synapse.api.constants import ( - EventTypes, - JoinRules, - LimitBlockingTypes, - Membership, - UserTypes, -) +from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes from synapse.api.errors import ( AuthError, Codes, @@ -513,71 +508,43 @@ class Auth(object): """ return self.store.is_server_admin(user) - @defer.inlineCallbacks - def compute_auth_events(self, event, current_state_ids, for_verification=False): + def compute_auth_events( + self, + event, + current_state_ids: Dict[Tuple[str, str], str], + for_verification: bool = False, + ): + """Given an event and current state return the list of event IDs used + to auth an event. + + If `for_verification` is False then only return auth events that + should be added to the event's `auth_events`. + + Returns: + defer.Deferred(list[str]): List of event IDs. + """ + if event.type == EventTypes.Create: - return [] + return defer.succeed([]) + + # Currently we ignore the `for_verification` flag even though there are + # some situations where we can drop particular auth events when adding + # to the event's `auth_events` (e.g. joins pointing to previous joins + # when room is publically joinable). Dropping event IDs has the + # advantage that the auth chain for the room grows slower, but we use + # the auth chain in state resolution v2 to order events, which means + # care must be taken if dropping events to ensure that it doesn't + # introduce undesirable "state reset" behaviour. + # + # All of which sounds a bit tricky so we don't bother for now. auth_ids = [] + for etype, state_key in event_auth.auth_types_for_event(event): + auth_ev_id = current_state_ids.get((etype, state_key)) + if auth_ev_id: + auth_ids.append(auth_ev_id) - key = (EventTypes.PowerLevels, "") - power_level_event_id = current_state_ids.get(key) - - if power_level_event_id: - auth_ids.append(power_level_event_id) - - key = (EventTypes.JoinRules, "") - join_rule_event_id = current_state_ids.get(key) - - key = (EventTypes.Member, event.sender) - member_event_id = current_state_ids.get(key) - - key = (EventTypes.Create, "") - create_event_id = current_state_ids.get(key) - if create_event_id: - auth_ids.append(create_event_id) - - if join_rule_event_id: - join_rule_event = yield self.store.get_event(join_rule_event_id) - join_rule = join_rule_event.content.get("join_rule") - is_public = join_rule == JoinRules.PUBLIC if join_rule else False - else: - is_public = False - - if event.type == EventTypes.Member: - e_type = event.content["membership"] - if e_type in [Membership.JOIN, Membership.INVITE]: - if join_rule_event_id: - auth_ids.append(join_rule_event_id) - - if e_type == Membership.JOIN: - if member_event_id and not is_public: - auth_ids.append(member_event_id) - else: - if member_event_id: - auth_ids.append(member_event_id) - - if for_verification: - key = (EventTypes.Member, event.state_key) - existing_event_id = current_state_ids.get(key) - if existing_event_id: - auth_ids.append(existing_event_id) - - if e_type == Membership.INVITE: - if "third_party_invite" in event.content: - key = ( - EventTypes.ThirdPartyInvite, - event.content["third_party_invite"]["signed"]["token"], - ) - third_party_invite_id = current_state_ids.get(key) - if third_party_invite_id: - auth_ids.append(third_party_invite_id) - elif member_event_id: - member_event = yield self.store.get_event(member_event_id) - if member_event.content["membership"] == Membership.JOIN: - auth_ids.append(member_event.event_id) - - return auth_ids + return defer.succeed(auth_ids) @defer.inlineCallbacks def check_can_change_room_list(self, room_id, user): From 02553901ce94461c6f140efc804443069b97f401 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 17 Dec 2019 11:44:32 +0000 Subject: [PATCH 58/82] Remove unused `get_pagination_rows` methods. (#6557) Remove unused get_pagination_rows methods --- changelog.d/6557.misc | 1 + synapse/handlers/account_data.py | 3 --- synapse/handlers/room.py | 12 ------------ synapse/handlers/typing.py | 3 --- 4 files changed, 1 insertion(+), 18 deletions(-) create mode 100644 changelog.d/6557.misc diff --git a/changelog.d/6557.misc b/changelog.d/6557.misc new file mode 100644 index 0000000000..80e7eaedb8 --- /dev/null +++ b/changelog.d/6557.misc @@ -0,0 +1 @@ +Remove unused `get_pagination_rows` methods from `EventSource` classes. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 20ec1ca01b..a8d3fbc6de 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -50,6 +50,3 @@ class AccountDataEventSource(object): ) return results, current_stream_id - - async def get_pagination_rows(self, user, config, key): - return [], config.to_id diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 22768e97ff..2d7925547d 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1008,15 +1008,3 @@ class RoomEventSource(object): def get_current_key_for_room(self, room_id): return self.store.get_room_events_max_id(room_id) - - @defer.inlineCallbacks - def get_pagination_rows(self, user, config, key): - events, next_key = yield self.store.paginate_room_events( - room_id=key, - from_key=config.from_key, - to_key=config.to_key, - direction=config.direction, - limit=config.limit, - ) - - return (events, next_key) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6f78454322..b635c339ed 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -317,6 +317,3 @@ class TypingNotificationEventSource(object): def get_current_key(self): return self.get_typing_handler()._latest_room_serial - - def get_pagination_rows(self, user, pagination_config, key): - return [], pagination_config.from_key From 2284eb3a533a2df04784df08da28e67d6588a5ea Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 18 Dec 2019 10:45:12 +0000 Subject: [PATCH 59/82] Add database config class (#6513) This encapsulates config for a given database and is the way to get new connections. --- changelog.d/6513.misc | 1 + scripts-dev/update_database | 9 +-- scripts/synapse_port_db | 58 ++++++-------- synapse/config/database.py | 78 +++++++++++++++---- synapse/handlers/presence.py | 2 +- synapse/server.py | 41 +--------- synapse/storage/_base.py | 2 +- synapse/storage/data_stores/__init__.py | 40 ++++++++-- .../storage/data_stores/main/client_ips.py | 2 +- synapse/storage/database.py | 45 ++++++++++- synapse/storage/engines/sqlite.py | 16 +++- synapse/storage/prepare_database.py | 7 +- tests/handlers/test_typing.py | 41 +++++----- tests/replication/slave/storage/_base.py | 6 +- tests/server.py | 51 ++++++------ tests/storage/test_appservice.py | 37 ++++++--- tests/storage/test_base.py | 14 ++-- tests/storage/test_registration.py | 1 - tests/utils.py | 43 ++++------ 19 files changed, 286 insertions(+), 208 deletions(-) create mode 100644 changelog.d/6513.misc diff --git a/changelog.d/6513.misc b/changelog.d/6513.misc new file mode 100644 index 0000000000..36700f5657 --- /dev/null +++ b/changelog.d/6513.misc @@ -0,0 +1 @@ +Remove all assumptions of there being a single phyiscal DB apart from the `synapse.config`. diff --git a/scripts-dev/update_database b/scripts-dev/update_database index 23017c21f8..1d62f0403a 100755 --- a/scripts-dev/update_database +++ b/scripts-dev/update_database @@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore -from synapse.storage.prepare_database import prepare_database logger = logging.getLogger("update_database") @@ -77,12 +76,8 @@ if __name__ == "__main__": # Instantiate and initialise the homeserver object. hs = MockHomeserver(config) - db_conn = hs.get_db_conn() - # Update the database to the latest schema. - prepare_database(db_conn, hs.database_engine, config=config) - db_conn.commit() - - # setup instantiates the store within the homeserver object. + # Setup instantiates the store within the homeserver object and updates the + # DB. hs.setup() store = hs.get_datastore() diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index e393a9b2f7..5b5368988c 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -30,6 +30,7 @@ import yaml from twisted.enterprise import adbapi from twisted.internet import defer, reactor +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import PreserveLoggingContext from synapse.storage._base import LoggingTransaction @@ -55,7 +56,7 @@ from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database from synapse.util import Clock @@ -165,23 +166,17 @@ class Store( class MockHomeserver: - def __init__(self, config, database_engine, db_conn, db_pool): - self.database_engine = database_engine - self.db_conn = db_conn - self.db_pool = db_pool + def __init__(self, config): self.clock = Clock(reactor) self.config = config self.hostname = config.server_name - def get_db_conn(self): - return self.db_conn - - def get_db_pool(self): - return self.db_pool - def get_clock(self): return self.clock + def get_reactor(self): + return reactor + class Porter(object): def __init__(self, **kwargs): @@ -445,45 +440,36 @@ class Porter(object): else: return - def setup_db(self, db_config, database_engine): - db_conn = database_engine.module.connect( - **{ - k: v - for k, v in db_config.get("args", {}).items() - if not k.startswith("cp_") - } - ) - - prepare_database(db_conn, database_engine, config=None) + def setup_db(self, db_config: DatabaseConnectionConfig, engine): + db_conn = make_conn(db_config, engine) + prepare_database(db_conn, engine, config=None) db_conn.commit() return db_conn @defer.inlineCallbacks - def build_db_store(self, config): + def build_db_store(self, db_config: DatabaseConnectionConfig): """Builds and returns a database store using the provided configuration. Args: - config: The database configuration, i.e. a dict following the structure of - the "database" section of Synapse's configuration file. + config: The database configuration Returns: The built Store object. """ - engine = create_engine(config) + self.progress.set_state("Preparing %s" % db_config.config["name"]) - self.progress.set_state("Preparing %s" % config["name"]) - conn = self.setup_db(config, engine) + engine = create_engine(db_config.config) + conn = self.setup_db(db_config, engine) - db_pool = adbapi.ConnectionPool(config["name"], **config["args"]) + hs = MockHomeserver(self.hs_config) - hs = MockHomeserver(self.hs_config, engine, conn, db_pool) - - store = Store(Database(hs), conn, hs) + store = Store(Database(hs, db_config, engine), conn, hs) yield store.db.runInteraction( - "%s_engine.check_database" % config["name"], engine.check_database, + "%s_engine.check_database" % db_config.config["name"], + engine.check_database, ) return store @@ -509,7 +495,11 @@ class Porter(object): @defer.inlineCallbacks def run(self): try: - self.sqlite_store = yield self.build_db_store(self.sqlite_config) + self.sqlite_store = yield self.build_db_store( + DatabaseConnectionConfig( + "master", self.sqlite_config, data_stores=["main"] + ) + ) # Check if all background updates are done, abort if not. updates_complete = ( @@ -524,7 +514,7 @@ class Porter(object): defer.returnValue(None) self.postgres_store = yield self.build_db_store( - self.hs_config.database_config + self.hs_config.get_single_database() ) yield self.run_background_updates_on_postgres() diff --git a/synapse/config/database.py b/synapse/config/database.py index 0e2509f0b1..5f2f3c7cfd 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -12,12 +12,43 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from textwrap import indent +from typing import List import yaml -from ._base import Config +from synapse.config._base import Config, ConfigError + +logger = logging.getLogger(__name__) + + +class DatabaseConnectionConfig: + """Contains the connection config for a particular database. + + Args: + name: A label for the database, used for logging. + db_config: The config for a particular database, as per `database` + section of main config. Has two fields: `name` for database + module name, and `args` for the args to give to the database + connector. + data_stores: The list of data stores that should be provisioned on the + database. + """ + + def __init__(self, name: str, db_config: dict, data_stores: List[str]): + if db_config["name"] not in ("sqlite3", "psycopg2"): + raise ConfigError("Unsupported database type %r" % (db_config["name"],)) + + if db_config["name"] == "sqlite3": + db_config.setdefault("args", {}).update( + {"cp_min": 1, "cp_max": 1, "check_same_thread": False} + ) + + self.name = name + self.config = db_config + self.data_stores = data_stores class DatabaseConfig(Config): @@ -26,20 +57,14 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - self.database_config = config.get("database") + database_config = config.get("database") - if self.database_config is None: - self.database_config = {"name": "sqlite3", "args": {}} + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} - name = self.database_config.get("name", None) - if name == "psycopg2": - pass - elif name == "sqlite3": - self.database_config.setdefault("args", {}).update( - {"cp_min": 1, "cp_max": 1, "check_same_thread": False} - ) - else: - raise RuntimeError("Unsupported database type '%s'" % (name,)) + self.databases = [ + DatabaseConnectionConfig("master", database_config, data_stores=["main"]) + ] self.set_databasepath(config.get("database_path")) @@ -76,11 +101,24 @@ class DatabaseConfig(Config): self.set_databasepath(args.database_path) def set_databasepath(self, database_path): + if database_path is None: + return + if database_path != ":memory:": database_path = self.abspath(database_path) - if self.database_config.get("name", None) == "sqlite3": - if database_path is not None: - self.database_config["args"]["database"] = database_path + + # We only support setting a database path if we have a single sqlite3 + # database. + if len(self.databases) != 1: + raise ConfigError("Cannot specify 'database_path' with multiple databases") + + database = self.get_single_database() + if database.config["name"] != "sqlite3": + # We don't raise here as we haven't done so before for this case. + logger.warn("Ignoring 'database_path' for non-sqlite3 database") + return + + database.config["args"]["database"] = database_path @staticmethod def add_arguments(parser): @@ -91,3 +129,11 @@ class DatabaseConfig(Config): metavar="SQLITE_DATABASE_PATH", help="The path to a sqlite database to use.", ) + + def get_single_database(self) -> DatabaseConnectionConfig: + """Returns the database if there is only one, useful for e.g. tests + """ + if len(self.databases) != 1: + raise Exception("More than one database exists") + + return self.databases[0] diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index eda15bc623..240c4add12 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -230,7 +230,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.store.database.is_running(): return logger.info( diff --git a/synapse/server.py b/synapse/server.py index 5021068ce0..7926867b77 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -25,7 +25,6 @@ import abc import logging import os -from twisted.enterprise import adbapi from twisted.mail.smtp import sendmail from twisted.web.client import BrowserLikePolicyForHTTPS @@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import ( ) from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import DataStores, Storage -from synapse.storage.engines import create_engine from synapse.streams.events import EventSources from synapse.util import Clock from synapse.util.distributor import Distributor @@ -134,7 +132,6 @@ class HomeServer(object): DEPENDENCIES = [ "http_client", - "db_pool", "federation_client", "federation_server", "handlers", @@ -233,12 +230,6 @@ class HomeServer(object): self.admin_redaction_ratelimiter = Ratelimiter() self.registration_ratelimiter = Ratelimiter() - self.database_engine = create_engine(config.database_config) - config.database_config.setdefault("args", {})[ - "cp_openfun" - ] = self.database_engine.on_new_connection - self.db_config = config.database_config - self.datastores = None # Other kwargs are explicit dependencies @@ -247,10 +238,8 @@ class HomeServer(object): def setup(self): logger.info("Setting up.") - with self.get_db_conn() as conn: - self.datastores = DataStores(self.DATASTORE_CLASS, conn, self) - conn.commit() self.start_time = int(self.get_clock().time()) + self.datastores = DataStores(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") def setup_master(self): @@ -284,6 +273,9 @@ class HomeServer(object): def get_datastore(self): return self.datastores.main + def get_datastores(self): + return self.datastores + def get_config(self): return self.config @@ -433,31 +425,6 @@ class HomeServer(object): ) return MatrixFederationHttpClient(self, tls_client_options_factory) - def build_db_pool(self): - name = self.db_config["name"] - - return adbapi.ConnectionPool( - name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {}) - ) - - def get_db_conn(self, run_new_connection=True): - """Makes a new connection to the database, skipping the db pool - - Returns: - Connection: a connection object implementing the PEP-249 spec - """ - # Any param beginning with cp_ is a parameter for adbapi, and should - # not be passed to the database engine. - db_params = { - k: v - for k, v in self.db_config.get("args", {}).items() - if not k.startswith("cp_") - } - db_conn = self.database_engine.module.connect(**db_params) - if run_new_connection: - self.database_engine.on_new_connection(db_conn) - return db_conn - def build_media_repository_resource(self): # build the media repo resource. This indirects through the HomeServer # to ensure that we only have a single instance of diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b7637b5dc0..88546ad614 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -40,7 +40,7 @@ class SQLBaseStore(object): def __init__(self, database: Database, db_conn, hs): self.hs = hs self._clock = hs.get_clock() - self.database_engine = hs.database_engine + self.database_engine = database.engine self.db = database self.rand = random.SystemRandom() diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index cafedd5c0d..0983e059c0 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -13,24 +13,50 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import Database +import logging + +from synapse.storage.database import Database, make_conn +from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database +logger = logging.getLogger(__name__) + class DataStores(object): """The various data stores. These are low level interfaces to physical databases. + + Attributes: + main (DataStore) """ - def __init__(self, main_store_class, db_conn, hs): + def __init__(self, main_store_class, hs): # Note we pass in the main store class here as workers use a different main # store. - database = Database(hs) - # Check that db is correctly configured. - database.engine.check_database(db_conn.cursor()) + self.databases = [] - prepare_database(db_conn, database.engine, config=hs.config) + for database_config in hs.config.database.databases: + db_name = database_config.name + engine = create_engine(database_config.config) - self.main = main_store_class(database, db_conn, hs) + with make_conn(database_config, engine) as db_conn: + logger.info("Preparing database %r...", db_name) + + engine.check_database(db_conn.cursor()) + prepare_database( + db_conn, engine, hs.config, data_stores=database_config.data_stores, + ) + + database = Database(hs, database_config, engine) + + if "main" in database_config.data_stores: + logger.info("Starting 'main' data store") + self.main = main_store_class(database, db_conn, hs) + + db_conn.commit() + + self.databases.append(database) + + logger.info("Database %r prepared", db_name) diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index add3037b69..13f4c9c72e 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -412,7 +412,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def _update_client_ips_batch(self): # If the DB pool has already terminated, don't try updating - if not self.hs.get_db_pool().running: + if not self.db.is_running(): return to_update = self._batch_row_update diff --git a/synapse/storage/database.py b/synapse/storage/database.py index ec19ae1d9d..1003dd84a5 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -24,9 +24,11 @@ from six.moves import intern, range from prometheus_client import Histogram +from twisted.enterprise import adbapi from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.config.database import DatabaseConnectionConfig from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater @@ -74,6 +76,37 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = { } +def make_pool( + reactor, db_config: DatabaseConnectionConfig, engine +) -> adbapi.ConnectionPool: + """Get the connection pool for the database. + """ + + return adbapi.ConnectionPool( + db_config.config["name"], + cp_reactor=reactor, + cp_openfun=engine.on_new_connection, + **db_config.config.get("args", {}) + ) + + +def make_conn(db_config: DatabaseConnectionConfig, engine): + """Make a new connection to the database and return it. + + Returns: + Connection + """ + + db_params = { + k: v + for k, v in db_config.config.get("args", {}).items() + if not k.startswith("cp_") + } + db_conn = engine.module.connect(**db_params) + engine.on_new_connection(db_conn) + return db_conn + + class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() @@ -218,10 +251,11 @@ class Database(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, hs, database_config: DatabaseConnectionConfig, engine): self.hs = hs self._clock = hs.get_clock() - self._db_pool = hs.get_db_pool() + self._database_config = database_config + self._db_pool = make_pool(hs.get_reactor(), database_config, engine) self.updates = BackgroundUpdater(hs, self) @@ -234,7 +268,7 @@ class Database(object): # to watch it self._txn_perf_counters = PerformanceCounters() - self.engine = hs.database_engine + self.engine = engine # A set of tables that are not safe to use native upserts in. self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys()) @@ -255,6 +289,11 @@ class Database(object): self._check_safe_to_upsert, ) + def is_running(self): + """Is the database pool currently running + """ + return self._db_pool.running + @defer.inlineCallbacks def _check_safe_to_upsert(self): """ diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index ddad17dc5a..df039a072d 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -16,8 +16,6 @@ import struct import threading -from synapse.storage.prepare_database import prepare_database - class Sqlite3Engine(object): single_threaded = True @@ -25,6 +23,9 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) + # The current max state_group, or None if we haven't looked # in the DB yet. self._current_state_group_id = None @@ -59,7 +60,16 @@ class Sqlite3Engine(object): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) def is_deadlock(self, error): diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 731e1c9d9c..b4194b44ee 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -41,7 +41,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config): +def prepare_database(db_conn, database_engine, config, data_stores=["main"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. @@ -54,11 +54,10 @@ def prepare_database(db_conn, database_engine, config): config (synapse.config.homeserver.HomeServerConfig|None): application config, or None if we are connecting to an existing database which we expect to be configured already + data_stores (list[str]): The name of the data stores that will be used + with this database. Defaults to all data stores. """ - # For now we only have the one datastore. - data_stores = ["main"] - try: cur = db_conn.cursor() version_info = _get_or_create_schema_state(cur, database_engine) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 92b8726093..596ddc6970 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -64,28 +64,29 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): mock_federation_client = Mock(spec=["put_json"]) mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) - hs = self.setup_test_homeserver( - datastore=( - Mock( - spec=[ - # Bits that Federation needs - "prep_send_transaction", - "delivered_txn", - "get_received_txn_response", - "set_received_txn_response", - "get_destination_retry_timings", - "get_device_updates_by_remote", - # Bits that user_directory needs - "get_user_directory_stream_pos", - "get_current_state_deltas", - ] - ) - ), - notifier=Mock(), - http_client=mock_federation_client, - keyring=mock_keyring, + datastores = Mock() + datastores.main = Mock( + spec=[ + # Bits that Federation needs + "prep_send_transaction", + "delivered_txn", + "get_received_txn_response", + "set_received_txn_response", + "get_destination_retry_timings", + "get_devices_by_remote", + # Bits that user_directory needs + "get_user_directory_stream_pos", + "get_current_state_deltas", + "get_device_updates_by_remote", + ] ) + hs = self.setup_test_homeserver( + notifier=Mock(), http_client=mock_federation_client, keyring=mock_keyring + ) + + hs.datastores = datastores + return hs def prepare(self, reactor, clock, hs): diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 3dae83c543..2a1e7c7166 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -20,7 +20,7 @@ from synapse.replication.tcp.client import ( ReplicationClientHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory -from synapse.storage.database import Database +from synapse.storage.database import make_conn from tests import unittest from tests.server import FakeTransport @@ -41,10 +41,12 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): + db_config = hs.config.database.get_single_database() self.master_store = self.hs.get_datastore() self.storage = hs.get_storage() + database = hs.get_datastores().databases[0] self.slaved_store = self.STORE_TYPE( - Database(hs), self.hs.get_db_conn(), self.hs + database, make_conn(db_config, database.engine), self.hs ) self.event_id = 0 diff --git a/tests/server.py b/tests/server.py index 2b7cf4242e..a554dfdd57 100644 --- a/tests/server.py +++ b/tests/server.py @@ -302,41 +302,42 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs): Set up a synchronous test server, driven by the reactor used by the homeserver. """ - d = _sth(cleanup_func, *args, **kwargs).result + server = _sth(cleanup_func, *args, **kwargs) - if isinstance(d, Failure): - d.raiseException() + database = server.config.database.get_single_database() # Make the thread pool synchronous. - clock = d.get_clock() - pool = d.get_db_pool() + clock = server.get_clock() - def runWithConnection(func, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runWithConnection, - func, - *args, - **kwargs - ) + for database in server.get_datastores().databases: + pool = database._db_pool - def runInteraction(interaction, *args, **kwargs): - return threads.deferToThreadPool( - pool._reactor, - pool.threadpool, - pool._runInteraction, - interaction, - *args, - **kwargs - ) + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) - if pool: pool.runWithConnection = runWithConnection pool.runInteraction = runInteraction pool.threadpool = ThreadPool(clock._reactor) pool.running = True - return d + + return server def get_clock(): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 2e521e9ab7..fd52512696 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -28,7 +28,7 @@ from synapse.storage.data_stores.main.appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore, ) -from synapse.storage.database import Database +from synapse.storage.database import Database, make_conn from tests import unittest from tests.utils import setup_test_homeserver @@ -55,8 +55,10 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob") self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts - database = Database(hs) - self.store = ApplicationServiceStore(database, hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + self.store = ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) def tearDown(self): # TODO: suboptimal that we need to create files for tests! @@ -111,9 +113,6 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - self.db_pool = hs.get_db_pool() - self.engine = hs.database_engine - self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, {"token": "alpha_tok", "url": "https://alpha.com", "id": "id_alpha"}, @@ -125,8 +124,15 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.as_yaml_files = [] - database = Database(hs) - self.store = TestTransactionStore(database, hs.get_db_conn(), hs) + # We assume there is only one database in these tests + database = hs.get_datastores().databases[0] + self.db_pool = database._db_pool + self.engine = database.engine + + db_config = hs.config.get_single_database() + self.store = TestTransactionStore( + database, make_conn(db_config, self.engine), hs + ) def _add_service(self, url, as_token, id): as_yaml = dict( @@ -419,7 +425,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.event_cache_size = 1 hs.config.password_providers = [] - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) @defer.inlineCallbacks def test_duplicate_ids(self): @@ -435,7 +444,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) @@ -456,7 +468,10 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: - ApplicationServiceStore(Database(hs), hs.get_db_conn(), hs) + database = hs.get_datastores().databases[0] + ApplicationServiceStore( + database, make_conn(database._database_config, database.engine), hs + ) e = cm.exception self.assertIn(f1, str(e)) diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 537cfe9f64..cdee0a9e60 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -52,15 +52,17 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True config.event_cache_size = 1 - config.database_config = {"name": "sqlite3"} - engine = create_engine(config.database_config) + hs = TestHomeServer("test", config=config) + + sqlite_config = {"name": "sqlite3"} + engine = create_engine(sqlite_config) fake_engine = Mock(wraps=engine) fake_engine.can_native_upsert = False - hs = TestHomeServer( - "test", db_pool=self.db_pool, config=config, database_engine=fake_engine - ) - self.datastore = SQLBaseStore(Database(hs), None, hs) + db = Database(Mock(), Mock(config=sqlite_config), fake_engine) + db._db_pool = self.db_pool + + self.datastore = SQLBaseStore(db, None, hs) @defer.inlineCallbacks def test_insert_1col(self): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 4578cc3b60..ed5786865a 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -26,7 +26,6 @@ class RegistrationStoreTestCase(unittest.TestCase): @defer.inlineCallbacks def setUp(self): hs = yield setup_test_homeserver(self.addCleanup) - self.db_pool = hs.get_db_pool() self.store = hs.get_datastore() diff --git a/tests/utils.py b/tests/utils.py index 585f305b9a..9f5bf40b4b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,6 +30,7 @@ from twisted.internet import defer, reactor from synapse.api.constants import EventTypes from synapse.api.errors import CodeMessageException, cs_error from synapse.api.room_versions import RoomVersions +from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server @@ -177,7 +178,6 @@ class TestHomeServer(HomeServer): DATASTORE_CLASS = DataStore -@defer.inlineCallbacks def setup_test_homeserver( cleanup_func, name="test", @@ -214,7 +214,7 @@ def setup_test_homeserver( if USE_POSTGRES_FOR_TESTS: test_db = "synapse_test_%s" % uuid.uuid4().hex - config.database_config = { + database_config = { "name": "psycopg2", "args": { "database": test_db, @@ -226,12 +226,15 @@ def setup_test_homeserver( }, } else: - config.database_config = { + database_config = { "name": "sqlite3", "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - db_engine = create_engine(config.database_config) + database = DatabaseConnectionConfig("master", database_config, ["main"]) + config.database.databases = [database] + + db_engine = create_engine(database.config) # Create the database before we actually try and connect to it, based off # the template database we generate in setupdb() @@ -251,11 +254,6 @@ def setup_test_homeserver( cur.close() db_conn.close() - # we need to configure the connection pool to run the on_new_connection - # function, so that we can test code that uses custom sqlite functions - # (like rank). - config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection - if datastore is None: hs = homeserverToUse( name, @@ -267,21 +265,19 @@ def setup_test_homeserver( **kargs ) - # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to - # date db - if not isinstance(db_engine, PostgresEngine): - db_conn = hs.get_db_conn() - yield prepare_database(db_conn, db_engine, config) - db_conn.commit() - db_conn.close() + hs.setup() + if homeserverToUse.__name__ == "TestHomeServer": + hs.setup_master() + + if isinstance(db_engine, PostgresEngine): + database = hs.get_datastores().databases[0] - else: # We need to do cleanup on PostgreSQL def cleanup(): import psycopg2 # Close all the db pools - hs.get_db_pool().close() + database._db_pool.close() dropped = False @@ -320,23 +316,12 @@ def setup_test_homeserver( # Register the cleanup hook cleanup_func(cleanup) - hs.setup() - if homeserverToUse.__name__ == "TestHomeServer": - hs.setup_master() else: - # If we have been given an explicit datastore we probably want to mock - # out the DataStores somehow too. This all feels a bit wrong, but then - # mocking the stores feels wrong too. - datastores = Mock(datastore=datastore) - hs = homeserverToUse( name, - db_pool=None, datastore=datastore, - datastores=datastores, config=config, version_string="Synapse/tests", - database_engine=db_engine, tls_server_context_factory=Mock(), tls_client_options_factory=Mock(), reactor=reactor, From 7963ca83cbefb782a94c47fd65ad6e94d05dc5d1 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 18 Dec 2019 11:13:33 +0000 Subject: [PATCH 60/82] Add delta file to fix missing default table data (#6555) --- changelog.d/6555.bugfix | 1 + .../storage/data_stores/main/deviceinbox.py | 17 ++-------------- .../delta/56/device_stream_id_insert.sql | 20 +++++++++++++++++++ .../full_schemas/54/stream_positions.sql | 1 + 4 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 changelog.d/6555.bugfix create mode 100644 synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql diff --git a/changelog.d/6555.bugfix b/changelog.d/6555.bugfix new file mode 100644 index 0000000000..86a5a56cf6 --- /dev/null +++ b/changelog.d/6555.bugfix @@ -0,0 +1 @@ +Fix missing row in device_max_stream_id that could cause unable to decrypt errors after server restart. \ No newline at end of file diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py index 85cfa16850..0613b49f4a 100644 --- a/synapse/storage/data_stores/main/deviceinbox.py +++ b/synapse/storage/data_stores/main/deviceinbox.py @@ -358,21 +358,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): - # Compatible method of performing an upsert - sql = "SELECT stream_id FROM device_max_stream_id" - - txn.execute(sql) - rows = txn.fetchone() - if rows: - db_stream_id = rows[0] - if db_stream_id < stream_id: - # Insert the new stream_id - sql = "UPDATE device_max_stream_id SET stream_id = ?" - else: - # No rows, perform an insert - sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)" - - txn.execute(sql, (stream_id,)) + sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?" + txn.execute(sql, (stream_id, stream_id)) local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql new file mode 100644 index 0000000000..c2f557fde9 --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql @@ -0,0 +1,20 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- This line already existed in deltas/35/device_stream_id but was not included in the +-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist +INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS ( + SELECT * from device_max_stream_id +); \ No newline at end of file diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql index c265fd20e2..91d21b2921 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql @@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales INSERT INTO user_directory_stream_pos (stream_id) VALUES (0); INSERT INTO stats_stream_pos (stream_id) VALUES (0); INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); +-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql \ No newline at end of file From d6752ce5da38d35857fe324800d76a86ee1e64f1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 18 Dec 2019 14:26:58 +0000 Subject: [PATCH 61/82] Clean up startup for the pusher (#6558) * Remove redundant python2 support code `str.decode()` doesn't exist on python3, so presumably this code was doing nothing * Filter out pushers with corrupt data When we get a row with unparsable json, drop the row, rather than returning a row with null `data`, which will then cause an explosion later on. * Improve logging when we can't start a pusher Log the ID to help us understand the problem * Make email pusher setup more robust We know we'll have a `data` member, since that comes from the database. What we *don't* know is if that is a dict, and if that has a `brand` member, and if that member is a string. --- changelog.d/6558.misc | 1 + synapse/push/pusher.py | 12 +++++---- synapse/push/pusherpool.py | 10 ++++--- synapse/rest/client/v1/pusher.py | 31 +++++++++++----------- synapse/storage/data_stores/main/pusher.py | 25 ++++++----------- tests/push/test_email.py | 3 +++ tests/push/test_http.py | 4 +++ 7 files changed, 44 insertions(+), 42 deletions(-) create mode 100644 changelog.d/6558.misc diff --git a/changelog.d/6558.misc b/changelog.d/6558.misc new file mode 100644 index 0000000000..a7572f1a85 --- /dev/null +++ b/changelog.d/6558.misc @@ -0,0 +1 @@ +Clean up logs from the push notifier at startup. \ No newline at end of file diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index f277aeb131..8ad0bf5936 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -80,9 +80,11 @@ class PusherFactory(object): return EmailPusher(self.hs, pusherdict, mailer) def _app_name_from_pusherdict(self, pusherdict): - if "data" in pusherdict and "brand" in pusherdict["data"]: - app_name = pusherdict["data"]["brand"] - else: - app_name = self.config.email_app_name + data = pusherdict["data"] - return app_name + if isinstance(data, dict): + brand = data.get("brand") + if isinstance(brand, str): + return brand + + return self.config.email_app_name diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 0f6992202d..b9dca5bc63 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -232,7 +232,6 @@ class PusherPool: Deferred """ pushers = yield self.store.get_all_pushers() - logger.info("Starting %d pushers", len(pushers)) # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -245,7 +244,7 @@ class PusherPool: """Start the given pusher Args: - pusherdict (dict): + pusherdict (dict): dict with the values pulled from the db table Returns: Deferred[EmailPusher|HttpPusher] @@ -254,7 +253,8 @@ class PusherPool: p = self.pusher_factory.create_pusher(pusherdict) except PusherConfigException as e: logger.warning( - "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s", + "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s", + pusherdict["id"], pusherdict.get("user_name"), pusherdict.get("app_id"), pusherdict.get("pushkey"), @@ -262,7 +262,9 @@ class PusherPool: ) return except Exception: - logger.exception("Couldn't start a pusher: caught Exception") + logger.exception( + "Couldn't start pusher id %i: caught Exception", pusherdict["id"], + ) return if not p: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 0791866f55..6f6b7aed6e 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -28,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) +ALLOWED_KEYS = { + "app_display_name", + "app_id", + "data", + "device_display_name", + "kind", + "lang", + "profile_tag", + "pushkey", +} + class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) @@ -43,23 +54,11 @@ class PushersRestServlet(RestServlet): pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string()) - allowed_keys = [ - "app_display_name", - "app_id", - "data", - "device_display_name", - "kind", - "lang", - "profile_tag", - "pushkey", - ] + filtered_pushers = list( + {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers + ) - for p in pushers: - for k, v in list(p.items()): - if k not in allowed_keys: - del p[k] - - return 200, {"pushers": pushers} + return 200, {"pushers": filtered_pushers} def on_OPTIONS(self, _): return 200, {} diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py index f07309ef09..6b03233262 100644 --- a/synapse/storage/data_stores/main/pusher.py +++ b/synapse/storage/data_stores/main/pusher.py @@ -15,8 +15,7 @@ # limitations under the License. import logging - -import six +from typing import Iterable, Iterator from canonicaljson import encode_canonical_json, json @@ -27,21 +26,16 @@ from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList logger = logging.getLogger(__name__) -if six.PY2: - db_binary_type = six.moves.builtins.buffer -else: - db_binary_type = memoryview - class PusherWorkerStore(SQLBaseStore): - def _decode_pushers_rows(self, rows): + def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]: + """JSON-decode the data in the rows returned from the `pushers` table + + Drops any rows whose data cannot be decoded + """ for r in rows: dataJson = r["data"] - r["data"] = None try: - if isinstance(dataJson, db_binary_type): - dataJson = str(dataJson).decode("UTF8") - r["data"] = json.loads(dataJson) except Exception as e: logger.warning( @@ -50,12 +44,9 @@ class PusherWorkerStore(SQLBaseStore): dataJson, e.args[0], ) - pass + continue - if isinstance(r["pushkey"], db_binary_type): - r["pushkey"] = str(r["pushkey"]).decode("UTF8") - - return rows + yield r @defer.inlineCallbacks def user_has_pusher(self, user_id): diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 358b593cd4..80187406bc 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -165,6 +165,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -175,6 +176,7 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -192,5 +194,6 @@ class EmailPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=self.user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af2327fb66..fe3441f081 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -104,6 +104,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -114,6 +115,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"]) @@ -132,6 +134,7 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) last_stream_ordering = pushers[0]["last_stream_ordering"] @@ -151,5 +154,6 @@ class HTTPPusherTests(HomeserverTestCase): pushers = self.get_success( self.hs.get_datastore().get_pushers_by(dict(user_name=user_id)) ) + pushers = list(pushers) self.assertEqual(len(pushers), 1) self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering) From b95b762560441b28f06e6458da796327e394953e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Dec 2019 11:11:14 +0000 Subject: [PATCH 62/82] Add an export_signing_key script (#6546) I want to do some key rotation, and it is silly that we don't have a way to do this. --- changelog.d/6546.feature | 1 + docs/code_style.md | 13 +++--- docs/sample_config.yaml | 19 +++++--- scripts/export_signing_key | 94 ++++++++++++++++++++++++++++++++++++++ synapse/config/key.py | 23 ++++++---- 5 files changed, 129 insertions(+), 21 deletions(-) create mode 100644 changelog.d/6546.feature create mode 100755 scripts/export_signing_key diff --git a/changelog.d/6546.feature b/changelog.d/6546.feature new file mode 100644 index 0000000000..954aacb0d0 --- /dev/null +++ b/changelog.d/6546.feature @@ -0,0 +1 @@ +Add an export_signing_key script to extract the public part of signing keys when rotating them. diff --git a/docs/code_style.md b/docs/code_style.md index f983f72d6c..71aecd41f7 100644 --- a/docs/code_style.md +++ b/docs/code_style.md @@ -137,6 +137,7 @@ Some guidelines follow: correctly handles the top-level option being set to `None` (as it will be if no sub-options are enabled). - Lines should be wrapped at 80 characters. +- Use two-space indents. Example: @@ -155,13 +156,13 @@ Example: # Settings for the frobber # frobber: - # frobbing speed. Defaults to 1. - # - #speed: 10 + # frobbing speed. Defaults to 1. + # + #speed: 10 - # frobbing distance. Defaults to 1000. - # - #distance: 100 + # frobbing distance. Defaults to 1000. + # + #distance: 100 Note that the sample configuration is generated from the synapse code and is maintained by a script, `scripts-dev/generate_sample_config`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1787248f53..e3b05423b8 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1122,14 +1122,19 @@ metrics_flags: signing_key_path: "CONFDIR/SERVERNAME.signing.key" # The keys that the server used to sign messages with but won't use -# to sign new messages. E.g. it has lost its private key +# to sign new messages. # -#old_signing_keys: -# "ed25519:auto": -# # Base64 encoded public key -# key: "The public part of your old signing key." -# # Millisecond POSIX timestamp when the key expired. -# expired_ts: 123456789123 +old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. diff --git a/scripts/export_signing_key b/scripts/export_signing_key new file mode 100755 index 0000000000..8aec9d802b --- /dev/null +++ b/scripts/export_signing_key @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import sys +import time +from typing import Optional + +import nacl.signing +from signedjson.key import encode_verify_key_base64, get_verify_key, read_signing_keys + + +def exit(status: int = 0, message: Optional[str] = None): + if message: + print(message, file=sys.stderr) + sys.exit(status) + + +def format_plain(public_key: nacl.signing.VerifyKey): + print( + "%s:%s %s" + % (public_key.alg, public_key.version, encode_verify_key_base64(public_key),) + ) + + +def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int): + print( + ' "%s:%s": { key: "%s", expired_ts: %i }' + % ( + public_key.alg, + public_key.version, + encode_verify_key_base64(public_key), + expiry_ts, + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "key_file", nargs="+", type=argparse.FileType("r"), help="The key file to read", + ) + + parser.add_argument( + "-x", + action="store_true", + dest="for_config", + help="format the output for inclusion in the old_signing_keys config setting", + ) + + parser.add_argument( + "--expiry-ts", + type=int, + default=int(time.time() * 1000) + 6*3600000, + help=( + "The expiry time to use for -x, in milliseconds since 1970. The default " + "is (now+6h)." + ), + ) + + args = parser.parse_args() + + formatter = ( + (lambda k: format_for_config(k, args.expiry_ts)) + if args.for_config + else format_plain + ) + + keys = [] + for file in args.key_file: + try: + res = read_signing_keys(file) + except Exception as e: + exit( + status=1, + message="Error reading key from file %s: %s %s" + % (file.name, type(e), e), + ) + res = [] + for key in res: + formatter(get_verify_key(key)) diff --git a/synapse/config/key.py b/synapse/config/key.py index 52ff1b2621..066e7838c3 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -108,7 +108,7 @@ class KeyConfig(Config): self.signing_key = self.read_signing_keys(signing_key_path, "signing_key") self.old_signing_keys = self.read_old_signing_keys( - config.get("old_signing_keys", {}) + config.get("old_signing_keys") ) self.key_refresh_interval = self.parse_duration( config.get("key_refresh_interval", "1d") @@ -199,14 +199,19 @@ class KeyConfig(Config): signing_key_path: "%(base_key_name)s.signing.key" # The keys that the server used to sign messages with but won't use - # to sign new messages. E.g. it has lost its private key + # to sign new messages. # - #old_signing_keys: - # "ed25519:auto": - # # Base64 encoded public key - # key: "The public part of your old signing key." - # # Millisecond POSIX timestamp when the key expired. - # expired_ts: 123456789123 + old_signing_keys: + # For each key, `key` should be the base64-encoded public key, and + # `expired_ts`should be the time (in milliseconds since the unix epoch) that + # it was last used. + # + # It is possible to build an entry from an old signing.key file using the + # `export_signing_key` script which is provided with synapse. + # + # For example: + # + #"ed25519:id": { key: "base64string", expired_ts: 123456789123 } # How long key response published by this server is valid for. # Used to set the valid_until_ts in /key/v2 APIs. @@ -290,6 +295,8 @@ class KeyConfig(Config): raise ConfigError("Error reading %s: %s" % (name, str(e))) def read_old_signing_keys(self, old_signing_keys): + if old_signing_keys is None: + return {} keys = {} for key_id, key_data in old_signing_keys.items(): if is_signing_algorithm_supported(key_id): From 0b794cbd7b232b42a2d726e6ab6c698d4bf35093 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Dec 2019 14:52:52 +0000 Subject: [PATCH 63/82] Fix sdnotify with acme enabled (#6571) If acme was enabled, the sdnotify startup hook would never be run because we would try to add it to a hook which had already fired. There's no need to delay it: we can sdnotify as soon as we've started the listeners. --- changelog.d/6571.bugfix | 1 + synapse/app/_base.py | 10 +++++++--- 2 files changed, 8 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6571.bugfix diff --git a/changelog.d/6571.bugfix b/changelog.d/6571.bugfix new file mode 100644 index 0000000000..e38ea7b4f7 --- /dev/null +++ b/changelog.d/6571.bugfix @@ -0,0 +1 @@ +Fix a bug which meant that we did not send systemd notifications on startup if acme was enabled. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 9c96816096..0e8b467a3e 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -237,6 +237,12 @@ def start(hs, listeners=None): """ Start a Synapse server or worker. + Should be called once the reactor is running and (if we're using ACME) the + TLS certificates are in place. + + Will start the main HTTP listeners and do some other startup tasks, and then + notify systemd. + Args: hs (synapse.server.HomeServer) listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml) @@ -311,9 +317,7 @@ def setup_sdnotify(hs): # Tell systemd our state, if we're using it. This will silently fail if # we're not using systemd. - hs.get_reactor().addSystemEventTrigger( - "after", "startup", sdnotify, b"READY=1\nMAINPID=%i" % (os.getpid(),) - ) + sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),)) hs.get_reactor().addSystemEventTrigger( "before", "shutdown", sdnotify, b"STOPPING=1" From bca30cefee3849813565dd71e571172818629d85 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 19 Dec 2019 14:53:15 +0000 Subject: [PATCH 64/82] Improve diagnostics on database upgrade failure (#6570) `Failed to upgrade database` is not helpful, and it's unlikely that UPGRADE.rst has anything useful. --- changelog.d/6570.misc | 1 + synapse/app/homeserver.py | 9 ++------- synapse/storage/prepare_database.py | 5 ++++- 3 files changed, 7 insertions(+), 8 deletions(-) create mode 100644 changelog.d/6570.misc diff --git a/changelog.d/6570.misc b/changelog.d/6570.misc new file mode 100644 index 0000000000..e89955a51e --- /dev/null +++ b/changelog.d/6570.misc @@ -0,0 +1 @@ +Improve diagnostics on database upgrade failure. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b8661457e2..0e9bf7f53a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -342,13 +342,8 @@ def setup(config_options): hs.setup() except IncorrectDatabaseSetup as e: quit_with_error(str(e)) - except UpgradeDatabaseException: - sys.stderr.write( - "\nFailed to upgrade database.\n" - "Have you checked for version specific instructions in" - " UPGRADES.rst?\n" - ) - sys.exit(1) + except UpgradeDatabaseException as e: + quit_with_error("Failed to upgrade database: %s" % (e,)) hs.setup_master() diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b4194b44ee..0195edf4ac 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -69,7 +69,10 @@ def prepare_database(db_conn, database_engine, config, data_stores=["main"]): if user_version != SCHEMA_VERSION: # If we don't pass in a config file then we are expecting to # have already upgraded the DB. - raise UpgradeDatabaseException("Database needs to be upgraded") + raise UpgradeDatabaseException( + "Expected database schema version %i but got %i" + % (SCHEMA_VERSION, user_version) + ) else: _upgrade_existing_database( cur, From 3d46124ad01990d37fa54c1599c28314dc5f5d30 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Dec 2019 15:07:28 +0000 Subject: [PATCH 65/82] Port some admin handlers to async/await (#6559) --- changelog.d/6559.misc | 1 + synapse/app/admin_cmd.py | 6 ++- synapse/handlers/admin.py | 41 ++++++++----------- synapse/handlers/deactivate_account.py | 54 ++++++++++++-------------- 4 files changed, 46 insertions(+), 56 deletions(-) create mode 100644 changelog.d/6559.misc diff --git a/changelog.d/6559.misc b/changelog.d/6559.misc new file mode 100644 index 0000000000..8bca37457d --- /dev/null +++ b/changelog.d/6559.misc @@ -0,0 +1 @@ +Port `synapse.handlers.admin` and `synapse.handlers.deactivate_account` to async/await. diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 51a909419f..8e36bc57d3 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -104,8 +104,10 @@ def export_data_command(hs, args): user_id = args.user_id directory = args.output_directory - res = yield hs.get_handlers().admin_handler.export_user_data( - user_id, FileExfiltrationWriter(user_id, directory=directory) + res = yield defer.ensureDeferred( + hs.get_handlers().admin_handler.export_user_data( + user_id, FileExfiltrationWriter(user_id, directory=directory) + ) ) print(res) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index 14449b9a1e..1a4ba12385 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -15,8 +15,6 @@ import logging -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.types import RoomStreamToken from synapse.visibility import filter_events_for_client @@ -33,11 +31,10 @@ class AdminHandler(BaseHandler): self.storage = hs.get_storage() self.state_store = self.storage.state - @defer.inlineCallbacks - def get_whois(self, user): + async def get_whois(self, user): connections = [] - sessions = yield self.store.get_user_ip_and_agents(user) + sessions = await self.store.get_user_ip_and_agents(user) for session in sessions: connections.append( { @@ -54,20 +51,18 @@ class AdminHandler(BaseHandler): return ret - @defer.inlineCallbacks - def get_users(self): + async def get_users(self): """Function to retrieve a list of users in users table. Args: Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.get_users() + ret = await self.store.get_users() return ret - @defer.inlineCallbacks - def get_users_paginate(self, start, limit, name, guests, deactivated): + async def get_users_paginate(self, start, limit, name, guests, deactivated): """Function to retrieve a paginated list of users from users list. This will return a json list of users. @@ -80,14 +75,13 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to json list[dict[str, Any]] """ - ret = yield self.store.get_users_paginate( + ret = await self.store.get_users_paginate( start, limit, name, guests, deactivated ) return ret - @defer.inlineCallbacks - def search_users(self, term): + async def search_users(self, term): """Function to search users list for one or more users with the matched term. @@ -96,7 +90,7 @@ class AdminHandler(BaseHandler): Returns: defer.Deferred: resolves to list[dict[str, Any]] """ - ret = yield self.store.search_users(term) + ret = await self.store.search_users(term) return ret @@ -119,8 +113,7 @@ class AdminHandler(BaseHandler): """ return self.store.set_server_admin(user, admin) - @defer.inlineCallbacks - def export_user_data(self, user_id, writer): + async def export_user_data(self, user_id, writer): """Write all data we have on the user to the given writer. Args: @@ -132,7 +125,7 @@ class AdminHandler(BaseHandler): The returned value is that returned by `writer.finished()`. """ # Get all rooms the user is in or has been in - rooms = yield self.store.get_rooms_for_user_where_membership_is( + rooms = await self.store.get_rooms_for_user_where_membership_is( user_id, membership_list=( Membership.JOIN, @@ -145,7 +138,7 @@ class AdminHandler(BaseHandler): # We only try and fetch events for rooms the user has been in. If # they've been e.g. invited to a room without joining then we handle # those seperately. - rooms_user_has_been_in = yield self.store.get_rooms_user_has_been_in(user_id) + rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id) for index, room in enumerate(rooms): room_id = room.room_id @@ -154,7 +147,7 @@ class AdminHandler(BaseHandler): "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms) ) - forgotten = yield self.store.did_forget(user_id, room_id) + forgotten = await self.store.did_forget(user_id, room_id) if forgotten: logger.info("[%s] User forgot room %d, ignoring", user_id, room_id) continue @@ -166,7 +159,7 @@ class AdminHandler(BaseHandler): if room.membership == Membership.INVITE: event_id = room.event_id - invite = yield self.store.get_event(event_id, allow_none=True) + invite = await self.store.get_event(event_id, allow_none=True) if invite: invited_state = invite.unsigned["invite_room_state"] writer.write_invite(room_id, invite, invited_state) @@ -177,7 +170,7 @@ class AdminHandler(BaseHandler): # were joined. We estimate that point by looking at the # stream_ordering of the last membership if it wasn't a join. if room.membership == Membership.JOIN: - stream_ordering = yield self.store.get_room_max_stream_ordering() + stream_ordering = self.store.get_room_max_stream_ordering() else: stream_ordering = room.stream_ordering @@ -203,7 +196,7 @@ class AdminHandler(BaseHandler): # events that we have and then filtering, this isn't the most # efficient method perhaps but it does guarantee we get everything. while True: - events, _ = yield self.store.paginate_room_events( + events, _ = await self.store.paginate_room_events( room_id, from_key, to_key, limit=100, direction="f" ) if not events: @@ -211,7 +204,7 @@ class AdminHandler(BaseHandler): from_key = events[-1].internal_metadata.after - events = yield filter_events_for_client(self.storage, user_id, events) + events = await filter_events_for_client(self.storage, user_id, events) writer.write_events(room_id, events) @@ -247,7 +240,7 @@ class AdminHandler(BaseHandler): for event_id in extremities: if not event_to_unseen_prevs[event_id]: continue - state = yield self.state_store.get_state_for_event(event_id) + state = await self.state_store.get_state_for_event(event_id) writer.write_state(room_id, event_id, state) return writer.finished() diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 6dedaaff8d..4426967f88 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -15,8 +15,6 @@ # limitations under the License. import logging -from twisted.internet import defer - from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID, create_requester @@ -46,8 +44,7 @@ class DeactivateAccountHandler(BaseHandler): self._account_validity_enabled = hs.config.account_validity.enabled - @defer.inlineCallbacks - def deactivate_account(self, user_id, erase_data, id_server=None): + async def deactivate_account(self, user_id, erase_data, id_server=None): """Deactivate a user's account Args: @@ -74,11 +71,11 @@ class DeactivateAccountHandler(BaseHandler): identity_server_supports_unbinding = True # Retrieve the 3PIDs this user has bound to an identity server - threepids = yield self.store.user_get_bound_threepids(user_id) + threepids = await self.store.user_get_bound_threepids(user_id) for threepid in threepids: try: - result = yield self._identity_handler.try_unbind_threepid( + result = await self._identity_handler.try_unbind_threepid( user_id, { "medium": threepid["medium"], @@ -91,33 +88,33 @@ class DeactivateAccountHandler(BaseHandler): # Do we want this to be a fatal error or should we carry on? logger.exception("Failed to remove threepid from ID server") raise SynapseError(400, "Failed to remove threepid from ID server") - yield self.store.user_delete_threepid( + await self.store.user_delete_threepid( user_id, threepid["medium"], threepid["address"] ) # Remove all 3PIDs this user has bound to the homeserver - yield self.store.user_delete_threepids(user_id) + await self.store.user_delete_threepids(user_id) # delete any devices belonging to the user, which will also # delete corresponding access tokens. - yield self._device_handler.delete_all_devices_for_user(user_id) + await self._device_handler.delete_all_devices_for_user(user_id) # then delete any remaining access tokens which weren't associated with # a device. - yield self._auth_handler.delete_access_tokens_for_user(user_id) + await self._auth_handler.delete_access_tokens_for_user(user_id) - yield self.store.user_set_password_hash(user_id, None) + await self.store.user_set_password_hash(user_id, None) # Add the user to a table of users pending deactivation (ie. # removal from all the rooms they're a member of) - yield self.store.add_user_pending_deactivation(user_id) + await self.store.add_user_pending_deactivation(user_id) # delete from user directory - yield self.user_directory_handler.handle_user_deactivated(user_id) + await self.user_directory_handler.handle_user_deactivated(user_id) # Mark the user as erased, if they asked for that if erase_data: logger.info("Marking %s as erased", user_id) - yield self.store.mark_user_erased(user_id) + await self.store.mark_user_erased(user_id) # Now start the process that goes through that list and # parts users from rooms (if it isn't already running) @@ -125,30 +122,29 @@ class DeactivateAccountHandler(BaseHandler): # Reject all pending invites for the user, so that the user doesn't show up in the # "invited" section of rooms' members list. - yield self._reject_pending_invites_for_user(user_id) + await self._reject_pending_invites_for_user(user_id) # Remove all information on the user from the account_validity table. if self._account_validity_enabled: - yield self.store.delete_account_validity_for_user(user_id) + await self.store.delete_account_validity_for_user(user_id) # Mark the user as deactivated. - yield self.store.set_user_deactivated_status(user_id, True) + await self.store.set_user_deactivated_status(user_id, True) return identity_server_supports_unbinding - @defer.inlineCallbacks - def _reject_pending_invites_for_user(self, user_id): + async def _reject_pending_invites_for_user(self, user_id): """Reject pending invites addressed to a given user ID. Args: user_id (str): The user ID to reject pending invites for. """ user = UserID.from_string(user_id) - pending_invites = yield self.store.get_invited_rooms_for_user(user_id) + pending_invites = await self.store.get_invited_rooms_for_user(user_id) for room in pending_invites: try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room.room_id, @@ -180,8 +176,7 @@ class DeactivateAccountHandler(BaseHandler): if not self._user_parter_running: run_as_background_process("user_parter_loop", self._user_parter_loop) - @defer.inlineCallbacks - def _user_parter_loop(self): + async def _user_parter_loop(self): """Loop that parts deactivated users from rooms Returns: @@ -191,19 +186,18 @@ class DeactivateAccountHandler(BaseHandler): logger.info("Starting user parter") try: while True: - user_id = yield self.store.get_user_pending_deactivation() + user_id = await self.store.get_user_pending_deactivation() if user_id is None: break logger.info("User parter parting %r", user_id) - yield self._part_user(user_id) - yield self.store.del_user_pending_deactivation(user_id) + await self._part_user(user_id) + await self.store.del_user_pending_deactivation(user_id) logger.info("User parter finished parting %r", user_id) logger.info("User parter finished: stopping") finally: self._user_parter_running = False - @defer.inlineCallbacks - def _part_user(self, user_id): + async def _part_user(self, user_id): """Causes the given user_id to leave all the rooms they're joined to Returns: @@ -211,11 +205,11 @@ class DeactivateAccountHandler(BaseHandler): """ user = UserID.from_string(user_id) - rooms_for_user = yield self.store.get_rooms_for_user(user_id) + rooms_for_user = await self.store.get_rooms_for_user(user_id) for room_id in rooms_for_user: logger.info("User parter parting %r from %r", user_id, room_id) try: - yield self._room_member_handler.update_membership( + await self._room_member_handler.update_membership( create_requester(user), user, room_id, From 0b5dbadd9607714c471cbf317a64a96d935898a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 19 Dec 2019 15:07:37 +0000 Subject: [PATCH 66/82] Explode on duplicate delta file names. (#6565) --- changelog.d/6565.misc | 1 + synapse/storage/prepare_database.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 changelog.d/6565.misc diff --git a/changelog.d/6565.misc b/changelog.d/6565.misc new file mode 100644 index 0000000000..e83f245bf0 --- /dev/null +++ b/changelog.d/6565.misc @@ -0,0 +1 @@ +Add assertion that schema delta file names are unique. diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 0195edf4ac..403848ad03 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -18,6 +18,7 @@ import imp import logging import os import re +from collections import Counter import attr @@ -315,6 +316,9 @@ def _upgrade_existing_database( ) ) + # Used to check if we have any duplicate file names + file_name_counter = Counter() + # Now find which directories have anything of interest. directory_entries = [] for directory in directories: @@ -325,6 +329,9 @@ def _upgrade_existing_database( _DirectoryListing(file_name, os.path.join(directory, file_name)) for file_name in file_names ) + + for file_name in file_names: + file_name_counter[file_name] += 1 except FileNotFoundError: # Data stores can have empty entries for a given version delta. pass @@ -333,6 +340,17 @@ def _upgrade_existing_database( "Could not open delta dir for version %d: %s" % (v, directory) ) + duplicates = set( + file_name for file_name, count in file_name_counter.items() if count > 1 + ) + if duplicates: + # We don't support using the same file name in the same delta version. + raise PrepareDatabaseException( + "Found multiple delta files with the same name in v%d: %s", + v, + duplicates, + ) + # We sort to ensure that we apply the delta files in a consistent # order (to avoid bugs caused by inconsistent directory listing order) directory_entries.sort() From fa780e9721c940479a72eed9877ccad4fef78160 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Dec 2019 10:32:02 +0000 Subject: [PATCH 67/82] Change EventContext to use the Storage class (#6564) --- changelog.d/6564.misc | 1 + synapse/api/auth.py | 2 +- synapse/events/snapshot.py | 36 +++++++++++-------- synapse/events/third_party_rules.py | 2 +- synapse/handlers/_base.py | 2 +- synapse/handlers/federation.py | 14 ++++---- synapse/handlers/message.py | 10 +++--- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 4 +-- synapse/push/bulk_push_rule_evaluator.py | 4 +-- synapse/replication/http/federation.py | 5 ++- synapse/replication/http/send_event.py | 3 +- synapse/storage/data_stores/main/push_rule.py | 2 +- .../storage/data_stores/main/roommember.py | 2 +- tests/test_state.py | 28 +++++++-------- 15 files changed, 64 insertions(+), 53 deletions(-) create mode 100644 changelog.d/6564.misc diff --git a/changelog.d/6564.misc b/changelog.d/6564.misc new file mode 100644 index 0000000000..f644f5868b --- /dev/null +++ b/changelog.d/6564.misc @@ -0,0 +1 @@ +Change `EventContext` to use the `Storage` class, in preparation for moving state database queries to a separate data store. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9fd52a8c77..abbc7079a3 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -79,7 +79,7 @@ class Auth(object): @defer.inlineCallbacks def check_from_context(self, room_version, event, context, do_sig_check=True): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.compute_auth_events( event, prev_state_ids, for_verification=True ) diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 64e898f40c..a44baea365 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -149,7 +149,7 @@ class EventContext: # the prev_state_ids, so if we're a state event we include the event # id that we replaced in the state. if event.is_state(): - prev_state_ids = yield self.get_prev_state_ids(store) + prev_state_ids = yield self.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) else: prev_state_id = None @@ -167,12 +167,13 @@ class EventContext: } @staticmethod - def deserialize(store, input): + def deserialize(storage, input): """Converts a dict that was produced by `serialize` back into a EventContext. Args: - store (DataStore): Used to convert AS ID to AS object + storage (Storage): Used to convert AS ID to AS object and fetch + state. input (dict): A dict produced by `serialize` Returns: @@ -181,6 +182,7 @@ class EventContext: context = _AsyncEventContextImpl( # We use the state_group and prev_state_id stuff to pull the # current_state_ids out of the DB and construct prev_state_ids. + storage=storage, prev_state_id=input["prev_state_id"], event_type=input["event_type"], event_state_key=input["event_state_key"], @@ -193,7 +195,7 @@ class EventContext: app_service_id = input["app_service_id"] if app_service_id: - context.app_service = store.get_app_service_by_id(app_service_id) + context.app_service = storage.main.get_app_service_by_id(app_service_id) return context @@ -216,7 +218,7 @@ class EventContext: return self._state_group @defer.inlineCallbacks - def get_current_state_ids(self, store): + def get_current_state_ids(self): """ Gets the room state map, including this event - ie, the state in ``state_group`` @@ -234,11 +236,11 @@ class EventContext: if self.rejected: raise RuntimeError("Attempt to access state_ids of rejected event") - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._current_state_ids @defer.inlineCallbacks - def get_prev_state_ids(self, store): + def get_prev_state_ids(self): """ Gets the room state map, excluding this event. @@ -250,7 +252,7 @@ class EventContext: Maps a (type, state_key) to the event ID of the state event matching this tuple. """ - yield self._ensure_fetched(store) + yield self._ensure_fetched() return self._prev_state_ids def get_cached_current_state_ids(self): @@ -270,7 +272,7 @@ class EventContext: return self._current_state_ids - def _ensure_fetched(self, store): + def _ensure_fetched(self): return defer.succeed(None) @@ -282,6 +284,8 @@ class _AsyncEventContextImpl(EventContext): Attributes: + _storage (Storage) + _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have been calculated. None if we haven't started calculating yet @@ -295,28 +299,30 @@ class _AsyncEventContextImpl(EventContext): that was replaced. """ + # This needs to have a default as we're inheriting + _storage = attr.ib(default=None) _prev_state_id = attr.ib(default=None) _event_type = attr.ib(default=None) _event_state_key = attr.ib(default=None) _fetching_state_deferred = attr.ib(default=None) - def _ensure_fetched(self, store): + def _ensure_fetched(self): if not self._fetching_state_deferred: - self._fetching_state_deferred = run_in_background( - self._fill_out_state, store - ) + self._fetching_state_deferred = run_in_background(self._fill_out_state) return make_deferred_yieldable(self._fetching_state_deferred) @defer.inlineCallbacks - def _fill_out_state(self, store): + def _fill_out_state(self): """Called to populate the _current_state_ids and _prev_state_ids attributes by loading from the database. """ if self.state_group is None: return - self._current_state_ids = yield store.get_state_ids_for_group(self.state_group) + self._current_state_ids = yield self._storage.state.get_state_ids_for_group( + self.state_group + ) if self._prev_state_id and self._event_state_key is not None: self._prev_state_ids = dict(self._current_state_ids) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 714a9b1579..86f7e5f8aa 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -53,7 +53,7 @@ class ThirdPartyEventRules(object): if self.third_party_rules is None: return True - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() # Retrieve the state events from the database. state_events = {} diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index d15c6282fb..51413d910e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -134,7 +134,7 @@ class BaseHandler(object): guest_access = event.content.get("guest_access", "forbidden") if guest_access != "can_join": if context: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state = yield self.store.get_events( list(current_state_ids.values()) ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 60bb00fc6a..05ae40dde7 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -718,7 +718,7 @@ class FederationHandler(BaseHandler): # changing their profile info. newly_joined = True - prev_state_ids = await context.get_prev_state_ids(self.store) + prev_state_ids = await context.get_prev_state_ids() prev_state_id = prev_state_ids.get((event.type, event.state_key)) if prev_state_id: @@ -1418,7 +1418,7 @@ class FederationHandler(BaseHandler): user = UserID.from_string(event.state_key) yield self.user_joined_room(user, event.room_id) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) auth_chain = yield self.store.get_auth_chain(state_ids) @@ -1927,7 +1927,7 @@ class FederationHandler(BaseHandler): context = yield self.state_handler.compute_event_context(event, old_state=state) if not auth_events: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -2336,12 +2336,12 @@ class FederationHandler(BaseHandler): k: a.event_id for k, a in iteritems(auth_events) if k != event_key } - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() current_state_ids = dict(current_state_ids) current_state_ids.update(state_updates) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = dict(prev_state_ids) prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) @@ -2625,7 +2625,7 @@ class FederationHandler(BaseHandler): event.content["third_party_invite"]["signed"]["token"], ) original_invite = None - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() original_invite_id = prev_state_ids.get(key) if original_invite_id: original_invite = yield self.store.get_event( @@ -2673,7 +2673,7 @@ class FederationHandler(BaseHandler): signed = event.content["third_party_invite"]["signed"] token = signed["token"] - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token)) invite_event = None diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index bf9add7fe2..4ad752205f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -515,7 +515,7 @@ class EventCreationHandler(object): # federation as well as those created locally. As of room v3, aliases events # can be created by users that are not in the room, therefore we have to # tolerate them in event_auth.check(). - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender)) prev_event = ( yield self.store.get_event(prev_event_id, allow_none=True) @@ -665,7 +665,7 @@ class EventCreationHandler(object): If so, returns the version of the event in context. Otherwise, returns None. """ - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_event_id = prev_state_ids.get((event.type, event.state_key)) if not prev_event_id: return @@ -914,7 +914,7 @@ class EventCreationHandler(object): def is_inviter_member_event(e): return e.type == EventTypes.Member and e.sender == event.sender - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() state_to_include_ids = [ e_id @@ -967,7 +967,7 @@ class EventCreationHandler(object): if original_event.room_id != event.room_id: raise SynapseError(400, "Cannot redact event from a different room") - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() auth_events_ids = yield self.auth.compute_auth_events( event, prev_state_ids, for_verification=True ) @@ -989,7 +989,7 @@ class EventCreationHandler(object): event.internal_metadata.recheck_redaction = False if event.type == EventTypes.Create: - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d3a1a7b4a6..89c9118b26 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -184,7 +184,7 @@ class RoomCreationHandler(BaseHandler): requester, tombstone_event, tombstone_context ) - old_room_state = yield tombstone_context.get_current_state_ids(self.store) + old_room_state = yield tombstone_context.get_current_state_ids() # update any aliases yield self._move_aliases_to_new_room( diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7b7270fc61..44c5e3239c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,7 +193,7 @@ class RoomMemberHandler(object): requester, event, context, extra_users=[target], ratelimit=ratelimit ) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) @@ -601,7 +601,7 @@ class RoomMemberHandler(object): if prev_event is not None: return - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() if event.membership == Membership.JOIN: if requester.is_guest: guest_can_join = yield self._can_guest_join(prev_state_ids) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 7881780760..7d9f5a38d9 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -116,7 +116,7 @@ class BulkPushRuleEvaluator(object): @defer.inlineCallbacks def _get_power_levels_and_sender_level(self, event, context): - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() pl_event_id = prev_state_ids.get(POWER_KEY) if pl_event_id: # fastpath: if there's a power level event, that's all we need, and @@ -304,7 +304,7 @@ class RulesForRoom(object): push_rules_delta_state_cache_metric.inc_hits() else: - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() push_rules_delta_state_cache_metric.inc_misses() push_rules_state_size_counter.inc(len(current_state_ids)) diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 9af4e7e173..49a3251372 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -51,6 +51,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): super(ReplicationFederationSendEventsRestServlet, self).__init__(hs) self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() self.federation_handler = hs.get_handlers().federation_handler @@ -100,7 +101,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): EventType = event_type_from_format_version(format_ver) event = EventType(event_dict, internal_metadata, rejected_reason) - context = EventContext.deserialize(self.store, event_payload["context"]) + context = EventContext.deserialize( + self.storage, event_payload["context"] + ) event_and_contexts.append((event, context)) diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index 9bafd60b14..84b92f16ad 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -54,6 +54,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() + self.storage = hs.get_storage() self.clock = hs.get_clock() @staticmethod @@ -100,7 +101,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): event = EventType(event_dict, internal_metadata, rejected_reason) requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize(self.store, content["context"]) + context = EventContext.deserialize(self.storage, content["context"]) ratelimit = content["ratelimit"] extra_users = [UserID.from_string(u) for u in content["extra_users"]] diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py index 5ba13aa973..e2673ae073 100644 --- a/synapse/storage/data_stores/main/push_rule.py +++ b/synapse/storage/data_stores/main/push_rule.py @@ -244,7 +244,7 @@ class PushRulesWorkerStore( # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._bulk_get_push_rules_for_room( event.room_id, state_group, current_state_ids, event=event ) diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py index 92e3b9c512..70ff5751b6 100644 --- a/synapse/storage/data_stores/main/roommember.py +++ b/synapse/storage/data_stores/main/roommember.py @@ -477,7 +477,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - current_state_ids = yield context.get_current_state_ids(self) + current_state_ids = yield context.get_current_state_ids() result = yield self._get_joined_users_from_context( event.room_id, state_group, current_state_ids, event=event, context=context ) diff --git a/tests/test_state.py b/tests/test_state.py index 176535947a..e0aae06be4 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -209,7 +209,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertEqual(2, len(prev_state_ids)) self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event) @@ -253,7 +253,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()} ) @@ -312,7 +312,7 @@ class StateTestCase(unittest.TestCase): ctx_c = context_store["C"] ctx_e = context_store["E"] - prev_state_ids = yield ctx_e.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_e.get_prev_state_ids() self.assertSetEqual( {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()} ) @@ -387,7 +387,7 @@ class StateTestCase(unittest.TestCase): ctx_b = context_store["B"] ctx_d = context_store["D"] - prev_state_ids = yield ctx_d.get_prev_state_ids(self.store) + prev_state_ids = yield ctx_d.get_prev_state_ids() self.assertSetEqual( {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()} ) @@ -419,10 +419,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state), current_state_ids.values() ) @@ -442,10 +442,10 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event, old_state=old_state) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertCountEqual( (e.event_id for e in old_state + [event]), current_state_ids.values() ) @@ -479,7 +479,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(current_state_ids.values()) @@ -511,7 +511,7 @@ class StateTestCase(unittest.TestCase): context = yield self.state.compute_event_context(event) - prev_state_ids = yield context.get_prev_state_ids(self.store) + prev_state_ids = yield context.get_prev_state_ids() self.assertEqual( set([e.event_id for e in old_state]), set(prev_state_ids.values()) @@ -552,7 +552,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -594,7 +594,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(len(current_state_ids), 6) @@ -649,7 +649,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) @@ -677,7 +677,7 @@ class StateTestCase(unittest.TestCase): event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 ) - current_state_ids = yield context.get_current_state_ids(self.store) + current_state_ids = yield context.get_current_state_ids() self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) From 75d8f26ac85efd3816d454927f40b6e4c3032df1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 20 Dec 2019 10:48:24 +0000 Subject: [PATCH 68/82] Split state groups into a separate data store (#6296) --- changelog.d/6245.misc | 1 + scripts/synapse_port_db | 8 +- synapse/config/database.py | 10 +- synapse/storage/data_stores/__init__.py | 5 + synapse/storage/data_stores/main/events.py | 157 --- .../main/schema/delta/32/remove_indices.sql | 1 - .../schema/full_schemas/54/full.sql.postgres | 52 - .../schema/full_schemas/54/full.sql.sqlite | 6 - synapse/storage/data_stores/main/state.py | 937 +----------------- synapse/storage/data_stores/state/__init__.py | 16 + .../storage/data_stores/state/bg_updates.py | 374 +++++++ .../schema/delta/23/drop_state_index.sql | 0 .../schema/delta/30/state_stream.sql | 0 .../schema/delta/32/remove_state_indices.sql | 19 + .../schema/delta/35/add_state_index.sql | 0 .../{main => state}/schema/delta/35/state.sql | 0 .../schema/delta/35/state_dedupe.sql | 0 .../schema/delta/47/state_group_seq.py | 0 .../schema/delta/56/state_group_room_idx.sql | 17 + .../state/schema/full_schemas/54/full.sql | 37 + .../full_schemas/54/sequence.sql.postgres | 21 + synapse/storage/data_stores/state/store.py | 640 ++++++++++++ synapse/storage/persist_events.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/purge_events.py | 4 +- synapse/storage/state.py | 14 +- tests/storage/test_state.py | 2 +- tests/utils.py | 2 +- 28 files changed, 1159 insertions(+), 1168 deletions(-) create mode 100644 changelog.d/6245.misc create mode 100644 synapse/storage/data_stores/state/__init__.py create mode 100644 synapse/storage/data_stores/state/bg_updates.py rename synapse/storage/data_stores/{main => state}/schema/delta/23/drop_state_index.sql (100%) rename synapse/storage/data_stores/{main => state}/schema/delta/30/state_stream.sql (100%) create mode 100644 synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql rename synapse/storage/data_stores/{main => state}/schema/delta/35/add_state_index.sql (100%) rename synapse/storage/data_stores/{main => state}/schema/delta/35/state.sql (100%) rename synapse/storage/data_stores/{main => state}/schema/delta/35/state_dedupe.sql (100%) rename synapse/storage/data_stores/{main => state}/schema/delta/47/state_group_seq.py (100%) create mode 100644 synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql create mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/full.sql create mode 100644 synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres create mode 100644 synapse/storage/data_stores/state/store.py diff --git a/changelog.d/6245.misc b/changelog.d/6245.misc new file mode 100644 index 0000000000..a3e6b8296e --- /dev/null +++ b/changelog.d/6245.misc @@ -0,0 +1 @@ +Split out state storage into separate data store. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 5b5368988c..eb927f2094 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -51,11 +51,12 @@ from synapse.storage.data_stores.main.registration import ( from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore -from synapse.storage.data_stores.main.state import StateBackgroundUpdateStore +from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore from synapse.storage.data_stores.main.stats import StatsStore from synapse.storage.data_stores.main.user_directory import ( UserDirectoryBackgroundUpdateStore, ) +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -138,6 +139,7 @@ class Store( RoomMemberBackgroundUpdateStore, SearchBackgroundUpdateStore, StateBackgroundUpdateStore, + MainStateBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore, StatsStore, ): @@ -496,9 +498,7 @@ class Porter(object): def run(self): try: self.sqlite_store = yield self.build_db_store( - DatabaseConnectionConfig( - "master", self.sqlite_config, data_stores=["main"] - ) + DatabaseConnectionConfig("master-sqlite", self.sqlite_config) ) # Check if all background updates are done, abort if not. diff --git a/synapse/config/database.py b/synapse/config/database.py index 5f2f3c7cfd..134824789c 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -34,10 +34,12 @@ class DatabaseConnectionConfig: module name, and `args` for the args to give to the database connector. data_stores: The list of data stores that should be provisioned on the - database. + database. Defaults to all data stores. """ - def __init__(self, name: str, db_config: dict, data_stores: List[str]): + def __init__( + self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"] + ): if db_config["name"] not in ("sqlite3", "psycopg2"): raise ConfigError("Unsupported database type %r" % (db_config["name"],)) @@ -62,9 +64,7 @@ class DatabaseConfig(Config): if database_config is None: database_config = {"name": "sqlite3", "args": {}} - self.databases = [ - DatabaseConnectionConfig("master", database_config, data_stores=["main"]) - ] + self.databases = [DatabaseConnectionConfig("master", database_config)] self.set_databasepath(config.get("database_path")) diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index 0983e059c0..d20df5f076 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -15,6 +15,7 @@ import logging +from synapse.storage.data_stores.state import StateGroupDataStore from synapse.storage.database import Database, make_conn from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -55,6 +56,10 @@ class DataStores(object): logger.info("Starting 'main' data store") self.main = main_store_class(database, db_conn, hs) + if "state" in database_config.data_stores: + logger.info("Starting 'state' data store") + self.state = StateGroupDataStore(database, db_conn, hs) + db_conn.commit() self.databases.append(database) diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py index 998bba1aad..58f35d7f56 100644 --- a/synapse/storage/data_stores/main/events.py +++ b/synapse/storage/data_stores/main/events.py @@ -1757,163 +1757,6 @@ class EventsStore( return state_groups - def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete - ) -> defer.Deferred: - """Deletes no longer referenced state groups and de-deltas any state - groups that reference them. - - Args: - room_id: The room the state groups belong to (must all be in the - same room). - state_groups_to_delete (Collection[int]): Set of all state groups - to delete. - """ - - return self.db.runInteraction( - "purge_unreferenced_state_groups", - self._purge_unreferenced_state_groups, - room_id, - state_groups_to_delete, - ) - - def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): - logger.info( - "[purge] found %i state groups to delete", len(state_groups_to_delete) - ) - - rows = self.db.simple_select_many_txn( - txn, - table="state_group_edges", - column="prev_state_group", - iterable=state_groups_to_delete, - keyvalues={}, - retcols=("state_group",), - ) - - remaining_state_groups = set( - row["state_group"] - for row in rows - if row["state_group"] not in state_groups_to_delete - ) - - logger.info( - "[purge] de-delta-ing %i remaining state groups", - len(remaining_state_groups), - ) - - # Now we turn the state groups that reference to-be-deleted state - # groups to non delta versions. - for sg in remaining_state_groups: - logger.info("[purge] de-delta-ing remaining state group %s", sg) - curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) - curr_state = curr_state[sg] - - self.db.simple_delete_txn( - txn, table="state_groups_state", keyvalues={"state_group": sg} - ) - - self.db.simple_delete_txn( - txn, table="state_group_edges", keyvalues={"state_group": sg} - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": sg, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(curr_state) - ], - ) - - logger.info("[purge] removing redundant state groups") - txn.executemany( - "DELETE FROM state_groups_state WHERE state_group = ?", - ((sg,) for sg in state_groups_to_delete), - ) - txn.executemany( - "DELETE FROM state_groups WHERE id = ?", - ((sg,) for sg in state_groups_to_delete), - ) - - @defer.inlineCallbacks - def get_previous_state_groups(self, state_groups): - """Fetch the previous groups of the given state groups. - - Args: - state_groups (Iterable[int]) - - Returns: - Deferred[dict[int, int]]: mapping from state group to previous - state group. - """ - - rows = yield self.db.simple_select_many_batch( - table="state_group_edges", - column="prev_state_group", - iterable=state_groups, - keyvalues={}, - retcols=("prev_state_group", "state_group"), - desc="get_previous_state_groups", - ) - - return {row["state_group"]: row["prev_state_group"] for row in rows} - - def purge_room_state(self, room_id, state_groups_to_delete): - """Deletes all record of a room from state tables - - Args: - room_id (str): - state_groups_to_delete (list[int]): State groups to delete - """ - - return self.db.runInteraction( - "purge_room_state", - self._purge_room_state_txn, - room_id, - state_groups_to_delete, - ) - - def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): - # first we have to delete the state groups states - logger.info("[purge] removing %s from state_groups_state", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups_state", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state group edges - logger.info("[purge] removing %s from state_group_edges", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_group_edges", - column="state_group", - iterable=state_groups_to_delete, - keyvalues={}, - ) - - # ... and the state groups - logger.info("[purge] removing %s from state_groups", room_id) - - self.db.simple_delete_many_txn( - txn, - table="state_groups", - column="id", - iterable=state_groups_to_delete, - keyvalues={}, - ) - async def is_event_after(self, event_id1, event_id2): """Returns True if event_id1 is after event_id2 in the stream """ diff --git a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql index 4219cdd06a..2de50d408c 100644 --- a/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql +++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql @@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY -DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres index 4ad2929f32..889a9a0ce4 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres @@ -975,40 +975,6 @@ CREATE TABLE state_events ( -CREATE TABLE state_group_edges ( - state_group bigint NOT NULL, - prev_state_group bigint NOT NULL -); - - - -CREATE SEQUENCE state_group_id_seq - START WITH 1 - INCREMENT BY 1 - NO MINVALUE - NO MAXVALUE - CACHE 1; - - - -CREATE TABLE state_groups ( - id bigint NOT NULL, - room_id text NOT NULL, - event_id text NOT NULL -); - - - -CREATE TABLE state_groups_state ( - state_group bigint NOT NULL, - room_id text NOT NULL, - type text NOT NULL, - state_key text NOT NULL, - event_id text NOT NULL -); - - - CREATE TABLE stats_stream_pos ( lock character(1) DEFAULT 'X'::bpchar NOT NULL, stream_id bigint, @@ -1482,12 +1448,6 @@ ALTER TABLE ONLY state_events ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id); - -ALTER TABLE ONLY state_groups - ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id); - - - ALTER TABLE ONLY stats_stream_pos ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock); @@ -1928,18 +1888,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts); -CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group); - - - -CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group); - - - -CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key); - - - CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering); diff --git a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite index bad33291e7..a0411ede7e 100644 --- a/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite +++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite @@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id); CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id); CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) ); CREATE INDEX room_depth_room ON room_depth(room_id); -CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); -CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) ); CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) ); CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) ); @@ -120,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL ); CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT); CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id ); CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id ); -CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL ); -CREATE INDEX state_group_edges_idx ON state_group_edges(state_group); -CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group); CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL ); CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering ); CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering ); @@ -254,6 +249,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen); CREATE INDEX users_creation_ts ON users (creation_ts); CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group); CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id); -CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key); CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id); CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip); diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index dcc6b43cdf..0dc39f139c 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -17,8 +17,7 @@ import logging from collections import namedtuple from typing import Iterable, Tuple -from six import iteritems, itervalues -from six.moves import range +from six import iteritems from twisted.internet import defer @@ -29,11 +28,9 @@ from synapse.events.snapshot import EventContext from synapse.storage._base import SQLBaseStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.database import Database -from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter -from synapse.util.caches import get_cache_factor_for, intern_string +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -55,207 +52,14 @@ class _GetStateGroupDelta( return len(self.delta_ids) if self.delta_ids else 0 -class StateGroupBackgroundUpdateStore(SQLBaseStore): - """Defines functions related to state groups needed to run the state backgroud - updates. - """ - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """ - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter=StateFilter.all() - ): - results = {group: {} for group in groups} - - where_clause, where_args = state_filter.make_sql_filter_clause() - - # Unless the filter clause is empty, we're going to append it after an - # existing where clause - if where_clause: - where_clause = " AND (%s)" % (where_clause,) - - if isinstance(self.database_engine, PostgresEngine): - # Temporarily disable sequential scans in this transaction. This is - # a temporary hack until we can add the right indices in - txn.execute("SET LOCAL enable_seqscan=off") - - # The below query walks the state_group tree so that the "state" - # table includes all state_groups in the tree. It then joins - # against `state_groups_state` to fetch the latest state. - # It assumes that previous state groups are always numerically - # lesser. - # The PARTITION is used to get the event_id in the greatest state - # group for the given type, state_key. - # This may return multiple rows per (type, state_key), but last_value - # should be the same. - sql = """ - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT DISTINCT type, state_key, last_value(event_id) OVER ( - PARTITION BY type, state_key ORDER BY state_group ASC - ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING - ) AS event_id FROM state_groups_state - WHERE state_group IN ( - SELECT state_group FROM state - ) - """ - - for group in groups: - args = [group] - args.extend(where_args) - - txn.execute(sql + where_clause, args) - for row in txn: - typ, state_key, event_id = row - key = (typ, state_key) - results[group][key] = event_id - else: - max_entries_returned = state_filter.max_entries_returned() - - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - for group in groups: - next_group = group - - while next_group: - # We did this before by getting the list of group ids, and - # then passing that list to sqlite to get latest event for - # each (type, state_key). However, that was terribly slow - # without the right indices (which we can't add until - # after we finish deduping state, which requires this func) - args = [next_group] - args.extend(where_args) - - txn.execute( - "SELECT type, state_key, event_id FROM state_groups_state" - " WHERE state_group = ? " + where_clause, - args, - ) - results[group].update( - ((typ, state_key), event_id) - for typ, state_key, event_id in txn - if (typ, state_key) not in results[group] - ) - - # If the number of entries in the (type,state_key)->event_id dict - # matches the number of (type,state_keys) types we were searching - # for, then we must have found them all, so no need to go walk - # further down the tree... UNLESS our types filter contained - # wildcards (i.e. Nones) in which case we have to do an exhaustive - # search - if ( - max_entries_returned is not None - and len(results[group]) == max_entries_returned - ): - break - - next_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - - return results - - # this inherits from EventsWorkerStore because it calls self.get_events -class StateGroupWorkerStore( - EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore -): +class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): """The parts of StateGroupStore that can be called from workers. """ - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" - CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, database: Database, db_conn, hs): super(StateGroupWorkerStore, self).__init__(database, db_conn, hs) - # Originally the state store used a single DictionaryCache to cache the - # event IDs for the state types in a given state group to avoid hammering - # on the state_group* tables. - # - # The point of using a DictionaryCache is that it can cache a subset - # of the state events for a given state group (i.e. a subset of the keys for a - # given dict which is an entry in the cache for a given state group ID). - # - # However, this poses problems when performing complicated queries - # on the store - for instance: "give me all the state for this group, but - # limit members to this subset of users", as DictionaryCache's API isn't - # rich enough to say "please cache any of these fields, apart from this subset". - # This is problematic when lazy loading members, which requires this behaviour, - # as without it the cache has no choice but to speculatively load all - # state events for the group, which negates the efficiency being sought. - # - # Rather than overcomplicating DictionaryCache's API, we instead split the - # state_group_cache into two halves - one for tracking non-member events, - # and the other for tracking member_events. This means that lazy loading - # queries can be made in a cache-friendly manner by querying both caches - # separately and then merging the result. So for the example above, you - # would query the members cache for a specific subset of state keys - # (which DictionaryCache will handle efficiently and fine) and the non-members - # cache for all state (which DictionaryCache will similarly handle fine) - # and then just merge the results together. - # - # We size the non-members cache to be smaller than the members cache as the - # vast majority of state in Matrix (today) is member events. - - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", - # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), - ) - self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), - ) - @defer.inlineCallbacks def get_room_version(self, room_id): """Get the room_version of a given room @@ -431,229 +235,6 @@ class StateGroupWorkerStore( return event.content.get("canonical_alias") - @cached(max_entries=10000, iterable=True) - def get_state_group_delta(self, state_group): - """Given a state group try to return a previous group and a delta between - the old and the new. - - Returns: - (prev_group, delta_ids), where both may be None. - """ - - def _get_state_group_delta_txn(txn): - prev_group = self.db.simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - retcol="prev_state_group", - allow_none=True, - ) - - if not prev_group: - return _GetStateGroupDelta(None, None) - - delta_ids = self.db.simple_select_list_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - retcols=("type", "state_key", "event_id"), - ) - - return _GetStateGroupDelta( - prev_group, - {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, - ) - - return self.db.runInteraction( - "get_state_group_delta", _get_state_group_delta_txn - ) - - @defer.inlineCallbacks - def get_state_groups_ids(self, _room_id, event_ids): - """Get the event IDs of all the state for the state groups for the given events - - Args: - _room_id (str): id of the room for these events - event_ids (iterable[str]): ids of the events - - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - if not event_ids: - return {} - - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups) - - return group_to_state - - @defer.inlineCallbacks - def get_state_ids_for_group(self, state_group): - """Get the event IDs of all the state in the given state group - - Args: - state_group (int) - - Returns: - Deferred[dict]: Resolves to a map of (type, state_key) -> event_id - """ - group_to_state = yield self._get_state_for_groups((state_group,)) - - return group_to_state[state_group] - - @defer.inlineCallbacks - def get_state_groups(self, room_id, event_ids): - """ Get the state groups for the given list of event_ids - - Returns: - Deferred[dict[int, list[EventBase]]]: - dict of state_group_id -> list of state events. - """ - if not event_ids: - return {} - - group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) - - state_event_map = yield self.get_events( - [ - ev_id - for group_ids in itervalues(group_to_ids) - for ev_id in itervalues(group_ids) - ], - get_prev_content=False, - ) - - return { - group: [ - state_event_map[v] - for v in itervalues(event_id_map) - if v in state_event_map - ] - for group, event_id_map in iteritems(group_to_ids) - } - - @defer.inlineCallbacks - def _get_state_groups_from_groups(self, groups, state_filter): - """Returns the state groups for a given set of groups, filtering on - types of state events. - - Args: - groups(list[int]): list of state group IDs to query - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - results = {} - - chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] - for chunk in chunks: - res = yield self.db.runInteraction( - "_get_state_groups_from_groups", - self._get_state_groups_from_groups_txn, - chunk, - state_filter, - ) - results.update(res) - - return results - - @defer.inlineCallbacks - def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): - """Given a list of event_ids and type tuples, return a list of state - dicts for each event. - - Args: - event_ids (list[string]) - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - deferred: A dict of (event_id) -> (type, state_key) -> [state_events] - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - state_event_map = yield self.get_events( - [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], - get_prev_content=False, - ) - - event_to_state = { - event_id: { - k: state_event_map[v] - for k, v in iteritems(group_to_state[group]) - if v in state_event_map - } - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): - """ - Get the state dicts corresponding to a list of events, containing the event_ids - of the state events (as opposed to the events themselves) - - Args: - event_ids(list(str)): events whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from event_id -> (type, state_key) -> event_id - """ - event_to_groups = yield self._get_state_group_for_events(event_ids) - - groups = set(itervalues(event_to_groups)) - group_to_state = yield self._get_state_for_groups(groups, state_filter) - - event_to_state = { - event_id: group_to_state[group] - for event_id, group in iteritems(event_to_groups) - } - - return {event: event_to_state[event] for event in event_ids} - - @defer.inlineCallbacks - def get_state_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_for_events([event_id], state_filter) - return state_map[event_id] - - @defer.inlineCallbacks - def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): - """ - Get the state dict corresponding to a particular event - - Args: - event_id(str): event whose state should be returned - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - A deferred dict from (type, state_key) -> state_event - """ - state_map = yield self.get_state_ids_for_events([event_id], state_filter) - return state_map[event_id] - @cached(max_entries=50000) def _get_state_group_for_event(self, event_id): return self.db.simple_select_one_onecol( @@ -684,329 +265,6 @@ class StateGroupWorkerStore( return {row["event_id"]: row["state_group"] for row in rows} - def _get_state_for_group_using_cache(self, cache, group, state_filter): - """Checks if group is in cache. See `_get_state_for_groups` - - Args: - cache(DictionaryCache): the state group cache to use - group(int): The state group to lookup - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns 2-tuple (`state_dict`, `got_all`). - `got_all` is a bool indicating if we successfully retrieved all - requests state from the cache, if False we need to query the DB for the - missing state. - """ - is_all, known_absent, state_dict_ids = cache.get(group) - - if is_all or state_filter.is_full(): - # Either we have everything or want everything, either way - # `is_all` tells us whether we've gotten everything. - return state_filter.filter_state(state_dict_ids), is_all - - # tracks whether any of our requested types are missing from the cache - missing_types = False - - if state_filter.has_wildcards(): - # We don't know if we fetched all the state keys for the types in - # the filter that are wildcards, so we have to assume that we may - # have missed some. - missing_types = True - else: - # There aren't any wild cards, so `concrete_types()` returns the - # complete list of event types we're wanting. - for key in state_filter.concrete_types(): - if key not in state_dict_ids and key not in known_absent: - missing_types = True - break - - return state_filter.filter_state(state_dict_ids), not missing_types - - @defer.inlineCallbacks - def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - state_filter (StateFilter): The state filter used to fetch state - from the database. - Returns: - Deferred[dict[int, dict[tuple[str, str], str]]]: - dict of state_group_id -> (dict of (type, state_key) -> event id) - """ - - member_filter, non_member_filter = state_filter.get_member_split() - - # Now we look them up in the member and non-member caches - ( - non_member_state, - incomplete_groups_nm, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_cache, state_filter=non_member_filter - ) - - ( - member_state, - incomplete_groups_m, - ) = yield self._get_state_for_groups_using_cache( - groups, self._state_group_members_cache, state_filter=member_filter - ) - - state = dict(non_member_state) - for group in groups: - state[group].update(member_state[group]) - - # Now fetch any missing groups from the database - - incomplete_groups = incomplete_groups_m | incomplete_groups_nm - - if not incomplete_groups: - return state - - cache_sequence_nm = self._state_group_cache.sequence - cache_sequence_m = self._state_group_members_cache.sequence - - # Help the cache hit ratio by expanding the filter a bit - db_state_filter = state_filter.return_expanded() - - group_to_state_dict = yield self._get_state_groups_from_groups( - list(incomplete_groups), state_filter=db_state_filter - ) - - # Now lets update the caches - self._insert_into_cache( - group_to_state_dict, - db_state_filter, - cache_seq_num_members=cache_sequence_m, - cache_seq_num_non_members=cache_sequence_nm, - ) - - # And finally update the result dict, by filtering out any extra - # stuff we pulled out of the database. - for group, group_state_dict in iteritems(group_to_state_dict): - # We just replace any existing entries, as we will have loaded - # everything we need from the database anyway. - state[group] = state_filter.filter_state(group_state_dict) - - return state - - def _get_state_for_groups_using_cache(self, groups, cache, state_filter): - """Gets the state at each of a list of state groups, optionally - filtering by type/state_key, querying from a specific cache. - - Args: - groups (iterable[int]): list of state groups for which we want - to get the state. - cache (DictionaryCache): the cache of group ids to state dicts which - we will pass through - either the normal state cache or the specific - members state cache. - state_filter (StateFilter): The state filter used to fetch state - from the database. - - Returns: - tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of - dict of state_group_id -> (dict of (type, state_key) -> event id) - of entries in the cache, and the state group ids either missing - from the cache or incomplete. - """ - results = {} - incomplete_groups = set() - for group in set(groups): - state_dict_ids, got_all = self._get_state_for_group_using_cache( - cache, group, state_filter - ) - results[group] = state_dict_ids - - if not got_all: - incomplete_groups.add(group) - - return results, incomplete_groups - - def _insert_into_cache( - self, - group_to_state_dict, - state_filter, - cache_seq_num_members, - cache_seq_num_non_members, - ): - """Inserts results from querying the database into the relevant cache. - - Args: - group_to_state_dict (dict): The new entries pulled from database. - Map from state group to state dict - state_filter (StateFilter): The state filter used to fetch state - from the database. - cache_seq_num_members (int): Sequence number of member cache since - last lookup in cache - cache_seq_num_non_members (int): Sequence number of member cache since - last lookup in cache - """ - - # We need to work out which types we've fetched from the DB for the - # member vs non-member caches. This should be as accurate as possible, - # but can be an underestimate (e.g. when we have wild cards) - - member_filter, non_member_filter = state_filter.get_member_split() - if member_filter.is_full(): - # We fetched all member events - member_types = None - else: - # `concrete_types()` will only return a subset when there are wild - # cards in the filter, but that's fine. - member_types = member_filter.concrete_types() - - if non_member_filter.is_full(): - # We fetched all non member events - non_member_types = None - else: - non_member_types = non_member_filter.concrete_types() - - for group, group_state_dict in iteritems(group_to_state_dict): - state_dict_members = {} - state_dict_non_members = {} - - for k, v in iteritems(group_state_dict): - if k[0] == EventTypes.Member: - state_dict_members[k] = v - else: - state_dict_non_members[k] = v - - self._state_group_members_cache.update( - cache_seq_num_members, - key=group, - value=state_dict_members, - fetched_keys=member_types, - ) - - self._state_group_cache.update( - cache_seq_num_non_members, - key=group, - value=state_dict_non_members, - fetched_keys=non_member_types, - ) - - def store_state_group( - self, event_id, room_id, prev_group, delta_ids, current_state_ids - ): - """Store a new set of state, returning a newly assigned state group. - - Args: - event_id (str): The event ID for which the state was calculated - room_id (str) - prev_group (int|None): A previous state group for the room, optional. - delta_ids (dict|None): The delta between state at `prev_group` and - `current_state_ids`, if `prev_group` was given. Same format as - `current_state_ids`. - current_state_ids (dict): The state to store. Map of (type, state_key) - to event_id. - - Returns: - Deferred[int]: The state group ID - """ - - def _store_state_group_txn(txn): - if current_state_ids is None: - # AFAIK, this can never happen - raise Exception("current_state_ids cannot be None") - - state_group = self.database_engine.get_next_state_group_id(txn) - - self.db.simple_insert_txn( - txn, - table="state_groups", - values={"id": state_group, "room_id": room_id, "event_id": event_id}, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if prev_group: - is_in_db = self.db.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, - ) - if not is_in_db: - raise Exception( - "Trying to persist state with unpersisted prev_group: %r" - % (prev_group,) - ) - - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={"state_group": state_group, "prev_state_group": prev_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_ids) - ], - ) - else: - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(current_state_ids) - ], - ) - - # Prefill the state group caches with this group. - # It's fine to use the sequence like this as the state group map - # is immutable. (If the map wasn't immutable then this prefill could - # race with another update) - - current_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] == EventTypes.Member - } - txn.call_after( - self._state_group_members_cache.update, - self._state_group_members_cache.sequence, - key=state_group, - value=dict(current_member_state_ids), - ) - - current_non_member_state_ids = { - s: ev - for (s, ev) in iteritems(current_state_ids) - if s[0] != EventTypes.Member - } - txn.call_after( - self._state_group_cache.update, - self._state_group_cache.sequence, - key=state_group, - value=dict(current_non_member_state_ids), - ) - - return state_group - - return self.db.runInteraction("store_state_group", _store_state_group_txn) - @defer.inlineCallbacks def get_referenced_state_groups(self, state_groups): """Check if the state groups are referenced by events. @@ -1031,22 +289,14 @@ class StateGroupWorkerStore( return set(row["state_group"] for row in rows) -class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): +class MainStateBackgroundUpdateStore(SQLBaseStore): - STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" - STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" def __init__(self, database: Database, db_conn, hs): - super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, - ) - self.db.updates.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state - ) + super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_index_update( self.CURRENT_STATE_INDEX_UPDATE_NAME, index_name="current_state_events_member_index", @@ -1061,181 +311,8 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): columns=["state_group"], ) - @defer.inlineCallbacks - def _background_deduplicate_state(self, progress, batch_size): - """This background update will slowly deduplicate state by reencoding - them as deltas. - """ - last_state_group = progress.get("last_state_group", 0) - rows_inserted = progress.get("rows_inserted", 0) - max_group = progress.get("max_group", None) - BATCH_SIZE_SCALE_FACTOR = 100 - - batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) - - if max_group is None: - rows = yield self.db.execute( - "_background_deduplicate_state", - None, - "SELECT coalesce(max(id), 0) FROM state_groups", - ) - max_group = rows[0][0] - - def reindex_txn(txn): - new_last_state_group = last_state_group - for count in range(batch_size): - txn.execute( - "SELECT id, room_id FROM state_groups" - " WHERE ? < id AND id <= ?" - " ORDER BY id ASC" - " LIMIT 1", - (new_last_state_group, max_group), - ) - row = txn.fetchone() - if row: - state_group, room_id = row - - if not row or not state_group: - return True, count - - txn.execute( - "SELECT state_group FROM state_group_edges" - " WHERE state_group = ?", - (state_group,), - ) - - # If we reach a point where we've already started inserting - # edges we should stop. - if txn.fetchall(): - return True, count - - txn.execute( - "SELECT coalesce(max(id), 0) FROM state_groups" - " WHERE id < ? AND room_id = ?", - (state_group, room_id), - ) - (prev_group,) = txn.fetchone() - new_last_state_group = state_group - - if prev_group: - potential_hops = self._count_state_group_hops_txn(txn, prev_group) - if potential_hops >= MAX_STATE_DELTA_HOPS: - # We want to ensure chains are at most this long,# - # otherwise read performance degrades. - continue - - prev_state = self._get_state_groups_from_groups_txn( - txn, [prev_group] - ) - prev_state = prev_state[prev_group] - - curr_state = self._get_state_groups_from_groups_txn( - txn, [state_group] - ) - curr_state = curr_state[state_group] - - if not set(prev_state.keys()) - set(curr_state.keys()): - # We can only do a delta if the current has a strict super set - # of keys - - delta_state = { - key: value - for key, value in iteritems(curr_state) - if prev_state.get(key, None) != value - } - - self.db.simple_delete_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": state_group, - "prev_state_group": prev_group, - }, - ) - - self.db.simple_delete_txn( - txn, - table="state_groups_state", - keyvalues={"state_group": state_group}, - ) - - self.db.simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": state_group, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in iteritems(delta_state) - ], - ) - - progress = { - "last_state_group": state_group, - "rows_inserted": rows_inserted + batch_size, - "max_group": max_group, - } - - self.db.updates._background_update_progress_txn( - txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress - ) - - return False, batch_size - - finished, result = yield self.db.runInteraction( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn - ) - - if finished: - yield self.db.updates._end_background_update( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME - ) - - return result * BATCH_SIZE_SCALE_FACTOR - - @defer.inlineCallbacks - def _background_index_state(self, progress, batch_size): - def reindex_txn(conn): - conn.rollback() - if isinstance(self.database_engine, PostgresEngine): - # postgres insists on autocommit for the index - conn.set_session(autocommit=True) - try: - txn = conn.cursor() - txn.execute( - "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - finally: - conn.set_session(autocommit=False) - else: - txn = conn.cursor() - txn.execute( - "CREATE INDEX state_groups_state_type_idx" - " ON state_groups_state(state_group, type, state_key)" - ) - txn.execute("DROP INDEX IF EXISTS state_groups_state_id") - - yield self.db.runWithConnection(reindex_txn) - - yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) - - return 1 - - -class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore): +class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): """ Keeps track of the state at a given event. This is done by the concept of `state groups`. Every event is a assigned diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py new file mode 100644 index 0000000000..86e09f6229 --- /dev/null +++ b/synapse/storage/data_stores/state/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401 diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py new file mode 100644 index 0000000000..e8edaf9f7b --- /dev/null +++ b/synapse/storage/data_stores/state/bg_updates.py @@ -0,0 +1,374 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from six import iteritems + +from twisted.internet import defer + +from synapse.storage._base import SQLBaseStore +from synapse.storage.database import Database +from synapse.storage.engines import PostgresEngine +from synapse.storage.state import StateFilter + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class StateGroupBackgroundUpdateStore(SQLBaseStore): + """Defines functions related to state groups needed to run the state backgroud + updates. + """ + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """ + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + def _get_state_groups_from_groups_txn( + self, txn, groups, state_filter=StateFilter.all() + ): + results = {group: {} for group in groups} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT type, state_key, last_value(event_id) OVER ( + PARTITION BY type, state_key ORDER BY state_group ASC + ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING + ) AS event_id FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) + """ + + for group in groups: + args = [group] + args.extend(where_args) + + txn.execute(sql + where_clause, args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + for group in groups: + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in results[group] + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(results[group]) == max_entries_returned + ): + break + + next_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return results + + +class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore): + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx" + + def __init__(self, database: Database, db_conn, hs): + super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.db.updates.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state + ) + self.db.updates.register_background_index_update( + self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME, + index_name="state_groups_room_id_idx", + table="state_groups", + columns=["room_id"], + ) + + @defer.inlineCallbacks + def _background_deduplicate_state(self, progress, batch_size): + """This background update will slowly deduplicate state by reencoding + them as deltas. + """ + last_state_group = progress.get("last_state_group", 0) + rows_inserted = progress.get("rows_inserted", 0) + max_group = progress.get("max_group", None) + + BATCH_SIZE_SCALE_FACTOR = 100 + + batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR)) + + if max_group is None: + rows = yield self.db.execute( + "_background_deduplicate_state", + None, + "SELECT coalesce(max(id), 0) FROM state_groups", + ) + max_group = rows[0][0] + + def reindex_txn(txn): + new_last_state_group = last_state_group + for count in range(batch_size): + txn.execute( + "SELECT id, room_id FROM state_groups" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC" + " LIMIT 1", + (new_last_state_group, max_group), + ) + row = txn.fetchone() + if row: + state_group, room_id = row + + if not row or not state_group: + return True, count + + txn.execute( + "SELECT state_group FROM state_group_edges" + " WHERE state_group = ?", + (state_group,), + ) + + # If we reach a point where we've already started inserting + # edges we should stop. + if txn.fetchall(): + return True, count + + txn.execute( + "SELECT coalesce(max(id), 0) FROM state_groups" + " WHERE id < ? AND room_id = ?", + (state_group, room_id), + ) + (prev_group,) = txn.fetchone() + new_last_state_group = state_group + + if prev_group: + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if potential_hops >= MAX_STATE_DELTA_HOPS: + # We want to ensure chains are at most this long,# + # otherwise read performance degrades. + continue + + prev_state = self._get_state_groups_from_groups_txn( + txn, [prev_group] + ) + prev_state = prev_state[prev_group] + + curr_state = self._get_state_groups_from_groups_txn( + txn, [state_group] + ) + curr_state = curr_state[state_group] + + if not set(prev_state.keys()) - set(curr_state.keys()): + # We can only do a delta if the current has a strict super set + # of keys + + delta_state = { + key: value + for key, value in iteritems(curr_state) + if prev_state.get(key, None) != value + } + + self.db.simple_delete_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self.db.simple_delete_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_state) + ], + ) + + progress = { + "last_state_group": state_group, + "rows_inserted": rows_inserted + batch_size, + "max_group": max_group, + } + + self.db.updates._background_update_progress_txn( + txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress + ) + + return False, batch_size + + finished, result = yield self.db.runInteraction( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn + ) + + if finished: + yield self.db.updates._end_background_update( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME + ) + + return result * BATCH_SIZE_SCALE_FACTOR + + @defer.inlineCallbacks + def _background_index_state(self, progress, batch_size): + def reindex_txn(conn): + conn.rollback() + if isinstance(self.database_engine, PostgresEngine): + # postgres insists on autocommit for the index + conn.set_session(autocommit=True) + try: + txn = conn.cursor() + txn.execute( + "CREATE INDEX CONCURRENTLY state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + finally: + conn.set_session(autocommit=False) + else: + txn = conn.cursor() + txn.execute( + "CREATE INDEX state_groups_state_type_idx" + " ON state_groups_state(state_group, type, state_key)" + ) + txn.execute("DROP INDEX IF EXISTS state_groups_state_id") + + yield self.db.runWithConnection(reindex_txn) + + yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME) + + return 1 diff --git a/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql rename to synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/30/state_stream.sql rename to synapse/storage/data_stores/state/schema/delta/30/state_stream.sql diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql new file mode 100644 index 0000000000..1450313bfa --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql @@ -0,0 +1,19 @@ +/* Copyright 2016 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- The following indices are redundant, other indices are equivalent or +-- supersets +DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY diff --git a/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql rename to synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/state.sql rename to synapse/storage/data_stores/state/schema/delta/35/state.sql diff --git a/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql rename to synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql diff --git a/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py similarity index 100% rename from synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py rename to synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql new file mode 100644 index 0000000000..7916ef18b2 --- /dev/null +++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('state_groups_room_id_idx', '{}'); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql new file mode 100644 index 0000000000..35f97d6b3d --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql @@ -0,0 +1,37 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE state_groups ( + id BIGINT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_groups_state ( + state_group BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT NOT NULL +); + +CREATE TABLE state_group_edges ( + state_group BIGINT NOT NULL, + prev_state_group BIGINT NOT NULL +); + +CREATE INDEX state_group_edges_idx ON state_group_edges (state_group); +CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group); +CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key); diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres new file mode 100644 index 0000000000..fcd926c9fb --- /dev/null +++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres @@ -0,0 +1,21 @@ +/* Copyright 2019 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE state_group_id_seq + START WITH 1 + INCREMENT BY 1 + NO MINVALUE + NO MAXVALUE + CACHE 1; diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py new file mode 100644 index 0000000000..d53695f238 --- /dev/null +++ b/synapse/storage/data_stores/state/store.py @@ -0,0 +1,640 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from collections import namedtuple + +from six import iteritems +from six.moves import range + +from twisted.internet import defer + +from synapse.api.constants import EventTypes +from synapse.storage._base import SQLBaseStore +from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore +from synapse.storage.database import Database +from synapse.storage.state import StateFilter +from synapse.util.caches import get_cache_factor_for +from synapse.util.caches.descriptors import cached +from synapse.util.caches.dictionary_cache import DictionaryCache + +logger = logging.getLogger(__name__) + + +MAX_STATE_DELTA_HOPS = 100 + + +class _GetStateGroupDelta( + namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids")) +): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + + __slots__ = [] + + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 + + +class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): + """A data store for fetching/storing state groups. + """ + + def __init__(self, database: Database, db_conn, hs): + super(StateGroupDataStore, self).__init__(database, db_conn, hs) + + # Originally the state store used a single DictionaryCache to cache the + # event IDs for the state types in a given state group to avoid hammering + # on the state_group* tables. + # + # The point of using a DictionaryCache is that it can cache a subset + # of the state events for a given state group (i.e. a subset of the keys for a + # given dict which is an entry in the cache for a given state group ID). + # + # However, this poses problems when performing complicated queries + # on the store - for instance: "give me all the state for this group, but + # limit members to this subset of users", as DictionaryCache's API isn't + # rich enough to say "please cache any of these fields, apart from this subset". + # This is problematic when lazy loading members, which requires this behaviour, + # as without it the cache has no choice but to speculatively load all + # state events for the group, which negates the efficiency being sought. + # + # Rather than overcomplicating DictionaryCache's API, we instead split the + # state_group_cache into two halves - one for tracking non-member events, + # and the other for tracking member_events. This means that lazy loading + # queries can be made in a cache-friendly manner by querying both caches + # separately and then merging the result. So for the example above, you + # would query the members cache for a specific subset of state keys + # (which DictionaryCache will handle efficiently and fine) and the non-members + # cache for all state (which DictionaryCache will similarly handle fine) + # and then just merge the results together. + # + # We size the non-members cache to be smaller than the members cache as the + # vast majority of state in Matrix (today) is member events. + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", + # TODO: this hasn't been tuned yet + 50000 * get_cache_factor_for("stateGroupCache"), + ) + self._state_group_members_cache = DictionaryCache( + "*stateGroupMembersCache*", + 500000 * get_cache_factor_for("stateGroupMembersCache"), + ) + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + + def _get_state_group_delta_txn(txn): + prev_group = self.db.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": state_group}, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self.db.simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={"state_group": state_group}, + retcols=("type", "state_key", "event_id"), + ) + + return _GetStateGroupDelta( + prev_group, + {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids}, + ) + + return self.db.runInteraction( + "get_state_group_delta", _get_state_group_delta_txn + ) + + @defer.inlineCallbacks + def _get_state_groups_from_groups(self, groups, state_filter): + """Returns the state groups for a given set of groups, filtering on + types of state events. + + Args: + groups(list[int]): list of state group IDs to query + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + results = {} + + chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] + for chunk in chunks: + res = yield self.db.runInteraction( + "_get_state_groups_from_groups", + self._get_state_groups_from_groups_txn, + chunk, + state_filter, + ) + results.update(res) + + return results + + def _get_state_for_group_using_cache(self, cache, group, state_filter): + """Checks if group is in cache. See `_get_state_for_groups` + + Args: + cache(DictionaryCache): the state group cache to use + group(int): The state group to lookup + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns 2-tuple (`state_dict`, `got_all`). + `got_all` is a bool indicating if we successfully retrieved all + requests state from the cache, if False we need to query the DB for the + missing state. + """ + is_all, known_absent, state_dict_ids = cache.get(group) + + if is_all or state_filter.is_full(): + # Either we have everything or want everything, either way + # `is_all` tells us whether we've gotten everything. + return state_filter.filter_state(state_dict_ids), is_all + + # tracks whether any of our requested types are missing from the cache + missing_types = False + + if state_filter.has_wildcards(): + # We don't know if we fetched all the state keys for the types in + # the filter that are wildcards, so we have to assume that we may + # have missed some. + missing_types = True + else: + # There aren't any wild cards, so `concrete_types()` returns the + # complete list of event types we're wanting. + for key in state_filter.concrete_types(): + if key not in state_dict_ids and key not in known_absent: + missing_types = True + break + + return state_filter.filter_state(state_dict_ids), not missing_types + + @defer.inlineCallbacks + def _get_state_for_groups(self, groups, state_filter=StateFilter.all()): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + state_filter (StateFilter): The state filter used to fetch state + from the database. + Returns: + Deferred[dict[int, dict[tuple[str, str], str]]]: + dict of state_group_id -> (dict of (type, state_key) -> event id) + """ + + member_filter, non_member_filter = state_filter.get_member_split() + + # Now we look them up in the member and non-member caches + ( + non_member_state, + incomplete_groups_nm, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_cache, state_filter=non_member_filter + ) + + ( + member_state, + incomplete_groups_m, + ) = yield self._get_state_for_groups_using_cache( + groups, self._state_group_members_cache, state_filter=member_filter + ) + + state = dict(non_member_state) + for group in groups: + state[group].update(member_state[group]) + + # Now fetch any missing groups from the database + + incomplete_groups = incomplete_groups_m | incomplete_groups_nm + + if not incomplete_groups: + return state + + cache_sequence_nm = self._state_group_cache.sequence + cache_sequence_m = self._state_group_members_cache.sequence + + # Help the cache hit ratio by expanding the filter a bit + db_state_filter = state_filter.return_expanded() + + group_to_state_dict = yield self._get_state_groups_from_groups( + list(incomplete_groups), state_filter=db_state_filter + ) + + # Now lets update the caches + self._insert_into_cache( + group_to_state_dict, + db_state_filter, + cache_seq_num_members=cache_sequence_m, + cache_seq_num_non_members=cache_sequence_nm, + ) + + # And finally update the result dict, by filtering out any extra + # stuff we pulled out of the database. + for group, group_state_dict in iteritems(group_to_state_dict): + # We just replace any existing entries, as we will have loaded + # everything we need from the database anyway. + state[group] = state_filter.filter_state(group_state_dict) + + return state + + def _get_state_for_groups_using_cache(self, groups, cache, state_filter): + """Gets the state at each of a list of state groups, optionally + filtering by type/state_key, querying from a specific cache. + + Args: + groups (iterable[int]): list of state groups for which we want + to get the state. + cache (DictionaryCache): the cache of group ids to state dicts which + we will pass through - either the normal state cache or the specific + members state cache. + state_filter (StateFilter): The state filter used to fetch state + from the database. + + Returns: + tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of + dict of state_group_id -> (dict of (type, state_key) -> event id) + of entries in the cache, and the state group ids either missing + from the cache or incomplete. + """ + results = {} + incomplete_groups = set() + for group in set(groups): + state_dict_ids, got_all = self._get_state_for_group_using_cache( + cache, group, state_filter + ) + results[group] = state_dict_ids + + if not got_all: + incomplete_groups.add(group) + + return results, incomplete_groups + + def _insert_into_cache( + self, + group_to_state_dict, + state_filter, + cache_seq_num_members, + cache_seq_num_non_members, + ): + """Inserts results from querying the database into the relevant cache. + + Args: + group_to_state_dict (dict): The new entries pulled from database. + Map from state group to state dict + state_filter (StateFilter): The state filter used to fetch state + from the database. + cache_seq_num_members (int): Sequence number of member cache since + last lookup in cache + cache_seq_num_non_members (int): Sequence number of member cache since + last lookup in cache + """ + + # We need to work out which types we've fetched from the DB for the + # member vs non-member caches. This should be as accurate as possible, + # but can be an underestimate (e.g. when we have wild cards) + + member_filter, non_member_filter = state_filter.get_member_split() + if member_filter.is_full(): + # We fetched all member events + member_types = None + else: + # `concrete_types()` will only return a subset when there are wild + # cards in the filter, but that's fine. + member_types = member_filter.concrete_types() + + if non_member_filter.is_full(): + # We fetched all non member events + non_member_types = None + else: + non_member_types = non_member_filter.concrete_types() + + for group, group_state_dict in iteritems(group_to_state_dict): + state_dict_members = {} + state_dict_non_members = {} + + for k, v in iteritems(group_state_dict): + if k[0] == EventTypes.Member: + state_dict_members[k] = v + else: + state_dict_non_members[k] = v + + self._state_group_members_cache.update( + cache_seq_num_members, + key=group, + value=state_dict_members, + fetched_keys=member_types, + ) + + self._state_group_cache.update( + cache_seq_num_non_members, + key=group, + value=state_dict_non_members, + fetched_keys=non_member_types, + ) + + def store_state_group( + self, event_id, room_id, prev_group, delta_ids, current_state_ids + ): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self.db.simple_insert_txn( + txn, + table="state_groups", + values={"id": state_group, "room_id": room_id, "event_id": event_id}, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self.db.simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn(txn, prev_group) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self.db.simple_insert_txn( + txn, + table="state_group_edges", + values={"state_group": state_group, "prev_state_group": prev_group}, + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(delta_ids) + ], + ) + else: + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(current_state_ids) + ], + ) + + # Prefill the state group caches with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + + current_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] == EventTypes.Member + } + txn.call_after( + self._state_group_members_cache.update, + self._state_group_members_cache.sequence, + key=state_group, + value=dict(current_member_state_ids), + ) + + current_non_member_state_ids = { + s: ev + for (s, ev) in iteritems(current_state_ids) + if s[0] != EventTypes.Member + } + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_non_member_state_ids), + ) + + return state_group + + return self.db.runInteraction("store_state_group", _store_state_group_txn) + + def purge_unreferenced_state_groups( + self, room_id: str, state_groups_to_delete + ) -> defer.Deferred: + """Deletes no longer referenced state groups and de-deltas any state + groups that reference them. + + Args: + room_id: The room the state groups belong to (must all be in the + same room). + state_groups_to_delete (Collection[int]): Set of all state groups + to delete. + """ + + return self.db.runInteraction( + "purge_unreferenced_state_groups", + self._purge_unreferenced_state_groups, + room_id, + state_groups_to_delete, + ) + + def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete): + logger.info( + "[purge] found %i state groups to delete", len(state_groups_to_delete) + ) + + rows = self.db.simple_select_many_txn( + txn, + table="state_group_edges", + column="prev_state_group", + iterable=state_groups_to_delete, + keyvalues={}, + retcols=("state_group",), + ) + + remaining_state_groups = set( + row["state_group"] + for row in rows + if row["state_group"] not in state_groups_to_delete + ) + + logger.info( + "[purge] de-delta-ing %i remaining state groups", + len(remaining_state_groups), + ) + + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) + curr_state = self._get_state_groups_from_groups_txn(txn, [sg]) + curr_state = curr_state[sg] + + self.db.simple_delete_txn( + txn, table="state_groups_state", keyvalues={"state_group": sg} + ) + + self.db.simple_delete_txn( + txn, table="state_group_edges", keyvalues={"state_group": sg} + ) + + self.db.simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": sg, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in iteritems(curr_state) + ], + ) + + logger.info("[purge] removing redundant state groups") + txn.executemany( + "DELETE FROM state_groups_state WHERE state_group = ?", + ((sg,) for sg in state_groups_to_delete), + ) + txn.executemany( + "DELETE FROM state_groups WHERE id = ?", + ((sg,) for sg in state_groups_to_delete), + ) + + @defer.inlineCallbacks + def get_previous_state_groups(self, state_groups): + """Fetch the previous groups of the given state groups. + + Args: + state_groups (Iterable[int]) + + Returns: + Deferred[dict[int, int]]: mapping from state group to previous + state group. + """ + + rows = yield self.db.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("prev_state_group", "state_group"), + desc="get_previous_state_groups", + ) + + return {row["state_group"]: row["prev_state_group"] for row in rows} + + def purge_room_state(self, room_id, state_groups_to_delete): + """Deletes all record of a room from state tables + + Args: + room_id (str): + state_groups_to_delete (list[int]): State groups to delete + """ + + return self.db.runInteraction( + "purge_room_state", + self._purge_room_state_txn, + room_id, + state_groups_to_delete, + ) + + def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete): + # first we have to delete the state groups states + logger.info("[purge] removing %s from state_groups_state", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups_state", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state group edges + logger.info("[purge] removing %s from state_group_edges", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_group_edges", + column="state_group", + iterable=state_groups_to_delete, + keyvalues={}, + ) + + # ... and the state groups + logger.info("[purge] removing %s from state_groups", room_id) + + self.db.simple_delete_many_txn( + txn, + table="state_groups", + column="id", + iterable=state_groups_to_delete, + keyvalues={}, + ) diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index fa03ca9ff7..1ed44925fc 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -183,7 +183,7 @@ class EventsPersistenceStorage(object): # so we use separate variables here even though they point to the same # store for now. self.main_store = stores.main - self.state_store = stores.main + self.state_store = stores.state self._clock = hs.get_clock() self.is_mine_id = hs.is_mine_id diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 403848ad03..e70026b80a 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -42,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn, database_engine, config, data_stores=["main"]): +def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py index a368182034..d6a7bd7834 100644 --- a/synapse/storage/purge_events.py +++ b/synapse/storage/purge_events.py @@ -58,7 +58,7 @@ class PurgeEventsStorage(object): sg_to_delete = yield self._find_unreferenced_groups(state_groups) - yield self.stores.main.purge_unreferenced_state_groups(room_id, sg_to_delete) + yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) @defer.inlineCallbacks def _find_unreferenced_groups(self, state_groups): @@ -102,7 +102,7 @@ class PurgeEventsStorage(object): # groups that are referenced. current_search -= referenced - edges = yield self.stores.main.get_previous_state_groups(current_search) + edges = yield self.stores.state.get_previous_state_groups(current_search) prevs = set(edges.values()) # We don't bother re-handling groups we've already seen diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 3735846899..cbeb586014 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -342,7 +342,7 @@ class StateGroupStorage(object): (prev_group, delta_ids) """ - return self.stores.main.get_state_group_delta(state_group) + return self.stores.state.get_state_group_delta(state_group) @defer.inlineCallbacks def get_state_groups_ids(self, _room_id, event_ids): @@ -362,7 +362,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups(groups) + group_to_state = yield self.stores.state._get_state_for_groups(groups) return group_to_state @@ -423,7 +423,7 @@ class StateGroupStorage(object): dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_groups_from_groups(groups, state_filter) + return self.stores.state._get_state_groups_from_groups(groups, state_filter) @defer.inlineCallbacks def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): @@ -439,7 +439,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -476,7 +476,7 @@ class StateGroupStorage(object): event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) groups = set(itervalues(event_to_groups)) - group_to_state = yield self.stores.main._get_state_for_groups( + group_to_state = yield self.stores.state._get_state_for_groups( groups, state_filter ) @@ -532,7 +532,7 @@ class StateGroupStorage(object): Deferred[dict[int, dict[tuple[str, str], str]]]: dict of state_group_id -> (dict of (type, state_key) -> event id) """ - return self.stores.main._get_state_for_groups(groups, state_filter) + return self.stores.state._get_state_for_groups(groups, state_filter) def store_state_group( self, event_id, room_id, prev_group, delta_ids, current_state_ids @@ -552,6 +552,6 @@ class StateGroupStorage(object): Returns: Deferred[int]: The state group ID """ - return self.stores.main.store_state_group( + return self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 43200654f1..d6ecf102f8 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -35,7 +35,7 @@ class StateStoreTestCase(tests.unittest.TestCase): self.store = hs.get_datastore() self.storage = hs.get_storage() - self.state_datastore = self.store + self.state_datastore = self.storage.state.stores.state self.event_builder_factory = hs.get_event_builder_factory() self.event_creation_handler = hs.get_event_creation_handler() diff --git a/tests/utils.py b/tests/utils.py index 9f5bf40b4b..e2e9cafd79 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -231,7 +231,7 @@ def setup_test_homeserver( "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1}, } - database = DatabaseConnectionConfig("master", database_config, ["main"]) + database = DatabaseConnectionConfig("master", database_config) config.database.databases = [database] db_engine = create_engine(database.config) From 7c6b853558feac49f1379b08a4fffb1aff1da481 Mon Sep 17 00:00:00 2001 From: dopple <54467709+doppled@users.noreply.github.com> Date: Sun, 22 Dec 2019 14:16:56 -0800 Subject: [PATCH 69/82] Update reverse proxy file name (#6590) s/reverse_proxy.rst/reverse_proxy.md/ --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index ae51d6ab39..2691dfc23d 100644 --- a/README.rst +++ b/README.rst @@ -393,4 +393,4 @@ something like the following in their logs:: 2019-09-11 19:32:04,271 - synapse.federation.transport.server - 288 - WARNING - GET-11752 - authenticate_request failed: 401: Invalid signature for server with key ed25519:a_EqML: Unable to verify signature for This is normally caused by a misconfiguration in your reverse-proxy. See -``_ and double-check that your settings are correct. +``_ and double-check that your settings are correct. From 32779b59fab0b4a64198f8fe617d7c495aeb1ede Mon Sep 17 00:00:00 2001 From: Aaron Raimist Date: Thu, 2 Jan 2020 04:28:20 -0600 Subject: [PATCH 70/82] Reword sections of federate.md that explained delegation at time of Synapse 1.0 transition (#6601) * Remove sections of federate.md explaining delegation at time of Synapse 1.0 transition Signed-off-by: Aaron Raimist * Add changelog Signed-off-by: Aaron Raimist --- changelog.d/6601.doc | 1 + docs/federate.md | 24 +++--------------------- 2 files changed, 4 insertions(+), 21 deletions(-) create mode 100644 changelog.d/6601.doc diff --git a/changelog.d/6601.doc b/changelog.d/6601.doc new file mode 100644 index 0000000000..08c5b3d215 --- /dev/null +++ b/changelog.d/6601.doc @@ -0,0 +1 @@ +Reword sections of federate.md that explained delegation at time of Synapse 1.0 transition. \ No newline at end of file diff --git a/docs/federate.md b/docs/federate.md index 193e2d2dfe..f9f17fcca5 100644 --- a/docs/federate.md +++ b/docs/federate.md @@ -66,10 +66,6 @@ therefore cannot gain access to the necessary certificate. With .well-known, federation servers will check for a valid TLS certificate for the delegated hostname (in our example: ``synapse.example.com``). -.well-known support first appeared in Synapse v0.99.0. To federate with older -servers you may need to additionally configure SRV delegation. Alternatively, -encourage the server admin in question to upgrade :). - ### DNS SRV delegation To use this delegation method, you need to have write access to your @@ -111,29 +107,15 @@ giving it a `server_name` of `example.com`, and once [ACME](acme.md) support is it would automatically generate a valid TLS certificate for you via Let's Encrypt and no SRV record or .well-known URI would be needed. -This is the common case, although you can add an SRV record or -`.well-known/matrix/server` URI for completeness if you wish. - **However**, if your server does not listen on port 8448, or if your `server_name` does not point to the host that your homeserver runs on, you will need to let other servers know how to find it. The way to do this is via .well-known or an SRV record. -#### I have created a .well-known URI. Do I still need an SRV record? +#### I have created a .well-known URI. Do I also need an SRV record? -As of Synapse 0.99, Synapse will first check for the existence of a .well-known -URI and follow any delegation it suggests. It will only then check for the -existence of an SRV record. - -That means that the SRV record will often be redundant. However, you should -remember that there may still be older versions of Synapse in the federation -which do not understand .well-known URIs, so if you removed your SRV record -you would no longer be able to federate with them. - -It is therefore best to leave the SRV record in place for now. Synapse 0.34 and -earlier will follow the SRV record (and not care about the invalid -certificate). Synapse 0.99 and later will follow the .well-known URI, with the -correct certificate chain. +No. You can use either `.well-known` delegation or use an SRV record for delegation. You +do not need to use both to delegate to the same location. #### Can I manage my own certificates rather than having Synapse renew certificates itself? From 0495097a7f6c6f5230972e0a5f32c0c9c42ef61b Mon Sep 17 00:00:00 2001 From: ewaf1 <59422220+ewaf1@users.noreply.github.com> Date: Thu, 2 Jan 2020 11:41:30 +0100 Subject: [PATCH 71/82] Added the section 'Configuration' in /docs/turn-howto.md (#6614) put the 2nd part of the "source installation"-section into a new section, because it also applies to Debian packages --- changelog.d/6614.doc | 1 + docs/turn-howto.md | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 changelog.d/6614.doc diff --git a/changelog.d/6614.doc b/changelog.d/6614.doc new file mode 100644 index 0000000000..38b962b062 --- /dev/null +++ b/changelog.d/6614.doc @@ -0,0 +1 @@ +Added the section 'Configuration' in /docs/turn-howto.md. diff --git a/docs/turn-howto.md b/docs/turn-howto.md index 4a983621e5..1bd3943f54 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md @@ -39,6 +39,8 @@ The TURN daemon `coturn` is available from a variety of sources such as native p make make install +### Configuration + 1. Create or edit the config file in `/etc/turnserver.conf`. The relevant lines, with example values, are: From 6964ea095bbd474b5c2b9dfe99c817cd370987bf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 14:19:09 +0000 Subject: [PATCH 72/82] Reduce the reconnect time when replication fails. (#6617) --- changelog.d/6617.misc | 1 + synapse/replication/tcp/client.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/6617.misc diff --git a/changelog.d/6617.misc b/changelog.d/6617.misc new file mode 100644 index 0000000000..94aa271d38 --- /dev/null +++ b/changelog.d/6617.misc @@ -0,0 +1 @@ +Reduce the reconnect time when worker replication fails, to make it easier to catch up. diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index fead78388c..bbcb84646c 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -46,7 +46,8 @@ class ReplicationClientFactory(ReconnectingClientFactory): is required. """ - maxDelay = 30 # Try at least once every N seconds + initialDelay = 0.1 + maxDelay = 1 # Try at least once every N seconds def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler): self.client_name = client_name From b6b57ecb4e845490fc26a537ff57df8cae1587b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 14:19:48 +0000 Subject: [PATCH 73/82] Kill off redundant SynapseRequestFactory (#6619) We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. --- changelog.d/6619.misc | 1 + synapse/http/site.py | 18 +++--------------- tests/server.py | 6 ++++-- 3 files changed, 8 insertions(+), 17 deletions(-) create mode 100644 changelog.d/6619.misc diff --git a/changelog.d/6619.misc b/changelog.d/6619.misc new file mode 100644 index 0000000000..b608133219 --- /dev/null +++ b/changelog.d/6619.misc @@ -0,0 +1 @@ +Simplify http handling by removing redundant SynapseRequestFactory. diff --git a/synapse/http/site.py b/synapse/http/site.py index ff8184a3d0..9f2d035fa0 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -47,9 +47,9 @@ class SynapseRequest(Request): logcontext(LoggingContext) : the log context for this request """ - def __init__(self, site, channel, *args, **kw): + def __init__(self, channel, *args, **kw): Request.__init__(self, channel, *args, **kw) - self.site = site + self.site = channel.site self._channel = channel # this is used by the tests self.authenticated_entity = None self.start_time = 0 @@ -331,18 +331,6 @@ class XForwardedForRequest(SynapseRequest): ) -class SynapseRequestFactory(object): - def __init__(self, site, x_forwarded_for): - self.site = site - self.x_forwarded_for = x_forwarded_for - - def __call__(self, *args, **kwargs): - if self.x_forwarded_for: - return XForwardedForRequest(self.site, *args, **kwargs) - else: - return SynapseRequest(self.site, *args, **kwargs) - - class SynapseSite(Site): """ Subclass of a twisted http Site that does access logging with python's @@ -364,7 +352,7 @@ class SynapseSite(Site): self.site_tag = site_tag proxied = config.get("x_forwarded", False) - self.requestFactory = SynapseRequestFactory(self, proxied) + self.requestFactory = XForwardedForRequest if proxied else SynapseRequest self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") diff --git a/tests/server.py b/tests/server.py index a554dfdd57..1644710aa0 100644 --- a/tests/server.py +++ b/tests/server.py @@ -20,6 +20,7 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers +from twisted.web.server import Site from synapse.http.site import SynapseRequest from synapse.util import Clock @@ -42,6 +43,7 @@ class FakeChannel(object): wire). """ + site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @@ -176,9 +178,9 @@ def make_request( content = content.encode("utf8") site = FakeSite() - channel = FakeChannel(reactor) + channel = FakeChannel(site, reactor) - req = request(site, channel) + req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) From 98247c4a0e169ee5f201fe5f0e404604d6628566 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 17:10:52 +0000 Subject: [PATCH 74/82] Remove unused, undocumented "content repo" resource (#6628) This looks like it got half-killed back in #888. Fixes #6567. --- changelog.d/6628.removal | 1 + docs/sample_config.yaml | 4 - synapse/api/urls.py | 1 - synapse/app/homeserver.py | 10 +- synapse/app/media_repository.py | 6 +- synapse/config/repository.py | 5 - synapse/rest/media/v0/__init__.py | 0 synapse/rest/media/v0/content_repository.py | 103 -------------------- tox.ini | 1 - 9 files changed, 3 insertions(+), 128 deletions(-) create mode 100644 changelog.d/6628.removal delete mode 100644 synapse/rest/media/v0/__init__.py delete mode 100644 synapse/rest/media/v0/content_repository.py diff --git a/changelog.d/6628.removal b/changelog.d/6628.removal new file mode 100644 index 0000000000..66cd6aeca4 --- /dev/null +++ b/changelog.d/6628.removal @@ -0,0 +1 @@ +Remove unused, undocumented /_matrix/content API. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index e3b05423b8..fad5f968b5 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -692,10 +692,6 @@ media_store_path: "DATADIR/media_store" # config: # directory: /mnt/some/other/directory -# Directory where in-progress uploads are stored. -# -uploads_path: "DATADIR/uploads" - # The largest allowed upload size in bytes # #max_upload_size: 10M diff --git a/synapse/api/urls.py b/synapse/api/urls.py index ff1f39e86c..f34434bd67 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -29,7 +29,6 @@ FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2" FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable" STATIC_PREFIX = "/_matrix/static" WEB_CLIENT_PREFIX = "/_matrix/client" -CONTENT_REPO_PREFIX = "/_matrix/content" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" MEDIA_PREFIX = "/_matrix/media/r0" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 0e9bf7f53a..6208deb646 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -39,7 +39,6 @@ import synapse import synapse.config.logger from synapse import events from synapse.api.urls import ( - CONTENT_REPO_PREFIX, FEDERATION_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, @@ -65,7 +64,6 @@ from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.key.v2 import KeyApiV2Resource -from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -223,13 +221,7 @@ class SynapseHomeServer(HomeServer): if self.get_config().enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( - { - MEDIA_PREFIX: media_repo, - LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), - } + {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} ) elif name == "media": raise ConfigError( diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py index 4c80f257e2..a63c53dc44 100644 --- a/synapse/app/media_repository.py +++ b/synapse/app/media_repository.py @@ -21,7 +21,7 @@ from twisted.web.resource import NoResource import synapse from synapse import events -from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX +from synapse.api.urls import LEGACY_MEDIA_PREFIX, MEDIA_PREFIX from synapse.app import _base from synapse.config._base import ConfigError from synapse.config.homeserver import HomeServerConfig @@ -37,7 +37,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto from synapse.replication.slave.storage.transactions import SlavedTransactionStore from synapse.replication.tcp.client import ReplicationClientHandler from synapse.rest.admin import register_servlets_for_media_repo -from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.server import HomeServer from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore from synapse.util.httpresourcetree import create_resource_tree @@ -82,9 +81,6 @@ class MediaRepositoryServer(HomeServer): { MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo, - CONTENT_REPO_PREFIX: ContentRepoResource( - self, self.config.uploads_path - ), "/_synapse/admin": admin_resource, } ) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index d0205e14b9..7d2dd27fd0 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -156,7 +156,6 @@ class ContentRepositoryConfig(Config): (provider_class, parsed_config, wrapper_config) ) - self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads")) self.dynamic_thumbnails = config.get("dynamic_thumbnails", False) self.thumbnail_requirements = parse_thumbnail_requirements( config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES) @@ -231,10 +230,6 @@ class ContentRepositoryConfig(Config): # config: # directory: /mnt/some/other/directory - # Directory where in-progress uploads are stored. - # - uploads_path: "%(uploads_path)s" - # The largest allowed upload size in bytes # #max_upload_size: 10M diff --git a/synapse/rest/media/v0/__init__.py b/synapse/rest/media/v0/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py deleted file mode 100644 index 86884c0ef4..0000000000 --- a/synapse/rest/media/v0/content_repository.py +++ /dev/null @@ -1,103 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import logging -import os -import re - -from canonicaljson import json - -from twisted.protocols.basic import FileSender -from twisted.web import resource, server - -from synapse.api.errors import Codes, cs_error -from synapse.http.server import finish_request, respond_with_json_bytes - -logger = logging.getLogger(__name__) - - -class ContentRepoResource(resource.Resource): - """Provides file uploading and downloading. - - Uploads are POSTed to wherever this Resource is linked to. This resource - returns a "content token" which can be used to GET this content again. The - token is typically a path, but it may not be. Tokens can expire, be - one-time uses, etc. - - In this case, the token is a path to the file and contains 3 interesting - sections: - - User ID base64d (for namespacing content to each user) - - random 24 char string - - Content type base64d (so we can return it when clients GET it) - - """ - - isLeaf = True - - def __init__(self, hs, directory): - resource.Resource.__init__(self) - self.hs = hs - self.directory = directory - - def render_GET(self, request): - # no auth here on purpose, to allow anyone to view, even across home - # servers. - - # TODO: A little crude here, we could do this better. - filename = request.path.decode("ascii").split("/")[-1] - # be paranoid - filename = re.sub("[^0-9A-z.-_]", "", filename) - - file_path = self.directory + "/" + filename - - logger.debug("Searching for %s", file_path) - - if os.path.isfile(file_path): - # filename has the content type - base64_contentype = filename.split(".")[1] - content_type = base64.urlsafe_b64decode(base64_contentype) - logger.info("Sending file %s", file_path) - f = open(file_path, "rb") - request.setHeader("Content-Type", content_type) - - # cache for at least a day. - # XXX: we might want to turn this off for data we don't want to - # recommend caching as it's sensitive or private - or at least - # select private. don't bother setting Expires as all our matrix - # clients are smart enough to be happy with Cache-Control (right?) - request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400") - - d = FileSender().beginFileTransfer(f, request) - - # after the file has been sent, clean up and finish the request - def cbFinished(ignored): - f.close() - finish_request(request) - - d.addCallback(cbFinished) - else: - respond_with_json_bytes( - request, - 404, - json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)), - send_cors=True, - ) - - return server.NOT_DONE_YET - - def render_OPTIONS(self, request): - respond_with_json_bytes(request, 200, {}, send_cors=True) - return server.NOT_DONE_YET diff --git a/tox.ini b/tox.ini index 1d6428f64f..0ab6d5666b 100644 --- a/tox.ini +++ b/tox.ini @@ -182,7 +182,6 @@ commands = mypy \ synapse/logging/ \ synapse/module_api \ synapse/rest/consent \ - synapse/rest/media/v0 \ synapse/rest/saml2 \ synapse/spam_checker_api \ synapse/storage/engines \ From e484101306787988adacf6d6de4fcd565368dec4 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 17:11:29 +0000 Subject: [PATCH 75/82] Raise an error if someone tries to use the log_file config option (#6626) This has caused some confusion for people who didn't notice it going away. --- changelog.d/6626.feature | 1 + synapse/app/homeserver.py | 2 +- synapse/config/logger.py | 17 +++++++++++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 changelog.d/6626.feature diff --git a/changelog.d/6626.feature b/changelog.d/6626.feature new file mode 100644 index 0000000000..15798fa59b --- /dev/null +++ b/changelog.d/6626.feature @@ -0,0 +1 @@ +Raise an error if someone tries to use the log_file config option. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 6208deb646..e5b44a5eed 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -310,7 +310,7 @@ def setup(config_options): "Synapse Homeserver", config_options ) except ConfigError as e: - sys.stderr.write("\n" + str(e) + "\n") + sys.stderr.write("\nERROR: %s\n" % (e,)) sys.exit(1) if not config: diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 75bb904718..3c455610d9 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import argparse import logging import logging.config import os @@ -37,7 +37,7 @@ from synapse.logging._structured import ( from synapse.logging.context import LoggingContextFilter from synapse.util.versionstring import get_version_string -from ._base import Config +from ._base import Config, ConfigError DEFAULT_LOG_CONFIG = Template( """ @@ -81,11 +81,18 @@ disable_existing_loggers: false """ ) +LOG_FILE_ERROR = """\ +Support for the log_file configuration option and --log-file command-line option was +removed in Synapse 1.3.0. You should instead set up a separate log configuration file. +""" + class LoggingConfig(Config): section = "logging" def read_config(self, config, **kwargs): + if config.get("log_file"): + raise ConfigError(LOG_FILE_ERROR) self.log_config = self.abspath(config.get("log_config")) self.no_redirect_stdio = config.get("no_redirect_stdio", False) @@ -106,6 +113,8 @@ class LoggingConfig(Config): def read_arguments(self, args): if args.no_redirect_stdio is not None: self.no_redirect_stdio = args.no_redirect_stdio + if args.log_file is not None: + raise ConfigError(LOG_FILE_ERROR) @staticmethod def add_arguments(parser): @@ -118,6 +127,10 @@ class LoggingConfig(Config): help="Do not redirect stdout/stderr to the log", ) + logging_group.add_argument( + "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS, + ) + def generate_files(self, config, config_dir_path): log_config = config.get("log_config") if log_config and not os.path.exists(log_config): From 08815566bca79d001ad1bf58b2b082e435b6e5df Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 3 Jan 2020 17:14:00 +0000 Subject: [PATCH 76/82] Automate generation of the sample and debian log configs (#6627) --- changelog.d/6627.misc | 1 + debian/build_virtualenv | 3 +++ debian/changelog | 6 +++++ debian/install | 1 - debian/log.yaml | 36 ------------------------- docs/sample_log_config.yaml | 4 +-- scripts-dev/generate_sample_config | 10 +++++++ scripts/generate_log_config | 43 ++++++++++++++++++++++++++++++ synapse/config/logger.py | 9 ++++++- 9 files changed, 73 insertions(+), 40 deletions(-) create mode 100644 changelog.d/6627.misc delete mode 100644 debian/log.yaml create mode 100755 scripts/generate_log_config diff --git a/changelog.d/6627.misc b/changelog.d/6627.misc new file mode 100644 index 0000000000..702f067070 --- /dev/null +++ b/changelog.d/6627.misc @@ -0,0 +1 @@ +Automate generation of the sample log config. diff --git a/debian/build_virtualenv b/debian/build_virtualenv index 2791896052..d892fd5c9d 100755 --- a/debian/build_virtualenv +++ b/debian/build_virtualenv @@ -85,6 +85,9 @@ PYTHONPATH="$tmpdir" \ ' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml" +# build the log config file +"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \ + --output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml" # add a dependency on the right version of python to substvars. PYPKG=`basename $SNAKE` diff --git a/debian/changelog b/debian/changelog index 31791c127c..75fe89fa97 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.7.3ubuntu1) UNRELEASED; urgency=medium + + * Automate generation of the default log configuration file. + + -- Richard van der Hoff Fri, 03 Jan 2020 13:55:38 +0000 + matrix-synapse-py3 (1.7.3) stable; urgency=medium * New synapse release 1.7.3. diff --git a/debian/install b/debian/install index 43dc8c6904..da8b726a2b 100644 --- a/debian/install +++ b/debian/install @@ -1,2 +1 @@ -debian/log.yaml etc/matrix-synapse debian/manage_debconf.pl /opt/venvs/matrix-synapse/lib/ diff --git a/debian/log.yaml b/debian/log.yaml deleted file mode 100644 index 95b655dd35..0000000000 --- a/debian/log.yaml +++ /dev/null @@ -1,36 +0,0 @@ - -version: 1 - -formatters: - precise: - format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s- %(message)s' - -filters: - context: - (): synapse.logging.context.LoggingContextFilter - request: "" - -handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: /var/log/matrix-synapse/homeserver.log - maxBytes: 104857600 - backupCount: 10 - filters: [context] - encoding: utf8 - console: - class: logging.StreamHandler - formatter: precise - level: WARN - -loggers: - synapse: - level: INFO - - synapse.storage.SQL: - level: INFO - -root: - level: INFO - handlers: [file, console] diff --git a/docs/sample_log_config.yaml b/docs/sample_log_config.yaml index 11e8f35f41..1a2739455e 100644 --- a/docs/sample_log_config.yaml +++ b/docs/sample_log_config.yaml @@ -1,4 +1,4 @@ -# Example log config file for synapse. +# Log configuration for Synapse. # # This is a YAML file containing a standard Python logging configuration # dictionary. See [1] for details on the valid settings. @@ -20,7 +20,7 @@ handlers: file: class: logging.handlers.RotatingFileHandler formatter: precise - filename: /home/rav/work/synapse/homeserver.log + filename: /var/log/matrix-synapse/homeserver.log maxBytes: 104857600 backupCount: 10 filters: [context] diff --git a/scripts-dev/generate_sample_config b/scripts-dev/generate_sample_config index 5e33b9b549..9cb4630a5c 100755 --- a/scripts-dev/generate_sample_config +++ b/scripts-dev/generate_sample_config @@ -7,12 +7,22 @@ set -e cd `dirname $0`/.. SAMPLE_CONFIG="docs/sample_config.yaml" +SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml" + +check() { + diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1 +} if [ "$1" == "--check" ]; then diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || { echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 exit 1 } + diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || { + echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2 + exit 1 + } else ./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG" + ./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG" fi diff --git a/scripts/generate_log_config b/scripts/generate_log_config new file mode 100755 index 0000000000..b6957f48a3 --- /dev/null +++ b/scripts/generate_log_config @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import sys + +from synapse.config.logger import DEFAULT_LOG_CONFIG + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-o", + "--output-file", + type=argparse.FileType("w"), + default=sys.stdout, + help="File to write the configuration to. Default: stdout", + ) + + parser.add_argument( + "-f", + "--log-file", + type=str, + default="/var/log/matrix-synapse/homeserver.log", + help="name of the log file", + ) + + args = parser.parse_args() + args.output_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file)) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 3c455610d9..a25c70e928 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -40,7 +40,14 @@ from synapse.util.versionstring import get_version_string from ._base import Config, ConfigError DEFAULT_LOG_CONFIG = Template( - """ + """\ +# Log configuration for Synapse. +# +# This is a YAML file containing a standard Python logging configuration +# dictionary. See [1] for details on the valid settings. +# +# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema + version: 1 formatters: From 01c3c6c9298d0bbdbbc6e829e9c9f1e1a52e8332 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 6 Jan 2020 04:53:07 -0500 Subject: [PATCH 77/82] Fix power levels being incorrectly set in old and new rooms after a room upgrade (#6633) Modify a copy of an upgraded room's PL before sending to the new room --- changelog.d/6633.bugfix | 1 + synapse/handlers/room.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6633.bugfix diff --git a/changelog.d/6633.bugfix b/changelog.d/6633.bugfix new file mode 100644 index 0000000000..4bacf26021 --- /dev/null +++ b/changelog.d/6633.bugfix @@ -0,0 +1 @@ +Fix bug where a moderator upgraded a room and became an admin in the new room. \ No newline at end of file diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 89c9118b26..4f489762fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -16,6 +16,7 @@ # limitations under the License. """Contains functions for performing events on rooms.""" +import copy import itertools import logging import math @@ -271,7 +272,7 @@ class RoomCreationHandler(BaseHandler): except AuthError as e: logger.warning("Unable to update PLs in old room: %s", e) - logger.info("Setting correct PLs in new room") + logger.info("Setting correct PLs in new room to %s", old_room_pl_state.content) yield self.event_creation_handler.create_and_send_nonmember_event( requester, { @@ -365,13 +366,15 @@ class RoomCreationHandler(BaseHandler): needed_power_level = max(state_default, ban, max(event_power_levels.values())) # Raise the requester's power level in the new room if necessary - current_power_level = power_levels["users"][requester.user.to_string()] + current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - # Assign this power level to the requester - power_levels["users"][requester.user.to_string()] = needed_power_level + # Perform a deepcopy in order to not modify the original power levels in a + # room, as its contents are preserved as the state for the old room later on + new_power_levels = copy.deepcopy(power_levels) + initial_state[(EventTypes.PowerLevels, "")] = new_power_levels - # Set the power levels to the modified state - initial_state[(EventTypes.PowerLevels, "")] = power_levels + # Assign this power level to the requester + new_power_levels["users"][user_id] = needed_power_level yield self._send_events_for_new_room( requester, @@ -733,7 +736,7 @@ class RoomCreationHandler(BaseHandler): initial_state, creation_content, room_alias=None, - power_level_content_override=None, + power_level_content_override=None, # Doesn't apply when initial state has power level state event content creator_join_profile=None, ): def create(etype, content, **kwargs): From 18674eebb1fa5d7445952d7e201afe33bd040523 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:28:58 +0000 Subject: [PATCH 78/82] Workaround for error when fetching notary's own key (#6620) * Kill off redundant SynapseRequestFactory We already get the Site via the Channel, so there's no need for a dedicated RequestFactory: we can just use the right constructor. * Workaround for error when fetching notary's own key As a notary server, when we return our own keys, include all of our signing keys in verify_keys. This is a workaround for #6596. --- changelog.d/6620.misc | 1 + synapse/rest/key/v2/remote_key_resource.py | 30 ++-- tests/rest/key/v2/test_remote_key_resource.py | 130 ++++++++++++++++++ tests/unittest.py | 11 +- 4 files changed, 163 insertions(+), 9 deletions(-) create mode 100644 changelog.d/6620.misc create mode 100644 tests/rest/key/v2/test_remote_key_resource.py diff --git a/changelog.d/6620.misc b/changelog.d/6620.misc new file mode 100644 index 0000000000..8bfb78fb20 --- /dev/null +++ b/changelog.d/6620.misc @@ -0,0 +1 @@ +Add a workaround for synapse raising exceptions when fetching the notary's own key from the notary. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e7fc3f0431..bf5e0eb844 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,6 +15,7 @@ import logging from canonicaljson import encode_canonical_json, json +from signedjson.key import encode_verify_key_base64 from signedjson.sign import sign_json from twisted.internet import defer @@ -216,15 +217,28 @@ class RemoteKey(DirectServeResource): if cache_misses and query_remote_on_cache_miss: yield self.fetcher.get_keys(cache_misses) yield self.query_keys(request, query, query_remote_on_cache_miss=False) - else: - signed_keys = [] - for key_json in json_results: - key_json = json.loads(key_json) + return + + signed_keys = [] + for key_json in json_results: + key_json = json.loads(key_json) + + # backwards-compatibility hack for #6596: if the requested key belongs + # to us, make sure that all of the signing keys appear in the + # "verify_keys" section. + if key_json["server_name"] == self.config.server_name: + verify_keys = key_json["verify_keys"] for signing_key in self.config.key_server_signing_keys: - key_json = sign_json(key_json, self.config.server_name, signing_key) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + verify_keys[key_id] = { + "key": encode_verify_key_base64(signing_key.verify_key) + } - signed_keys.append(key_json) + for signing_key in self.config.key_server_signing_keys: + key_json = sign_json(key_json, self.config.server_name, signing_key) - results = {"server_keys": signed_keys} + signed_keys.append(key_json) - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + results = {"server_keys": signed_keys} + + respond_with_json_bytes(request, 200, encode_canonical_json(results)) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py new file mode 100644 index 0000000000..d8246b4e78 --- /dev/null +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import urllib.parse +from io import BytesIO + +from mock import Mock + +import signedjson.key +from nacl.signing import SigningKey +from signedjson.sign import sign_json + +from twisted.web.resource import NoResource + +from synapse.http.site import SynapseRequest +from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.util.httpresourcetree import create_resource_tree + +from tests import unittest +from tests.server import FakeChannel, wait_until_result + + +class RemoteKeyResourceTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + return self.setup_test_homeserver(http_client=self.http_client) + + def create_test_json_resource(self): + return create_resource_tree( + {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource() + ) + + def expect_outgoing_key_request( + self, server_name: str, signing_key: SigningKey + ) -> None: + """ + Tell the mock http client to expect an outgoing GET request for the given key + """ + + def get_json(destination, path, ignore_backoff=False, **kwargs): + self.assertTrue(ignore_backoff) + self.assertEqual(destination, server_name) + key_id = "%s:%s" % (signing_key.alg, signing_key.version) + self.assertEqual( + path, "/_matrix/key/v2/server/%s" % (urllib.parse.quote(key_id),) + ) + + response = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": 200 * 1000, + "verify_keys": { + key_id: { + "key": signedjson.key.encode_verify_key_base64( + signing_key.verify_key + ) + } + }, + } + sign_json(response, server_name, signing_key) + return response + + self.http_client.get_json.side_effect = get_json + + def make_notary_request(self, server_name: str, key_id: str) -> dict: + """Send a GET request to the test server requesting the given key. + + Checks that the response is a 200 and returns the decoded json body. + """ + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/key/v2/query/%s/%s" + % (server_name.encode("utf-8"), key_id.encode("utf-8")), + b"1.1", + ) + wait_until_result(self.reactor, req) + self.assertEqual(channel.code, 200) + resp = channel.json_body + return resp + + def test_get_key(self): + """Fetch a remote key""" + SERVER_NAME = "remote.server" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(SERVER_NAME, testkey) + + resp = self.make_notary_request(SERVER_NAME, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertEqual(len(keys[0]["verify_keys"]), 1) + + # it should be signed by both the origin server and the notary + self.assertIn(SERVER_NAME, keys[0]["signatures"]) + self.assertIn(self.hs.hostname, keys[0]["signatures"]) + + def test_get_own_key(self): + """Fetch our own key""" + testkey = signedjson.key.generate_signing_key("ver1") + self.expect_outgoing_key_request(self.hs.hostname, testkey) + + resp = self.make_notary_request(self.hs.hostname, "ed25519:ver1") + keys = resp["server_keys"] + self.assertEqual(len(keys), 1) + + # it should be signed by both itself, and the notary signing key + sigs = keys[0]["signatures"] + self.assertEqual(len(sigs), 1) + self.assertIn(self.hs.hostname, sigs) + oursigs = sigs[self.hs.hostname] + self.assertEqual(len(oursigs), 2) + + # and both keys should be present in the verify_keys section + self.assertIn("ed25519:ver1", keys[0]["verify_keys"]) + self.assertIn("ed25519:a_lPym", keys[0]["verify_keys"]) diff --git a/tests/unittest.py b/tests/unittest.py index b30b7d1718..cbda237278 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -36,7 +36,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.ratelimiting import FederationRateLimitConfig from synapse.federation.transport import server as federation_server from synapse.http.server import JsonResource -from synapse.http.site import SynapseRequest +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester @@ -210,6 +210,15 @@ class HomeserverTestCase(TestCase): # Register the resources self.resource = self.create_test_json_resource() + # create a site to wrap the resource. + self.site = SynapseSite( + logger_name="synapse.access.http.fake", + site_tag="test", + config={}, + resource=self.resource, + server_version_string="1", + ) + from tests.rest.client.v1.utils import RestHelper self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None)) From 4b36b482e0cc1a63db27534c4ea5d9608cdb6a79 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 12:33:56 +0000 Subject: [PATCH 79/82] Fix exception when fetching notary server's old keys (#6625) Lift the restriction that *all* the keys used for signing v2 key responses be present in verify_keys. Fixes #6596. --- changelog.d/6625.bugfix | 1 + synapse/crypto/keyring.py | 13 ++-- tests/crypto/test_keyring.py | 147 +++++++++++++++++++++++------------ 3 files changed, 107 insertions(+), 54 deletions(-) create mode 100644 changelog.d/6625.bugfix diff --git a/changelog.d/6625.bugfix b/changelog.d/6625.bugfix new file mode 100644 index 0000000000..a8dc5587dc --- /dev/null +++ b/changelog.d/6625.bugfix @@ -0,0 +1 @@ +Fix exception when fetching the `matrix.org:ed25519:auto` key. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 7cfad192e8..6fe5a6a26a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -511,17 +511,18 @@ class BaseV2KeyFetcher(object): server_name = response_json["server_name"] verified = False for key_id in response_json["signatures"].get(server_name, {}): - # each of the keys used for the signature must be present in the response - # json. key = verify_keys.get(key_id) if not key: - raise KeyLookupError( - "Key response is signed by key id %s:%s but that key is not " - "present in the response" % (server_name, key_id) - ) + # the key may not be present in verify_keys if: + # * we got the key from the notary server, and: + # * the key belongs to the notary server, and: + # * the notary server is using a different key to sign notary + # responses. + continue verify_signed_json(response_json, server_name, key.verify_key) verified = True + break if not verified: raise KeyLookupError( diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 8efd39c7f7..34d5895f18 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,6 +19,7 @@ from mock import Mock import canonicaljson import signedjson.key import signedjson.sign +from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key from twisted.internet import defer @@ -412,6 +413,49 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): handlers=None, http_client=self.http_client, config=config ) + def build_perspectives_response( + self, server_name: str, signing_key: SigningKey, valid_until_ts: int, + ) -> dict: + """ + Build a valid perspectives server response to a request for the given key + """ + verify_key = signedjson.key.get_verify_key(signing_key) + verifykey_id = "%s:%s" % (verify_key.alg, verify_key.version) + + response = { + "server_name": server_name, + "old_verify_keys": {}, + "valid_until_ts": valid_until_ts, + "verify_keys": { + verifykey_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) + } + }, + } + # the response must be signed by both the origin server and the perspectives + # server. + signedjson.sign.sign_json(response, server_name, signing_key) + self.mock_perspective_server.sign_response(response) + return response + + def expect_outgoing_key_query( + self, expected_server_name: str, expected_key_id: str, response: dict + ) -> None: + """ + Tell the mock http client to expect a perspectives-server key query + """ + + def post_json(destination, path, data, **kwargs): + self.assertEqual(destination, self.mock_perspective_server.server_name) + self.assertEqual(path, "/_matrix/key/v2/query") + + # check that the request is for the expected key + q = data["server_keys"] + self.assertEqual(list(q[expected_server_name].keys()), [expected_key_id]) + return {"server_keys": [response]} + + self.http_client.post_json.side_effect = post_json + def test_get_keys_from_perspectives(self): # arbitrarily advance the clock a bit self.reactor.advance(100) @@ -424,33 +468,61 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): testverifykey_id = "ed25519:ver1" VALID_UNTIL_TS = 200 * 1000 - # valid response - response = { - "server_name": SERVER_NAME, - "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, - "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) - } - }, - } + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS, + ) - # the response must be signed by both the origin server and the perspectives - # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) - self.mock_perspective_server.sign_response(response) + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") + keys_to_fetch = {SERVER_NAME: {"key1": 0}} + keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + self.assertIn(SERVER_NAME, keys) + k = keys[SERVER_NAME][testverifykey_id] + self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) + self.assertEqual(k.verify_key, testverifykey) + self.assertEqual(k.verify_key.alg, "ed25519") + self.assertEqual(k.verify_key.version, "ver1") - # check that the request is for the expected key - q = data["server_keys"] - self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"]) - return {"server_keys": [response]} + # check that the perspectives store is correctly updated + lookup_triplet = (SERVER_NAME, testverifykey_id, None) + key_json = self.get_success( + self.hs.get_datastore().get_server_keys_json([lookup_triplet]) + ) + res = key_json[lookup_triplet] + self.assertEqual(len(res), 1) + res = res[0] + self.assertEqual(res["key_id"], testverifykey_id) + self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) + self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) + self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) - self.http_client.post_json.side_effect = post_json + self.assertEqual( + bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) + ) + + def test_get_perspectives_own_key(self): + """Check that we can get the perspectives server's own keys + + This is slightly complicated by the fact that the perspectives server may + use different keys for signing notary responses. + """ + + # arbitrarily advance the clock a bit + self.reactor.advance(100) + + fetcher = PerspectivesKeyFetcher(self.hs) + + SERVER_NAME = self.mock_perspective_server.server_name + testkey = signedjson.key.generate_signing_key("ver1") + testverifykey = signedjson.key.get_verify_key(testkey) + testverifykey_id = "ed25519:ver1" + VALID_UNTIL_TS = 200 * 1000 + + response = self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) + + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) keys_to_fetch = {SERVER_NAME: {"key1": 0}} keys = self.get_success(fetcher.get_keys(keys_to_fetch)) @@ -490,35 +562,14 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): VALID_UNTIL_TS = 200 * 1000 def build_response(): - # valid response - response = { - "server_name": SERVER_NAME, - "old_verify_keys": {}, - "valid_until_ts": VALID_UNTIL_TS, - "verify_keys": { - testverifykey_id: { - "key": signedjson.key.encode_verify_key_base64(testverifykey) - } - }, - } - - # the response must be signed by both the origin server and the perspectives - # server. - signedjson.sign.sign_json(response, SERVER_NAME, testkey) - self.mock_perspective_server.sign_response(response) - return response + return self.build_perspectives_response( + SERVER_NAME, testkey, VALID_UNTIL_TS + ) def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) keys_to_fetch = {SERVER_NAME: {"key1": 0}} - - def post_json(destination, path, data, **kwargs): - self.assertEqual(destination, self.mock_perspective_server.server_name) - self.assertEqual(path, "/_matrix/key/v2/query") - return {"server_keys": [response]} - - self.http_client.post_json.side_effect = post_json - + self.expect_outgoing_key_query(SERVER_NAME, "key1", response) return self.get_success(fetcher.get_keys(keys_to_fetch)) # start with a valid response so we can check we are testing the right thing From ab4b4ee6a7e15d1d6e83c4b826051da7df7f83e3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 14:34:02 +0000 Subject: [PATCH 80/82] Fix an error which was thrown by the PresenceHandler _on_shutdown handler. (#6640) --- changelog.d/6640.bugfix | 1 + synapse/handlers/presence.py | 9 ++------- 2 files changed, 3 insertions(+), 7 deletions(-) create mode 100644 changelog.d/6640.bugfix diff --git a/changelog.d/6640.bugfix b/changelog.d/6640.bugfix new file mode 100644 index 0000000000..8c2a129933 --- /dev/null +++ b/changelog.d/6640.bugfix @@ -0,0 +1 @@ +Fix an error which was thrown by the PresenceHandler _on_shutdown handler. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 240c4add12..202aa9294f 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -95,12 +95,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER class PresenceHandler(object): - def __init__(self, hs): - """ - - Args: - hs (synapse.server.HomeServer): - """ + def __init__(self, hs: "synapse.server.HomeServer"): self.hs = hs self.is_mine = hs.is_mine self.is_mine_id = hs.is_mine_id @@ -230,7 +225,7 @@ class PresenceHandler(object): is some spurious presence changes that will self-correct. """ # If the DB pool has already terminated, don't try updating - if not self.store.database.is_running(): + if not self.store.db.is_running(): return logger.info( From 9f6c1befbbb0279dca261b105148e633c3d45453 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jan 2020 14:44:01 +0000 Subject: [PATCH 81/82] Add experimental 'databases' config (#6580) --- changelog.d/6580.feature | 1 + synapse/config/database.py | 55 +++++++++++++++++++------ synapse/storage/data_stores/__init__.py | 21 ++++++++++ 3 files changed, 64 insertions(+), 13 deletions(-) create mode 100644 changelog.d/6580.feature diff --git a/changelog.d/6580.feature b/changelog.d/6580.feature new file mode 100644 index 0000000000..233c589c66 --- /dev/null +++ b/changelog.d/6580.feature @@ -0,0 +1 @@ +Add experimental config option to specify multiple databases. diff --git a/synapse/config/database.py b/synapse/config/database.py index 134824789c..219b32f670 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -15,7 +15,6 @@ import logging import os from textwrap import indent -from typing import List import yaml @@ -30,16 +29,13 @@ class DatabaseConnectionConfig: Args: name: A label for the database, used for logging. db_config: The config for a particular database, as per `database` - section of main config. Has two fields: `name` for database - module name, and `args` for the args to give to the database - connector. - data_stores: The list of data stores that should be provisioned on the - database. Defaults to all data stores. + section of main config. Has three fields: `name` for database + module name, `args` for the args to give to the database + connector, and optional `data_stores` that is a list of stores to + provision on this database (defaulting to all). """ - def __init__( - self, name: str, db_config: dict, data_stores: List[str] = ["main", "state"] - ): + def __init__(self, name: str, db_config: dict): if db_config["name"] not in ("sqlite3", "psycopg2"): raise ConfigError("Unsupported database type %r" % (db_config["name"],)) @@ -48,6 +44,10 @@ class DatabaseConnectionConfig: {"cp_min": 1, "cp_max": 1, "check_same_thread": False} ) + data_stores = db_config.get("data_stores") + if data_stores is None: + data_stores = ["main", "state"] + self.name = name self.config = db_config self.data_stores = data_stores @@ -59,14 +59,43 @@ class DatabaseConfig(Config): def read_config(self, config, **kwargs): self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) + # We *experimentally* support specifying multiple databases via the + # `databases` key. This is a map from a label to database config in the + # same format as the `database` config option, plus an extra + # `data_stores` key to specify which data store goes where. For example: + # + # databases: + # master: + # name: psycopg2 + # data_stores: ["main"] + # args: {} + # state: + # name: psycopg2 + # data_stores: ["state"] + # args: {} + + multi_database_config = config.get("databases") database_config = config.get("database") - if database_config is None: - database_config = {"name": "sqlite3", "args": {}} + if multi_database_config and database_config: + raise ConfigError("Can't specify both 'database' and 'datbases' in config") - self.databases = [DatabaseConnectionConfig("master", database_config)] + if multi_database_config: + if config.get("database_path"): + raise ConfigError("Can't specify 'database_path' with 'databases'") - self.set_databasepath(config.get("database_path")) + self.databases = [ + DatabaseConnectionConfig(name, db_conf) + for name, db_conf in multi_database_config.items() + ] + + else: + if database_config is None: + database_config = {"name": "sqlite3", "args": {}} + + self.databases = [DatabaseConnectionConfig("master", database_config)] + + self.set_databasepath(config.get("database_path")) def generate_config_section(self, data_dir_path, database_conf, **kwargs): if not database_conf: diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py index d20df5f076..092e803799 100644 --- a/synapse/storage/data_stores/__init__.py +++ b/synapse/storage/data_stores/__init__.py @@ -37,6 +37,8 @@ class DataStores(object): # store. self.databases = [] + self.main = None + self.state = None for database_config in hs.config.database.databases: db_name = database_config.name @@ -54,10 +56,22 @@ class DataStores(object): if "main" in database_config.data_stores: logger.info("Starting 'main' data store") + + # Sanity check we don't try and configure the main store on + # multiple databases. + if self.main: + raise Exception("'main' data store already configured") + self.main = main_store_class(database, db_conn, hs) if "state" in database_config.data_stores: logger.info("Starting 'state' data store") + + # Sanity check we don't try and configure the state store on + # multiple databases. + if self.state: + raise Exception("'state' data store already configured") + self.state = StateGroupDataStore(database, db_conn, hs) db_conn.commit() @@ -65,3 +79,10 @@ class DataStores(object): self.databases.append(database) logger.info("Database %r prepared", db_name) + + # Sanity check that we have actually configured all the required stores. + if not self.main: + raise Exception("No 'main' data store configured") + + if not self.state: + raise Exception("No 'main' data store configured") From ba897a75903129a453d4fb853190dd31f7d1193b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 6 Jan 2020 15:22:46 +0000 Subject: [PATCH 82/82] Fix some test failures when frozen_dicts are enabled (#6642) Fixes #4026 --- changelog.d/6642.misc | 1 + synapse/crypto/event_signing.py | 9 ++++++--- synapse/handlers/room.py | 15 +++++++++------ synapse/handlers/room_member.py | 2 ++ synapse/storage/data_stores/main/state.py | 4 ++-- 5 files changed, 20 insertions(+), 11 deletions(-) create mode 100644 changelog.d/6642.misc diff --git a/changelog.d/6642.misc b/changelog.d/6642.misc new file mode 100644 index 0000000000..a480bbd134 --- /dev/null +++ b/changelog.d/6642.misc @@ -0,0 +1 @@ +Fix errors when frozen_dicts are enabled. diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py index ccaa8a9920..e65bd61d97 100644 --- a/synapse/crypto/event_signing.py +++ b/synapse/crypto/event_signing.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc import hashlib import logging @@ -40,8 +40,11 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256): # some malformed events lack a 'hashes'. Protect against it being missing # or a weird type by basically treating it the same as an unhashed event. hashes = event.get("hashes") - if not isinstance(hashes, dict): - raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED) + # nb it might be a frozendict or a dict + if not isinstance(hashes, collections.abc.Mapping): + raise SynapseError( + 400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED + ) if name not in hashes: raise SynapseError( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 4f489762fc..9cab2adbfb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -16,7 +16,7 @@ # limitations under the License. """Contains functions for performing events on rooms.""" -import copy + import itertools import logging import math @@ -368,13 +368,16 @@ class RoomCreationHandler(BaseHandler): # Raise the requester's power level in the new room if necessary current_power_level = power_levels["users"][user_id] if current_power_level < needed_power_level: - # Perform a deepcopy in order to not modify the original power levels in a - # room, as its contents are preserved as the state for the old room later on - new_power_levels = copy.deepcopy(power_levels) - initial_state[(EventTypes.PowerLevels, "")] = new_power_levels + # make sure we copy the event content rather than overwriting it. + # note that if frozen_dicts are enabled, `power_levels` will be a frozen + # dict so we can't just copy.deepcopy it. - # Assign this power level to the requester + new_power_levels = {k: v for k, v in power_levels.items() if k != "users"} + new_power_levels["users"] = { + k: v for k, v in power_levels.get("users", {}).items() if k != user_id + } new_power_levels["users"][user_id] = needed_power_level + initial_state[(EventTypes.PowerLevels, "")] = new_power_levels yield self._send_events_for_new_room( requester, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 44c5e3239c..dbb0c3dda2 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -507,6 +507,8 @@ class RoomMemberHandler(object): Returns: Deferred """ + logger.info("Transferring room state from %s to %s", old_room_id, room_id) + # Find all local users that were in the old room and copy over each user's state users = yield self.store.get_users_in_room(old_room_id) yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users) diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py index 0dc39f139c..d07440e3ed 100644 --- a/synapse/storage/data_stores/main/state.py +++ b/synapse/storage/data_stores/main/state.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import collections.abc import logging from collections import namedtuple from typing import Iterable, Tuple @@ -107,7 +107,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): predecessor = create_event.content.get("predecessor", None) # Ensure the key is a dictionary - if not isinstance(predecessor, dict): + if not isinstance(predecessor, collections.abc.Mapping): return None return predecessor