Clean up and document handling of logcontexts in Keyring (#2452)

I'm still unclear on what the intended behaviour for
`verify_json_objects_for_server` is, but at least I now understand the
behaviour of most of the things it calls...
This commit is contained in:
Richard van der Hoff 2017-09-18 18:31:01 +01:00 committed by GitHub
parent 77c81ca6ea
commit 290777b3d9
2 changed files with 110 additions and 28 deletions

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,10 +16,9 @@
from synapse.crypto.keyclient import fetch_server_key from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError, logcontext
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import ( from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn preserve_fn
) )
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -74,6 +74,11 @@ class Keyring(object):
self.perspective_servers = self.config.perspectives self.perspective_servers = self.config.perspectives
self.hs = hs self.hs = hs
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {} self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object): def verify_json_for_server(self, server_name, json_object):
@ -82,7 +87,7 @@ class Keyring(object):
)[0] )[0]
def verify_json_objects_for_server(self, server_and_json): def verify_json_objects_for_server(self, server_and_json):
"""Bulk verfies signatures of json objects, bulk fetching keys as """Bulk verifies signatures of json objects, bulk fetching keys as
necessary. necessary.
Args: Args:
@ -212,7 +217,13 @@ class Keyring(object):
Args: Args:
server_names (list): list of server_names we want to lookup server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
Returns: a Deferred which resolves once all key lookups for the given
servers have completed. Follows the synapse rules of logcontext
preservation.
""" """
while True: while True:
wait_on = [ wait_on = [
@ -226,15 +237,13 @@ class Keyring(object):
else: else:
break break
def rm(r, server_name_):
self.key_downloads.pop(server_name_, None)
return r
for server_name, deferred in server_to_deferred.items(): for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(preserve_context_over_deferred(deferred)) self.key_downloads[server_name] = deferred
self.key_downloads[server_name] = d deferred.addBoth(rm, server_name)
def rm(r, server_name):
self.key_downloads.pop(server_name, None)
return r
d.addBoth(rm, server_name)
def get_server_verify_keys(self, verify_requests): def get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request """Tries to find at least one key for each verify request
@ -333,7 +342,7 @@ class Keyring(object):
Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
server_name -> key_id -> VerifyKey server_name -> key_id -> VerifyKey
""" """
res = yield preserve_context_over_deferred(defer.gatherResults( res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.get_server_verify_keys)( preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids server_name, key_ids
@ -341,7 +350,7 @@ class Keyring(object):
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(dict(res)) defer.returnValue(dict(res))
@ -362,13 +371,13 @@ class Keyring(object):
) )
defer.returnValue({}) defer.returnValue({})
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(get_key)(p_name, p_keys) preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items() for p_name, p_keys in self.perspective_servers.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
union_of_keys = {} union_of_keys = {}
for result in results: for result in results:
@ -402,13 +411,13 @@ class Keyring(object):
defer.returnValue(keys) defer.returnValue(keys)
results = yield preserve_context_over_deferred(defer.gatherResults( results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(get_key)(server_name, key_ids) preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids for server_name, key_ids in server_name_and_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
merged = {} merged = {}
for result in results: for result in results:
@ -485,7 +494,7 @@ class Keyring(object):
for server_name, response_keys in processed_response.items(): for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys) keys.setdefault(server_name, {}).update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=server_name, server_name=server_name,
@ -495,7 +504,7 @@ class Keyring(object):
for server_name, response_keys in keys.items() for server_name, response_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(keys) defer.returnValue(keys)
@ -543,7 +552,7 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store_keys)( preserve_fn(self.store_keys)(
server_name=key_server_name, server_name=key_server_name,
@ -553,7 +562,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items() for key_server_name, verify_keys in keys.items()
], ],
consumeErrors=True consumeErrors=True
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
defer.returnValue(keys) defer.returnValue(keys)
@ -619,7 +628,7 @@ class Keyring(object):
response_keys.update(verify_keys) response_keys.update(verify_keys)
response_keys.update(old_verify_keys) response_keys.update(old_verify_keys)
yield preserve_context_over_deferred(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_keys_json)( preserve_fn(self.store.store_server_keys_json)(
server_name=server_name, server_name=server_name,
@ -632,7 +641,7 @@ class Keyring(object):
for key_id in updated_key_ids for key_id in updated_key_ids
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))
results[server_name] = response_keys results[server_name] = response_keys
@ -710,7 +719,6 @@ class Keyring(object):
defer.returnValue(verify_keys) defer.returnValue(verify_keys)
@defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys): def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server """Store a collection of verify keys for a given server
Args: Args:
@ -721,7 +729,7 @@ class Keyring(object):
A deferred that completes when the keys are stored. A deferred that completes when the keys are stored.
""" """
# TODO(markjh): Store whether the keys have expired. # TODO(markjh): Store whether the keys have expired.
yield preserve_context_over_deferred(defer.gatherResults( return logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
preserve_fn(self.store.store_server_verify_key)( preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key server_name, server_name, key.time_added, key
@ -729,4 +737,4 @@ class Keyring(object):
for key_id, key in verify_keys.items() for key_id, key in verify_keys.items()
], ],
consumeErrors=True, consumeErrors=True,
)).addErrback(unwrapFirstError) ).addErrback(unwrapFirstError))

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.crypto import keyring
from synapse.util.logcontext import LoggingContext
from tests import utils, unittest
from twisted.internet import defer
class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(handlers=None)
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
sentinel_context = LoggingContext.current_context()
kr = keyring.Keyring(self.hs)
def check_context(_, expected):
self.assertEquals(
LoggingContext.current_context().test_key, expected
)
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
context_one.test_key = "one"
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"],
{"server1": lookup_1_deferred},
)
# there were no previous lookups, so the deferred should be ready
self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one)
wait_1_deferred.addBoth(check_context, "one")
with LoggingContext("two") as context_two:
context_two.test_key = "two"
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = kr.wait_for_previous_lookups(
["server1"],
{"server1": lookup_2_deferred},
)
self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext.
self.assertIs(LoggingContext.current_context(), sentinel_context)
wait_2_deferred.addBoth(check_context, "two")
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
# now the second wait should complete and restore our
# loggingcontext.
yield wait_2_deferred