0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-12 13:01:34 +01:00

Convert the crypto module to async/await. (#8003)

This commit is contained in:
Patrick Cloke 2020-08-03 08:29:01 -04:00 committed by GitHub
parent b6c6fb7950
commit 2a89ce8cd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 133 deletions

1
changelog.d/8003.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -223,8 +223,7 @@ class Keyring(object):
return results return results
@defer.inlineCallbacks async def _start_key_lookups(self, verify_requests):
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request """Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved. Once each fetch completes, verify_request.key_ready will be resolved.
@ -245,7 +244,7 @@ class Keyring(object):
server_to_request_ids.setdefault(server_name, set()).add(request_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding. # Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys()) await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in # take out a lock on each of the servers by sticking a Deferred in
# key_downloads # key_downloads
@ -283,15 +282,14 @@ class Keyring(object):
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@defer.inlineCallbacks async def wait_for_previous_lookups(self, server_names) -> None:
def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_names (Iterable[str]): list of servers which we want to look up server_names (Iterable[str]): list of servers which we want to look up
Returns: Returns:
Deferred[None]: resolves once all key lookups for the given servers have Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation. completed. Follows the synapse rules of logcontext preservation.
""" """
loop_count = 1 loop_count = 1
@ -309,7 +307,7 @@ class Keyring(object):
loop_count, loop_count,
) )
with PreserveLoggingContext(): with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on)) await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1 loop_count += 1
@ -326,13 +324,15 @@ class Keyring(object):
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@defer.inlineCallbacks async def do_iterations():
def do_iterations(): try:
with Measure(self.clock, "get_server_verify_keys"): with Measure(self.clock, "get_server_verify_keys"):
for f in self._key_fetchers: for f in self._key_fetchers:
if not remaining_requests: if not remaining_requests:
return return
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) await self._attempt_key_fetches_with_fetcher(
f, remaining_requests
)
# look for any requests which weren't satisfied # look for any requests which weren't satisfied
with PreserveLoggingContext(): with PreserveLoggingContext():
@ -349,8 +349,7 @@ class Keyring(object):
Codes.UNAUTHORIZED, Codes.UNAUTHORIZED,
) )
) )
except Exception as err:
def on_err(err):
# we don't really expect to get here, because any errors should already # we don't really expect to get here, because any errors should already
# have been caught and logged. But if we do, let's log the error and make # have been caught and logged. But if we do, let's log the error and make
# sure that all of the deferreds are resolved. # sure that all of the deferreds are resolved.
@ -360,10 +359,9 @@ class Keyring(object):
if not verify_request.key_ready.called: if not verify_request.key_ready.called:
verify_request.key_ready.errback(err) verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err) run_in_background(do_iterations)
@defer.inlineCallbacks async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests """Use a key fetcher to attempt to satisfy some key requests
Args: Args:
@ -390,7 +388,7 @@ class Keyring(object):
verify_request.minimum_valid_until_ts, verify_request.minimum_valid_until_ts,
) )
results = yield fetcher.get_keys(missing_keys) results = await fetcher.get_keys(missing_keys)
completed = [] completed = []
for verify_request in remaining_requests: for verify_request in remaining_requests:
@ -423,7 +421,7 @@ class Keyring(object):
class KeyFetcher(object): class KeyFetcher(object):
def get_keys(self, keys_to_fetch): async def get_keys(self, keys_to_fetch):
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch (dict[str, dict[str, int]]):
@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks async def get_keys(self, keys_to_fetch):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
keys_to_fetch = ( keys_to_fetch = (
@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in keys_for_server.keys() for key_id in keys_for_server.keys()
) )
res = yield self.store.get_server_verify_keys(keys_to_fetch) res = await self.store.get_server_verify_keys(keys_to_fetch)
keys = {} keys = {}
for (server_name, key_id), key in res.items(): for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key keys.setdefault(server_name, {})[key_id] = key
@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.config = hs.get_config() self.config = hs.get_config()
@defer.inlineCallbacks async def process_v2_response(self, from_server, response_json, time_added_ms):
def process_v2_response(self, from_server, response_json, time_added_ms):
"""Parse a 'Server Keys' structure from the result of a /key request """Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from This is used to parse either the entirety of the response from
@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
key_json_bytes = encode_canonical_json(response_json) key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[ [
run_in_background( run_in_background(
@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client() self.client = hs.get_http_client()
self.key_servers = self.config.key_servers self.key_servers = self.config.key_servers
@defer.inlineCallbacks async def get_keys(self, keys_to_fetch):
def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys""" """see KeyFetcher.get_keys"""
@defer.inlineCallbacks async def get_key(key_server):
def get_key(key_server):
try: try:
result = yield self.get_server_verify_key_v2_indirect( result = await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server keys_to_fetch, key_server
) )
return result return result
@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return {} return {}
results = yield make_deferred_yieldable( results = await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers], [run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True, consumeErrors=True,
@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys return union_of_keys
@defer.inlineCallbacks async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
""" """
Args: Args:
keys_to_fetch (dict[str, dict[str, int]]): keys_to_fetch (dict[str, dict[str, int]]):
@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
the keys the keys
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
from server_name -> key_id -> FetchKeyResult from server_name -> key_id -> FetchKeyResult
Raises: Raises:
@ -632,8 +625,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
try: try:
query_response = yield defer.ensureDeferred( query_response = await self.client.post_json(
self.client.post_json(
destination=perspective_name, destination=perspective_name,
path="/_matrix/key/v2/query", path="/_matrix/key/v2/query",
data={ data={
@ -646,7 +638,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
} }
}, },
) )
)
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve upon # these both have str() representations which we can't really improve upon
raise KeyLookupError(str(e)) raise KeyLookupError(str(e))
@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
try: try:
self._validate_perspectives_response(key_server, response) self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response( processed_response = await self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms perspective_name, response, time_added_ms=time_now_ms
) )
except KeyLookupError as e: except KeyLookupError as e:
@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
) )
keys.setdefault(server_name, {}).update(processed_response) keys.setdefault(server_name, {}).update(processed_response)
yield self.store.store_server_verify_keys( await self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys perspective_name, time_now_ms, added_keys
) )
@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.client = hs.get_http_client() self.client = hs.get_http_client()
def get_keys(self, keys_to_fetch): async def get_keys(self, keys_to_fetch):
""" """
Args: Args:
keys_to_fetch (dict[str, iterable[str]]): keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids the keys to be fetched. server_name -> key_ids
Returns: Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
map from server_name -> key_id -> FetchKeyResult map from server_name -> key_id -> FetchKeyResult
""" """
results = {} results = {}
@defer.inlineCallbacks async def get_key(key_to_fetch_item):
def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item server_name, key_ids = key_to_fetch_item
try: try:
keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys results[server_name] = keys
except KeyLookupError as e: except KeyLookupError as e:
logger.warning( logger.warning(
@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception: except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name) logger.exception("Error getting keys %s from %s", key_ids, server_name)
return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( return await yieldable_gather_results(
lambda _: results get_key, keys_to_fetch.items()
) ).addCallback(lambda _: results)
@defer.inlineCallbacks async def get_server_verify_key_v2_direct(self, server_name, key_ids):
def get_server_verify_key_v2_direct(self, server_name, key_ids):
""" """
Args: Args:
@ -794,8 +783,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
try: try:
response = yield defer.ensureDeferred( response = await self.client.get_json(
self.client.get_json(
destination=server_name, destination=server_name,
path="/_matrix/key/v2/server/" path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id), + urllib.parse.quote(requested_key_id),
@ -813,7 +801,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
# read the response). # read the response).
timeout=10000, timeout=10000,
) )
)
except (NotRetryingDestination, RequestSendFailed) as e: except (NotRetryingDestination, RequestSendFailed) as e:
# these both have str() representations which we can't really improve # these both have str() representations which we can't really improve
# upon # upon
@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"]) % (server_name, response["server_name"])
) )
response_keys = yield self.process_v2_response( response_keys = await self.process_v2_response(
from_server=server_name, from_server=server_name,
response_json=response, response_json=response,
time_added_ms=time_now_ms, time_added_ms=time_now_ms,
) )
yield self.store.store_server_verify_keys( await self.store.store_server_verify_keys(
server_name, server_name,
time_now_ms, time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()), ((server_name, key_id, key) for key_id, key in response_keys.items()),
@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys return keys
@defer.inlineCallbacks async def _handle_key_deferred(verify_request) -> None:
def _handle_key_deferred(verify_request):
"""Waits for the key to become available, and then performs a verification """Waits for the key to become available, and then performs a verification
Args: Args:
verify_request (VerifyJsonRequest): verify_request (VerifyJsonRequest):
Returns:
Deferred[None]
Raises: Raises:
SynapseError if there was a problem performing the verification SynapseError if there was a problem performing the verification
""" """
server_name = verify_request.server_name server_name = verify_request.server_name
with PreserveLoggingContext(): with PreserveLoggingContext():
_, key_id, verify_key = yield verify_request.key_ready _, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object json_object = verify_request.json_object

View file

@ -40,6 +40,7 @@ from synapse.logging.context import (
from synapse.storage.keys import FetchKeyResult from synapse.storage.keys import FetchKeyResult
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable
class MockPerspectiveServer(object): class MockPerspectiveServer(object):
@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
with a null `ts_valid_until_ms` with a null `ts_valid_until_ms`
""" """
mock_fetcher = keyring.KeyFetcher() mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
kr = keyring.Keyring( kr = keyring.Keyring(
self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""Two requests for the same key should be deduped.""" """Two requests for the same key should be deduped."""
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
def get_keys(keys_to_fetch): async def get_keys(keys_to_fetch):
# there should only be one request object (with the max validity) # there should only be one request object (with the max validity)
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed( return {
{
"server1": { "server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
} }
} }
)
mock_fetcher = keyring.KeyFetcher() mock_fetcher = keyring.KeyFetcher()
mock_fetcher.get_keys = Mock(side_effect=get_keys) mock_fetcher.get_keys = Mock(side_effect=get_keys)
@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
"""If the first fetcher cannot provide a recent enough key, we fall back""" """If the first fetcher cannot provide a recent enough key, we fall back"""
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)
def get_keys1(keys_to_fetch): async def get_keys1(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed( return {
{ "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
"server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
} }
}
)
def get_keys2(keys_to_fetch): async def get_keys2(keys_to_fetch):
self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
return defer.succeed( return {
{
"server1": { "server1": {
get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
} }
} }
)
mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1 = keyring.KeyFetcher()
mock_fetcher1.get_keys = Mock(side_effect=get_keys1) mock_fetcher1.get_keys = Mock(side_effect=get_keys1)