From 895b79ac2ece74500fb8a4ea158a6aec2adc0856 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 9 Apr 2019 18:28:17 +0100 Subject: [PATCH] Factor out KeyFetchers from KeyRing Rather than have three methods which have to have the same interface, factor out a separate interface which is provided by three implementations. I find it easier to grok the code this way. --- changelog.d/5244.misc | 1 + synapse/crypto/keyring.py | 459 +++++++++++++++++++---------------- tests/crypto/test_keyring.py | 34 ++- 3 files changed, 276 insertions(+), 218 deletions(-) create mode 100644 changelog.d/5244.misc diff --git a/changelog.d/5244.misc b/changelog.d/5244.misc new file mode 100644 index 000000000..9cc1fb869 --- /dev/null +++ b/changelog.d/5244.misc @@ -0,0 +1 @@ +Refactor synapse.crypto.keyring to use a KeyFetcher interface. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 14a27288f..eaf41b983 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -80,12 +80,13 @@ class KeyLookupError(ValueError): class Keyring(object): def __init__(self, hs): - self.store = hs.get_datastore() self.clock = hs.get_clock() - self.client = hs.get_http_client() - self.config = hs.get_config() - self.perspective_servers = self.config.perspectives - self.hs = hs + + self._key_fetchers = ( + StoreKeyFetcher(hs), + PerspectivesKeyFetcher(hs), + ServerKeyFetcher(hs), + ) # map from server name to Deferred. Has an entry for each server with # an ongoing key download; the Deferred completes once the download @@ -271,13 +272,6 @@ class Keyring(object): verify_requests (list[VerifyKeyRequest]): list of verify requests """ - # These are functions that produce keys given a list of key ids - key_fetch_fns = ( - self.get_keys_from_store, # First try the local store - self.get_keys_from_perspectives, # Then try via perspectives - self.get_keys_from_server, # Then try directly - ) - @defer.inlineCallbacks def do_iterations(): with Measure(self.clock, "get_server_verify_keys"): @@ -288,8 +282,8 @@ class Keyring(object): verify_request.key_ids ) - for fn in key_fetch_fns: - results = yield fn(missing_keys.items()) + for f in self._key_fetchers: + results = yield f.get_keys(missing_keys.items()) # We now need to figure out which verify requests we have keys # for and which we don't @@ -348,8 +342,9 @@ class Keyring(object): run_in_background(do_iterations).addErrback(on_err) - @defer.inlineCallbacks - def get_keys_from_store(self, server_name_and_key_ids): + +class KeyFetcher(object): + def get_keys(self, server_name_and_key_ids): """ Args: server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]): @@ -359,6 +354,18 @@ class Keyring(object): Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: map from server_name -> key_id -> FetchKeyResult """ + raise NotImplementedError + + +class StoreKeyFetcher(KeyFetcher): + """KeyFetcher impl which fetches keys from our data store""" + + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" keys_to_fetch = ( (server_name, key_id) for server_name, key_ids in server_name_and_key_ids @@ -370,203 +377,11 @@ class Keyring(object): keys.setdefault(server_name, {})[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks - def get_keys_from_perspectives(self, server_name_and_key_ids): - @defer.inlineCallbacks - def get_key(perspective_name, perspective_keys): - try: - result = yield self.get_server_verify_key_v2_indirect( - server_name_and_key_ids, perspective_name, perspective_keys - ) - defer.returnValue(result) - except KeyLookupError as e: - logger.warning("Key lookup failed from %r: %s", perspective_name, e) - except Exception as e: - logger.exception( - "Unable to get key from %r: %s %s", - perspective_name, - type(e).__name__, - str(e), - ) - defer.returnValue({}) - - results = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background(get_key, p_name, p_keys) - for p_name, p_keys in self.perspective_servers.items() - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - union_of_keys = {} - for result in results: - for server_name, keys in result.items(): - union_of_keys.setdefault(server_name, {}).update(keys) - - defer.returnValue(union_of_keys) - - @defer.inlineCallbacks - def get_keys_from_server(self, server_name_and_key_ids): - results = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.get_server_verify_key_v2_direct, server_name, key_ids - ) - for server_name, key_ids in server_name_and_key_ids - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) - ) - - merged = {} - for result in results: - merged.update(result) - - defer.returnValue( - {server_name: keys for server_name, keys in merged.items() if keys} - ) - - @defer.inlineCallbacks - def get_server_verify_key_v2_indirect( - self, server_names_and_key_ids, perspective_name, perspective_keys - ): - """ - Args: - server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): - list of (server_name, iterable[key_id]) tuples to fetch keys for - perspective_name (str): name of the notary server to query for the keys - perspective_keys (dict[str, VerifyKey]): map of key_id->key for the - notary server - - Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map - from server_name -> key_id -> FetchKeyResult - """ - # TODO(mark): Set the minimum_valid_until_ts to that needed by - # the events being validated or the current time if validating - # an incoming request. - try: - query_response = yield self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - u"server_keys": { - server_name: { - key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids - } - for server_name, key_ids in server_names_and_key_ids - } - }, - long_retries=True, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - raise_from(KeyLookupError("Failed to connect to remote server"), e) - except HttpResponseException as e: - raise_from(KeyLookupError("Remote server returned an error"), e) - - keys = {} - added_keys = [] - - time_now_ms = self.clock.time_msec() - - for response in query_response["server_keys"]: - if ( - u"signatures" not in response - or perspective_name not in response[u"signatures"] - ): - raise KeyLookupError( - "Key response not signed by perspective server" - " %r" % (perspective_name,) - ) - - verified = False - for key_id in response[u"signatures"][perspective_name]: - if key_id in perspective_keys: - verify_signed_json( - response, perspective_name, perspective_keys[key_id] - ) - verified = True - - if not verified: - logging.info( - "Response from perspective server %r not signed with a" - " known key, signed with: %r, known keys: %r", - perspective_name, - list(response[u"signatures"][perspective_name]), - list(perspective_keys), - ) - raise KeyLookupError( - "Response not signed with a known key for perspective" - " server %r" % (perspective_name,) - ) - - processed_response = yield self.process_v2_response( - perspective_name, response, time_added_ms=time_now_ms - ) - server_name = response["server_name"] - - added_keys.extend( - (server_name, key_id, key) for key_id, key in processed_response.items() - ) - keys.setdefault(server_name, {}).update(processed_response) - - yield self.store.store_server_verify_keys( - perspective_name, time_now_ms, added_keys - ) - - defer.returnValue(keys) - - @defer.inlineCallbacks - def get_server_verify_key_v2_direct(self, server_name, key_ids): - keys = {} # type: dict[str, FetchKeyResult] - - for requested_key_id in key_ids: - if requested_key_id in keys: - continue - - time_now_ms = self.clock.time_msec() - try: - response = yield self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - ) - except (NotRetryingDestination, RequestSendFailed) as e: - raise_from(KeyLookupError("Failed to connect to remote server"), e) - except HttpResponseException as e: - raise_from(KeyLookupError("Remote server returned an error"), e) - - if ( - u"signatures" not in response - or server_name not in response[u"signatures"] - ): - raise KeyLookupError("Key response not signed by remote server") - - if response["server_name"] != server_name: - raise KeyLookupError( - "Expected a response for server %r not %r" - % (server_name, response["server_name"]) - ) - - response_keys = yield self.process_v2_response( - from_server=server_name, - requested_ids=[requested_key_id], - response_json=response, - time_added_ms=time_now_ms, - ) - yield self.store.store_server_verify_keys( - server_name, - time_now_ms, - ((server_name, key_id, key) for key_id, key in response_keys.items()), - ) - keys.update(response_keys) - - defer.returnValue({server_name: keys}) +class BaseV2KeyFetcher(object): + def __init__(self, hs): + self.store = hs.get_datastore() + self.config = hs.get_config() @defer.inlineCallbacks def process_v2_response( @@ -670,6 +485,226 @@ class Keyring(object): defer.returnValue(verify_keys) +class PerspectivesKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the "perspectives" servers""" + + def __init__(self, hs): + super(PerspectivesKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + self.perspective_servers = self.config.perspectives + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + + @defer.inlineCallbacks + def get_key(perspective_name, perspective_keys): + try: + result = yield self.get_server_verify_key_v2_indirect( + server_name_and_key_ids, perspective_name, perspective_keys + ) + defer.returnValue(result) + except KeyLookupError as e: + logger.warning("Key lookup failed from %r: %s", perspective_name, e) + except Exception as e: + logger.exception( + "Unable to get key from %r: %s %s", + perspective_name, + type(e).__name__, + str(e), + ) + + defer.returnValue({}) + + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background(get_key, p_name, p_keys) + for p_name, p_keys in self.perspective_servers.items() + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + union_of_keys = {} + for result in results: + for server_name, keys in result.items(): + union_of_keys.setdefault(server_name, {}).update(keys) + + defer.returnValue(union_of_keys) + + @defer.inlineCallbacks + def get_server_verify_key_v2_indirect( + self, server_names_and_key_ids, perspective_name, perspective_keys + ): + """ + Args: + server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]): + list of (server_name, iterable[key_id]) tuples to fetch keys for + perspective_name (str): name of the notary server to query for the keys + perspective_keys (dict[str, VerifyKey]): map of key_id->key for the + notary server + + Returns: + Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + from server_name -> key_id -> FetchKeyResult + """ + # TODO(mark): Set the minimum_valid_until_ts to that needed by + # the events being validated or the current time if validating + # an incoming request. + try: + query_response = yield self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + u"server_keys": { + server_name: { + key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids + } + for server_name, key_ids in server_names_and_key_ids + } + }, + long_retries=True, + ) + except (NotRetryingDestination, RequestSendFailed) as e: + raise_from(KeyLookupError("Failed to connect to remote server"), e) + except HttpResponseException as e: + raise_from(KeyLookupError("Remote server returned an error"), e) + + keys = {} + added_keys = [] + + time_now_ms = self.clock.time_msec() + + for response in query_response["server_keys"]: + if ( + u"signatures" not in response + or perspective_name not in response[u"signatures"] + ): + raise KeyLookupError( + "Key response not signed by perspective server" + " %r" % (perspective_name,) + ) + + verified = False + for key_id in response[u"signatures"][perspective_name]: + if key_id in perspective_keys: + verify_signed_json( + response, perspective_name, perspective_keys[key_id] + ) + verified = True + + if not verified: + logging.info( + "Response from perspective server %r not signed with a" + " known key, signed with: %r, known keys: %r", + perspective_name, + list(response[u"signatures"][perspective_name]), + list(perspective_keys), + ) + raise KeyLookupError( + "Response not signed with a known key for perspective" + " server %r" % (perspective_name,) + ) + + processed_response = yield self.process_v2_response( + perspective_name, response, time_added_ms=time_now_ms + ) + server_name = response["server_name"] + + added_keys.extend( + (server_name, key_id, key) for key_id, key in processed_response.items() + ) + keys.setdefault(server_name, {}).update(processed_response) + + yield self.store.store_server_verify_keys( + perspective_name, time_now_ms, added_keys + ) + + defer.returnValue(keys) + + +class ServerKeyFetcher(BaseV2KeyFetcher): + """KeyFetcher impl which fetches keys from the origin servers""" + + def __init__(self, hs): + super(ServerKeyFetcher, self).__init__(hs) + self.clock = hs.get_clock() + self.client = hs.get_http_client() + + @defer.inlineCallbacks + def get_keys(self, server_name_and_key_ids): + """see KeyFetcher.get_keys""" + results = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + run_in_background( + self.get_server_verify_key_v2_direct, server_name, key_ids + ) + for server_name, key_ids in server_name_and_key_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + ) + + merged = {} + for result in results: + merged.update(result) + + defer.returnValue( + {server_name: keys for server_name, keys in merged.items() if keys} + ) + + @defer.inlineCallbacks + def get_server_verify_key_v2_direct(self, server_name, key_ids): + keys = {} # type: dict[str, FetchKeyResult] + + for requested_key_id in key_ids: + if requested_key_id in keys: + continue + + time_now_ms = self.clock.time_msec() + try: + response = yield self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + ) + except (NotRetryingDestination, RequestSendFailed) as e: + raise_from(KeyLookupError("Failed to connect to remote server"), e) + except HttpResponseException as e: + raise_from(KeyLookupError("Remote server returned an error"), e) + + if ( + u"signatures" not in response + or server_name not in response[u"signatures"] + ): + raise KeyLookupError("Key response not signed by remote server") + + if response["server_name"] != server_name: + raise KeyLookupError( + "Expected a response for server %r not %r" + % (server_name, response["server_name"]) + ) + + response_keys = yield self.process_v2_response( + from_server=server_name, + requested_ids=[requested_key_id], + response_json=response, + time_added_ms=time_now_ms, + ) + yield self.store.store_server_verify_keys( + server_name, + time_now_ms, + ((server_name, key_id, key) for key_id, key in response_keys.items()), + ) + keys.update(response_keys) + + defer.returnValue({server_name: keys}) + + @defer.inlineCallbacks def _handle_key_deferred(verify_request): """Waits for the key to become available, and then performs a verification diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 83de32b05..de61bad15 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -24,7 +24,11 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import ( + KeyLookupError, + PerspectivesKeyFetcher, + ServerKeyFetcher, +) from synapse.storage.keys import FetchKeyResult from synapse.util import logcontext from synapse.util.logcontext import LoggingContext @@ -218,12 +222,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.assertFalse(d.called) self.get_success(d) + +class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.http_client = Mock() + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + return hs + def test_get_keys_from_server(self): # arbitrarily advance the clock a bit self.reactor.advance(100) SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) + fetcher = ServerKeyFetcher(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" @@ -250,7 +261,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids)) + keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) @@ -278,15 +289,26 @@ class KeyringTestCase(unittest.HomeserverTestCase): # change the server name: it should cause a rejection response["server_name"] = "OTHER_SERVER" self.get_failure( - kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError + fetcher.get_keys(server_name_and_key_ids), KeyLookupError ) + +class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): + def make_homeserver(self, reactor, clock): + self.mock_perspective_server = MockPerspectiveServer() + self.http_client = Mock() + hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client) + keys = self.mock_perspective_server.get_verify_keys() + hs.config.perspectives = {self.mock_perspective_server.server_name: keys} + return hs + def test_get_keys_from_perspectives(self): # arbitrarily advance the clock a bit self.reactor.advance(100) + fetcher = PerspectivesKeyFetcher(self.hs) + SERVER_NAME = "server2" - kr = keyring.Keyring(self.hs) testkey = signedjson.key.generate_signing_key("ver1") testverifykey = signedjson.key.get_verify_key(testkey) testverifykey_id = "ed25519:ver1" @@ -320,7 +342,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): self.http_client.post_json.side_effect = post_json server_name_and_key_ids = [(SERVER_NAME, ("key1",))] - keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids)) + keys = self.get_success(fetcher.get_keys(server_name_and_key_ids)) self.assertIn(SERVER_NAME, keys) k = keys[SERVER_NAME][testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)