mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 15:53:51 +01:00
Wait for previous attempts at fetching keys for a given server before trying to fetch more
This commit is contained in:
parent
b5f55a1d85
commit
f0dd568e16
1 changed files with 68 additions and 15 deletions
|
@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes
|
|||
from synapse.util.retryutils import get_retry_limiter
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
from collections import namedtuple
|
||||
|
@ -88,6 +90,8 @@ class Keyring(object):
|
|||
"Not signed with a supported algorithm",
|
||||
Codes.UNAUTHORIZED,
|
||||
))
|
||||
else:
|
||||
deferreds[group_id] = defer.Deferred()
|
||||
|
||||
group = KeyGroup(server_name, group_id, key_ids)
|
||||
|
||||
|
@ -133,10 +137,41 @@ class Keyring(object):
|
|||
Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
deferreds.update(self.get_server_verify_keys(
|
||||
group_id_to_group
|
||||
))
|
||||
server_to_deferred = {
|
||||
server_name: defer.Deferred()
|
||||
for server_name, _ in server_and_json
|
||||
}
|
||||
|
||||
# We want to wait for any previous lookups to complete before
|
||||
# proceeding.
|
||||
wait_on_deferred = self.wait_for_previous_lookups(
|
||||
[server_name for server_name, _ in server_and_json],
|
||||
server_to_deferred,
|
||||
)
|
||||
|
||||
# Actually start fetching keys.
|
||||
wait_on_deferred.addBoth(
|
||||
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||
)
|
||||
|
||||
# When we've finished fetching all the keys for a given server_name,
|
||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||
# any lookups waiting will proceed.
|
||||
server_to_gids = {}
|
||||
|
||||
def remove_deferreds(res, server_name, group_id):
|
||||
server_to_gids[server_name].discard(group_id)
|
||||
if not server_to_gids[server_name]:
|
||||
server_to_deferred.pop(server_name).callback(None)
|
||||
return res
|
||||
|
||||
for g_id, deferred in deferreds.items():
|
||||
server_name = group_id_to_group[g_id].server_name
|
||||
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||
|
||||
# Pass those keys to handle_key_deferred so that the json object
|
||||
# signatures can be verified
|
||||
return [
|
||||
handle_key_deferred(
|
||||
group_id_to_group[g_id],
|
||||
|
@ -145,7 +180,30 @@ class Keyring(object):
|
|||
for g_id in group_ids
|
||||
]
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group):
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||
"""Waits for any previous key lookups for the given servers to finish.
|
||||
|
||||
Args:
|
||||
server_names (list): list of server_names we want to lookup
|
||||
server_to_deferred (dict): server_name to deferred which gets
|
||||
resolved once we've finished looking up keys for that server
|
||||
"""
|
||||
while True:
|
||||
wait_on = [
|
||||
self.key_downloads[server_name]
|
||||
for server_name in server_names
|
||||
if server_name in self.key_downloads
|
||||
]
|
||||
if wait_on:
|
||||
yield defer.DeferredList(wait_on)
|
||||
else:
|
||||
break
|
||||
|
||||
for server_name, deferred in server_to_deferred:
|
||||
self.key_downloads[server_name] = ObservableDeferred(deferred)
|
||||
|
||||
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||
each group.
|
||||
"""
|
||||
|
@ -157,11 +215,6 @@ class Keyring(object):
|
|||
self.get_keys_from_server, # Then try directly
|
||||
)
|
||||
|
||||
group_deferreds = {
|
||||
group_id: defer.Deferred()
|
||||
for group_id in group_id_to_group
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def do_iterations():
|
||||
merged_results = {}
|
||||
|
@ -182,7 +235,7 @@ class Keyring(object):
|
|||
for group in group_id_to_group.values():
|
||||
for key_id in group.key_ids:
|
||||
if key_id in merged_results[group.server_name]:
|
||||
group_deferreds.pop(group.group_id).callback((
|
||||
group_id_to_deferred[group.group_id].callback((
|
||||
group.group_id,
|
||||
group.server_name,
|
||||
key_id,
|
||||
|
@ -205,7 +258,7 @@ class Keyring(object):
|
|||
}
|
||||
|
||||
for group in missing_groups.values():
|
||||
group_deferreds.pop(group.group_id).errback(SynapseError(
|
||||
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||
401,
|
||||
"No key for %s with id %s" % (
|
||||
group.server_name, group.key_ids,
|
||||
|
@ -214,13 +267,13 @@ class Keyring(object):
|
|||
))
|
||||
|
||||
def on_err(err):
|
||||
for deferred in group_deferreds.values():
|
||||
deferred.errback(err)
|
||||
group_deferreds.clear()
|
||||
for deferred in group_id_to_deferred.values():
|
||||
if not deferred.called:
|
||||
deferred.errback(err)
|
||||
|
||||
do_iterations().addErrback(on_err)
|
||||
|
||||
return group_deferreds
|
||||
return group_id_to_deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_keys_from_store(self, server_name_and_key_ids):
|
||||
|
|
Loading…
Reference in a new issue