Merge pull request #5001 from matrix-org/rav/keyring_cleanups

Cleanups in the Keyring
This commit is contained in:
Richard van der Hoff 2019-04-08 12:47:09 +01:00 committed by GitHub
commit 7fc1e17f4c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 181 additions and 33 deletions

1
changelog.d/5001.misc Normal file
View file

@ -0,0 +1 @@
Clean up some code in the server-key Keyring.

View file

@ -20,6 +20,7 @@ from collections import namedtuple
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
import nacl.signing
from signedjson.key import ( from signedjson.key import (
decode_verify_key_bytes, decode_verify_key_bytes,
encode_verify_key_base64, encode_verify_key_base64,
@ -494,11 +495,11 @@ class Keyring(object):
) )
processed_response = yield self.process_v2_response( processed_response = yield self.process_v2_response(
perspective_name, response, only_from_server=False perspective_name, response
) )
server_name = response["server_name"]
for server_name, response_keys in processed_response.items(): keys.setdefault(server_name, {}).update(processed_response)
keys.setdefault(server_name, {}).update(response_keys)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield logcontext.make_deferred_yieldable(defer.gatherResults(
[ [
@ -517,7 +518,7 @@ class Keyring(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids): def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} keys = {} # type: dict[str, nacl.signing.VerifyKey]
for requested_key_id in key_ids: for requested_key_id in key_ids:
if requested_key_id in keys: if requested_key_id in keys:
@ -542,6 +543,11 @@ class Keyring(object):
or server_name not in response[u"signatures"]): or server_name not in response[u"signatures"]):
raise KeyLookupError("Key response not signed by remote server") raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name:
raise KeyLookupError("Expected a response for server %r not %r" % (
server_name, response["server_name"]
))
response_keys = yield self.process_v2_response( response_keys = yield self.process_v2_response(
from_server=server_name, from_server=server_name,
requested_ids=[requested_key_id], requested_ids=[requested_key_id],
@ -550,24 +556,45 @@ class Keyring(object):
keys.update(response_keys) keys.update(response_keys)
yield logcontext.make_deferred_yieldable(defer.gatherResults( yield self.store_keys(
[ server_name=server_name,
run_in_background( from_server=server_name,
self.store_keys, verify_keys=keys,
server_name=key_server_name, )
from_server=server_name, defer.returnValue({server_name: keys})
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError))
defer.returnValue(keys)
@defer.inlineCallbacks @defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, def process_v2_response(
requested_ids=[], only_from_server=True): self, from_server, response_json, requested_ids=[],
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
"""
time_now_ms = self.clock.time_msec() time_now_ms = self.clock.time_msec()
response_keys = {} response_keys = {}
verify_keys = {} verify_keys = {}
@ -589,15 +616,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"] server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}): for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]: if key_id not in response_json["verify_keys"]:
raise KeyLookupError( raise KeyLookupError(
@ -643,9 +662,7 @@ class Keyring(object):
consumeErrors=True, consumeErrors=True,
).addErrback(unwrapFirstError)) ).addErrback(unwrapFirstError))
results[server_name] = response_keys defer.returnValue(response_keys)
defer.returnValue(results)
def store_keys(self, server_name, from_server, verify_keys): def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server """Store a collection of verify keys for a given server

View file

@ -188,8 +188,8 @@ class KeyStore(SQLBaseStore):
Args: Args:
server_keys (list): List of (server_name, key_id, source) triplets. server_keys (list): List of (server_name, key_id, source) triplets.
Returns: Returns:
Dict mapping (server_name, key_id, source) triplets to dicts with Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
"ts_valid_until_ms" and "key_json" keys. Dict mapping (server_name, key_id, source) triplets to lists of dicts
""" """
def _get_server_keys_json_txn(txn): def _get_server_keys_json_txn(txn):

View file

@ -16,6 +16,7 @@ import time
from mock import Mock from mock import Mock
import canonicaljson
import signedjson.key import signedjson.key
import signedjson.sign import signedjson.sign
@ -23,6 +24,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.crypto import keyring from synapse.crypto import keyring
from synapse.crypto.keyring import KeyLookupError
from synapse.util import logcontext from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext from synapse.util.logcontext import LoggingContext
@ -48,6 +50,9 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)} key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
}, },
} }
return self.get_signed_response(res)
def get_signed_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key) signedjson.sign.sign_json(res, self.server_name, self.key)
return res return res
@ -202,6 +207,131 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called) self.assertFalse(d.called)
self.get_success(d) self.get_success(d)
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 1000
# valid response
response = {
"server_name": SERVER_NAME,
"old_verify_keys": {},
"valid_until_ts": VALID_UNTIL_TS,
"verify_keys": {
testverifykey_id: {
"key": signedjson.key.encode_verify_key_base64(testverifykey)
}
},
}
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
def get_json(destination, path, **kwargs):
self.assertEqual(destination, SERVER_NAME)
self.assertEqual(path, "/_matrix/key/v2/server/key1")
return response
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastore().get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
)
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 200 * 1000
# valid response
response = {
"server_name": SERVER_NAME,
"old_verify_keys": {},
"valid_until_ts": VALID_UNTIL_TS,
"verify_keys": {
testverifykey_id: {
"key": signedjson.key.encode_verify_key_base64(testverifykey)
}
},
}
persp_resp = {
"server_keys": [self.mock_perspective_server.get_signed_response(response)]
}
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
return persp_resp
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastore().get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
self.assertEqual(
bytes(res["key_json"]),
canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
)
@defer.inlineCallbacks @defer.inlineCallbacks
def run_in_context(f, *args, **kwargs): def run_in_context(f, *args, **kwargs):