forked from MirrorHub/synapse
Rewrite get_server_verify_keys, again.
Attempt to simplify the logic in get_server_verify_keys by splitting it into two methods.
This commit is contained in:
parent
99113e40ba
commit
a82c96b87f
2 changed files with 54 additions and 48 deletions
1
changelog.d/5299.misc
Normal file
1
changelog.d/5299.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Rewrite get_server_verify_keys, again.
|
|
@ -270,59 +270,21 @@ class Keyring(object):
|
||||||
verify_requests (list[VerifyKeyRequest]): list of verify requests
|
verify_requests (list[VerifyKeyRequest]): list of verify requests
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
remaining_requests = set(
|
||||||
|
(rq for rq in verify_requests if not rq.deferred.called)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_iterations():
|
def do_iterations():
|
||||||
with Measure(self.clock, "get_server_verify_keys"):
|
with Measure(self.clock, "get_server_verify_keys"):
|
||||||
# dict[str, set(str)]: keys to fetch for each server
|
|
||||||
missing_keys = {}
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
missing_keys.setdefault(verify_request.server_name, set()).update(
|
|
||||||
verify_request.key_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
for f in self._key_fetchers:
|
for f in self._key_fetchers:
|
||||||
results = yield f.get_keys(missing_keys.items())
|
if not remaining_requests:
|
||||||
|
return
|
||||||
# We now need to figure out which verify requests we have keys
|
yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
|
||||||
# for and which we don't
|
|
||||||
missing_keys = {}
|
|
||||||
requests_missing_keys = []
|
|
||||||
for verify_request in verify_requests:
|
|
||||||
if verify_request.deferred.called:
|
|
||||||
# We've already called this deferred, which probably
|
|
||||||
# means that we've already found a key for it.
|
|
||||||
continue
|
|
||||||
|
|
||||||
server_name = verify_request.server_name
|
|
||||||
|
|
||||||
# see if any of the keys we got this time are sufficient to
|
|
||||||
# complete this VerifyKeyRequest.
|
|
||||||
result_keys = results.get(server_name, {})
|
|
||||||
for key_id in verify_request.key_ids:
|
|
||||||
fetch_key_result = result_keys.get(key_id)
|
|
||||||
if fetch_key_result:
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
verify_request.deferred.callback(
|
|
||||||
(
|
|
||||||
server_name,
|
|
||||||
key_id,
|
|
||||||
fetch_key_result.verify_key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# The else block is only reached if the loop above
|
|
||||||
# doesn't break.
|
|
||||||
missing_keys.setdefault(server_name, set()).update(
|
|
||||||
verify_request.key_ids
|
|
||||||
)
|
|
||||||
requests_missing_keys.append(verify_request)
|
|
||||||
|
|
||||||
if not missing_keys:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
# look for any requests which weren't satisfied
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
for verify_request in requests_missing_keys:
|
for verify_request in remaining_requests:
|
||||||
verify_request.deferred.errback(
|
verify_request.deferred.errback(
|
||||||
SynapseError(
|
SynapseError(
|
||||||
401,
|
401,
|
||||||
|
@ -333,13 +295,56 @@ class Keyring(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
|
# we don't really expect to get here, because any errors should already
|
||||||
|
# have been caught and logged. But if we do, let's log the error and make
|
||||||
|
# sure that all of the deferreds are resolved.
|
||||||
|
logger.error("Unexpected error in _get_server_verify_keys: %s", err)
|
||||||
with PreserveLoggingContext():
|
with PreserveLoggingContext():
|
||||||
for verify_request in verify_requests:
|
for verify_request in remaining_requests:
|
||||||
if not verify_request.deferred.called:
|
if not verify_request.deferred.called:
|
||||||
verify_request.deferred.errback(err)
|
verify_request.deferred.errback(err)
|
||||||
|
|
||||||
run_in_background(do_iterations).addErrback(on_err)
|
run_in_background(do_iterations).addErrback(on_err)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
|
||||||
|
"""Use a key fetcher to attempt to satisfy some key requests
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fetcher (KeyFetcher): fetcher to use to fetch the keys
|
||||||
|
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
|
||||||
|
Any successfully-completed requests will be reomved from the list.
|
||||||
|
"""
|
||||||
|
# dict[str, set(str)]: keys to fetch for each server
|
||||||
|
missing_keys = {}
|
||||||
|
for verify_request in remaining_requests:
|
||||||
|
# any completed requests should already have been removed
|
||||||
|
assert not verify_request.deferred.called
|
||||||
|
missing_keys.setdefault(verify_request.server_name, set()).update(
|
||||||
|
verify_request.key_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
results = yield fetcher.get_keys(missing_keys.items())
|
||||||
|
|
||||||
|
completed = list()
|
||||||
|
for verify_request in remaining_requests:
|
||||||
|
server_name = verify_request.server_name
|
||||||
|
|
||||||
|
# see if any of the keys we got this time are sufficient to
|
||||||
|
# complete this VerifyKeyRequest.
|
||||||
|
result_keys = results.get(server_name, {})
|
||||||
|
for key_id in verify_request.key_ids:
|
||||||
|
key = result_keys.get(key_id)
|
||||||
|
if key:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
verify_request.deferred.callback(
|
||||||
|
(server_name, key_id, key.verify_key)
|
||||||
|
)
|
||||||
|
completed.append(verify_request)
|
||||||
|
break
|
||||||
|
|
||||||
|
remaining_requests.difference_update(completed)
|
||||||
|
|
||||||
|
|
||||||
class KeyFetcher(object):
|
class KeyFetcher(object):
|
||||||
def get_keys(self, server_name_and_key_ids):
|
def get_keys(self, server_name_and_key_ids):
|
||||||
|
|
Loading…
Reference in a new issue