mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 15:53:51 +01:00
Wait for previous attempts at fetching keys for a given server before trying to fetch more
This commit is contained in:
parent
b5f55a1d85
commit
f0dd568e16
1 changed files with 68 additions and 15 deletions
|
@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util.retryutils import get_retry_limiter
|
from synapse.util.retryutils import get_retry_limiter
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
|
|
||||||
|
from synapse.util.async import ObservableDeferred
|
||||||
|
|
||||||
from OpenSSL import crypto
|
from OpenSSL import crypto
|
||||||
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
@ -88,6 +90,8 @@ class Keyring(object):
|
||||||
"Not signed with a supported algorithm",
|
"Not signed with a supported algorithm",
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
))
|
))
|
||||||
|
else:
|
||||||
|
deferreds[group_id] = defer.Deferred()
|
||||||
|
|
||||||
group = KeyGroup(server_name, group_id, key_ids)
|
group = KeyGroup(server_name, group_id, key_ids)
|
||||||
|
|
||||||
|
@ -133,10 +137,41 @@ class Keyring(object):
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
deferreds.update(self.get_server_verify_keys(
|
server_to_deferred = {
|
||||||
group_id_to_group
|
server_name: defer.Deferred()
|
||||||
))
|
for server_name, _ in server_and_json
|
||||||
|
}
|
||||||
|
|
||||||
|
# We want to wait for any previous lookups to complete before
|
||||||
|
# proceeding.
|
||||||
|
wait_on_deferred = self.wait_for_previous_lookups(
|
||||||
|
[server_name for server_name, _ in server_and_json],
|
||||||
|
server_to_deferred,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Actually start fetching keys.
|
||||||
|
wait_on_deferred.addBoth(
|
||||||
|
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
|
||||||
|
)
|
||||||
|
|
||||||
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
|
# any lookups waiting will proceed.
|
||||||
|
server_to_gids = {}
|
||||||
|
|
||||||
|
def remove_deferreds(res, server_name, group_id):
|
||||||
|
server_to_gids[server_name].discard(group_id)
|
||||||
|
if not server_to_gids[server_name]:
|
||||||
|
server_to_deferred.pop(server_name).callback(None)
|
||||||
|
return res
|
||||||
|
|
||||||
|
for g_id, deferred in deferreds.items():
|
||||||
|
server_name = group_id_to_group[g_id].server_name
|
||||||
|
server_to_gids.setdefault(server_name, set()).add(g_id)
|
||||||
|
deferred.addBoth(remove_deferreds, server_name, g_id)
|
||||||
|
|
||||||
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
|
# signatures can be verified
|
||||||
return [
|
return [
|
||||||
handle_key_deferred(
|
handle_key_deferred(
|
||||||
group_id_to_group[g_id],
|
group_id_to_group[g_id],
|
||||||
|
@ -145,7 +180,30 @@ class Keyring(object):
|
||||||
for g_id in group_ids
|
for g_id in group_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_server_verify_keys(self, group_id_to_group):
|
@defer.inlineCallbacks
|
||||||
|
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||||
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_names (list): list of server_names we want to lookup
|
||||||
|
server_to_deferred (dict): server_name to deferred which gets
|
||||||
|
resolved once we've finished looking up keys for that server
|
||||||
|
"""
|
||||||
|
while True:
|
||||||
|
wait_on = [
|
||||||
|
self.key_downloads[server_name]
|
||||||
|
for server_name in server_names
|
||||||
|
if server_name in self.key_downloads
|
||||||
|
]
|
||||||
|
if wait_on:
|
||||||
|
yield defer.DeferredList(wait_on)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
for server_name, deferred in server_to_deferred:
|
||||||
|
self.key_downloads[server_name] = ObservableDeferred(deferred)
|
||||||
|
|
||||||
|
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
|
||||||
"""Takes a dict of KeyGroups and tries to find at least one key for
|
"""Takes a dict of KeyGroups and tries to find at least one key for
|
||||||
each group.
|
each group.
|
||||||
"""
|
"""
|
||||||
|
@ -157,11 +215,6 @@ class Keyring(object):
|
||||||
self.get_keys_from_server, # Then try directly
|
self.get_keys_from_server, # Then try directly
|
||||||
)
|
)
|
||||||
|
|
||||||
group_deferreds = {
|
|
||||||
group_id: defer.Deferred()
|
|
||||||
for group_id in group_id_to_group
|
|
||||||
}
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_iterations():
|
def do_iterations():
|
||||||
merged_results = {}
|
merged_results = {}
|
||||||
|
@ -182,7 +235,7 @@ class Keyring(object):
|
||||||
for group in group_id_to_group.values():
|
for group in group_id_to_group.values():
|
||||||
for key_id in group.key_ids:
|
for key_id in group.key_ids:
|
||||||
if key_id in merged_results[group.server_name]:
|
if key_id in merged_results[group.server_name]:
|
||||||
group_deferreds.pop(group.group_id).callback((
|
group_id_to_deferred[group.group_id].callback((
|
||||||
group.group_id,
|
group.group_id,
|
||||||
group.server_name,
|
group.server_name,
|
||||||
key_id,
|
key_id,
|
||||||
|
@ -205,7 +258,7 @@ class Keyring(object):
|
||||||
}
|
}
|
||||||
|
|
||||||
for group in missing_groups.values():
|
for group in missing_groups.values():
|
||||||
group_deferreds.pop(group.group_id).errback(SynapseError(
|
group_id_to_deferred[group.group_id].errback(SynapseError(
|
||||||
401,
|
401,
|
||||||
"No key for %s with id %s" % (
|
"No key for %s with id %s" % (
|
||||||
group.server_name, group.key_ids,
|
group.server_name, group.key_ids,
|
||||||
|
@ -214,13 +267,13 @@ class Keyring(object):
|
||||||
))
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
for deferred in group_deferreds.values():
|
for deferred in group_id_to_deferred.values():
|
||||||
deferred.errback(err)
|
if not deferred.called:
|
||||||
group_deferreds.clear()
|
deferred.errback(err)
|
||||||
|
|
||||||
do_iterations().addErrback(on_err)
|
do_iterations().addErrback(on_err)
|
||||||
|
|
||||||
return group_deferreds
|
return group_id_to_deferred
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
|
|
Loading…
Reference in a new issue