0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-17 07:21:37 +01:00

Fix stack overflow in Keyring (#5724)

* Refactor Keyring._start_key_lookups

There's an awful lot of deferreds and dictionaries flying around here. The
whole thing can be made much simpler and achieve the same effect.

* Add a delay to key lookup lock release to fix stack overflow

A tactical call_later here should fix #5723

* changelog
This commit is contained in:
Richard van der Hoff 2019-07-22 13:51:22 +01:00 committed by GitHub
commit 0cb72812f9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 79 deletions

1
changelog.d/5724.bugfix Normal file
View file

@ -0,0 +1 @@
Fix stack overflow in server key lookup code.

View file

@ -238,27 +238,9 @@ class Keyring(object):
""" """
try: try:
# create a deferred for each server we're going to look up the keys ctx = LoggingContext.current_context()
# for; we'll resolve them once we have completed our lookups.
# These will be passed into wait_for_previous_lookups to block
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before # map from server name to a set of outstanding request ids
# proceeding.
yield self.wait_for_previous_lookups(server_to_deferred)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
#
# map from server name to a set of request ids
server_to_request_ids = {} server_to_request_ids = {}
for verify_request in verify_requests: for verify_request in verify_requests:
@ -266,40 +248,61 @@ class Keyring(object):
request_id = id(verify_request) request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id) server_to_request_ids.setdefault(server_name, set()).add(request_id)
def remove_deferreds(res, verify_request): # Wait for any previous lookups to complete before proceeding.
yield self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
for server_name in server_to_request_ids.keys():
self.key_downloads[server_name] = defer.Deferred()
logger.debug("Got key lookup lock on %s", server_name)
# When we've finished fetching all the keys for a given server_name,
# drop the lock by resolving the deferred in key_downloads.
def drop_server_lock(server_name):
d = self.key_downloads.pop(server_name)
d.callback(None)
def lookup_done(res, verify_request):
server_name = verify_request.server_name server_name = verify_request.server_name
request_id = id(verify_request) server_requests = server_to_request_ids[server_name]
server_to_request_ids[server_name].discard(request_id) server_requests.remove(id(verify_request))
if not server_to_request_ids[server_name]:
d = server_to_deferred.pop(server_name, None) # if there are no more requests for this server, we can drop the lock.
if d: if not server_requests:
d.callback(None) with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name)
# ... but not immediately, as that can cause stack explosions if
# we get a long queue of lookups.
self.clock.call_later(0, drop_server_lock, server_name)
return res return res
for verify_request in verify_requests: for verify_request in verify_requests:
verify_request.key_ready.addBoth(remove_deferreds, verify_request) verify_request.key_ready.addBoth(lookup_done, verify_request)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
except Exception: except Exception:
logger.exception("Error starting key lookups") logger.exception("Error starting key lookups")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_previous_lookups(self, server_to_deferred): def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish. """Waits for any previous key lookups for the given servers to finish.
Args: Args:
server_to_deferred (dict[str, Deferred]): server_name to deferred which gets server_names (Iterable[str]): list of servers which we want to look up
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
Returns: a Deferred which resolves once all key lookups for the given Returns:
servers have completed. Follows the synapse rules of logcontext Deferred[None]: resolves once all key lookups for the given servers have
preservation. completed. Follows the synapse rules of logcontext preservation.
""" """
loop_count = 1 loop_count = 1
while True: while True:
wait_on = [ wait_on = [
(server_name, self.key_downloads[server_name]) (server_name, self.key_downloads[server_name])
for server_name in server_to_deferred.keys() for server_name in server_names
if server_name in self.key_downloads if server_name in self.key_downloads
] ]
if not wait_on: if not wait_on:
@ -314,19 +317,6 @@ class Keyring(object):
loop_count += 1 loop_count += 1
ctx = LoggingContext.current_context()
def rm(r, server_name_):
with PreserveLoggingContext(ctx):
logger.debug("Releasing key lookup lock on %s", server_name_)
self.key_downloads.pop(server_name_, None)
return r
for server_name, deferred in server_to_deferred.items():
logger.debug("Got key lookup lock on %s", server_name)
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
def _get_server_verify_keys(self, verify_requests): def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request

View file

@ -86,35 +86,6 @@ class KeyringTestCase(unittest.HomeserverTestCase):
getattr(LoggingContext.current_context(), "request", None), expected getattr(LoggingContext.current_context(), "request", None), expected
) )
def test_wait_for_previous_lookups(self):
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
self.successResultOf(wait_1_deferred)
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
# now the second wait should complete.
self.successResultOf(wait_2_deferred)
def test_verify_json_objects_for_server_awaits_previous_requests(self): def test_verify_json_objects_for_server_awaits_previous_requests(self):
key1 = signedjson.key.generate_signing_key(1) key1 = signedjson.key.generate_signing_key(1)