forked from MirrorHub/synapse
Merge pull request #2459 from matrix-org/rav/keyring_cleanups
Clean up Keyring code
This commit is contained in:
commit
c94ab5976a
5 changed files with 355 additions and 216 deletions
|
@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.logcontext import (
|
from synapse.util.logcontext import (
|
||||||
preserve_context_over_fn, PreserveLoggingContext,
|
PreserveLoggingContext,
|
||||||
preserve_fn
|
preserve_fn
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
@ -57,7 +57,8 @@ Attributes:
|
||||||
json_object(dict): The JSON object to verify.
|
json_object(dict): The JSON object to verify.
|
||||||
deferred(twisted.internet.defer.Deferred):
|
deferred(twisted.internet.defer.Deferred):
|
||||||
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
A deferred (server_name, key_id, verify_key) tuple that resolves when
|
||||||
a verify key has been fetched
|
a verify key has been fetched. The deferreds' callbacks are run with no
|
||||||
|
logcontext.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,9 +83,11 @@ class Keyring(object):
|
||||||
self.key_downloads = {}
|
self.key_downloads = {}
|
||||||
|
|
||||||
def verify_json_for_server(self, server_name, json_object):
|
def verify_json_for_server(self, server_name, json_object):
|
||||||
return self.verify_json_objects_for_server(
|
return logcontext.make_deferred_yieldable(
|
||||||
|
self.verify_json_objects_for_server(
|
||||||
[(server_name, json_object)]
|
[(server_name, json_object)]
|
||||||
)[0]
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
def verify_json_objects_for_server(self, server_and_json):
|
def verify_json_objects_for_server(self, server_and_json):
|
||||||
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
"""Bulk verifies signatures of json objects, bulk fetching keys as
|
||||||
|
@ -94,8 +97,10 @@ class Keyring(object):
|
||||||
server_and_json (list): List of pairs of (server_name, json_object)
|
server_and_json (list): List of pairs of (server_name, json_object)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list of deferreds indicating success or failure to verify each
|
List<Deferred>: for each input pair, a deferred indicating success
|
||||||
json object's signature for the given server_name.
|
or failure to verify each json object's signature for the given
|
||||||
|
server_name. The deferreds run their callbacks in the sentinel
|
||||||
|
logcontext.
|
||||||
"""
|
"""
|
||||||
verify_requests = []
|
verify_requests = []
|
||||||
|
|
||||||
|
@ -122,73 +127,59 @@ class Keyring(object):
|
||||||
|
|
||||||
verify_requests.append(verify_request)
|
verify_requests.append(verify_request)
|
||||||
|
|
||||||
|
preserve_fn(self._start_key_lookups)(verify_requests)
|
||||||
|
|
||||||
|
# Pass those keys to handle_key_deferred so that the json object
|
||||||
|
# signatures can be verified
|
||||||
|
handle = preserve_fn(_handle_key_deferred)
|
||||||
|
return [
|
||||||
|
handle(rq) for rq in verify_requests
|
||||||
|
]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_key_deferred(verify_request):
|
def _start_key_lookups(self, verify_requests):
|
||||||
server_name = verify_request.server_name
|
"""Sets off the key fetches for each verify request
|
||||||
try:
|
|
||||||
_, key_id, verify_key = yield verify_request.deferred
|
|
||||||
except IOError as e:
|
|
||||||
logger.warn(
|
|
||||||
"Got IOError when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
502,
|
|
||||||
"Error downloading keys for %s" % (server_name,),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(
|
|
||||||
"Got Exception when downloading keys for %s: %s %s",
|
|
||||||
server_name, type(e).__name__, str(e.message),
|
|
||||||
)
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"No key for %s with id %s" % (server_name, key_ids),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
json_object = verify_request.json_object
|
Once each fetch completes, verify_request.deferred will be resolved.
|
||||||
|
|
||||||
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
Args:
|
||||||
key_id, verify_key.alg, verify_key.version, server_name,
|
verify_requests (List[VerifyKeyRequest]):
|
||||||
))
|
"""
|
||||||
try:
|
|
||||||
verify_signed_json(json_object, server_name, verify_key)
|
|
||||||
except:
|
|
||||||
raise SynapseError(
|
|
||||||
401,
|
|
||||||
"Invalid signature for server %s with key %s:%s" % (
|
|
||||||
server_name, verify_key.alg, verify_key.version
|
|
||||||
),
|
|
||||||
Codes.UNAUTHORIZED,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# create a deferred for each server we're going to look up the keys
|
||||||
|
# for; we'll resolve them once we have completed our lookups.
|
||||||
|
# These will be passed into wait_for_previous_lookups to block
|
||||||
|
# any other lookups until we have finished.
|
||||||
|
# The deferreds are called with no logcontext.
|
||||||
server_to_deferred = {
|
server_to_deferred = {
|
||||||
server_name: defer.Deferred()
|
rq.server_name: defer.Deferred()
|
||||||
for server_name, _ in server_and_json
|
for rq in verify_requests
|
||||||
}
|
}
|
||||||
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
|
|
||||||
# We want to wait for any previous lookups to complete before
|
# We want to wait for any previous lookups to complete before
|
||||||
# proceeding.
|
# proceeding.
|
||||||
wait_on_deferred = self.wait_for_previous_lookups(
|
yield self.wait_for_previous_lookups(
|
||||||
[server_name for server_name, _ in server_and_json],
|
[rq.server_name for rq in verify_requests],
|
||||||
server_to_deferred,
|
server_to_deferred,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Actually start fetching keys.
|
# Actually start fetching keys.
|
||||||
wait_on_deferred.addBoth(
|
self._get_server_verify_keys(verify_requests)
|
||||||
lambda _: self.get_server_verify_keys(verify_requests)
|
|
||||||
)
|
|
||||||
|
|
||||||
# When we've finished fetching all the keys for a given server_name,
|
# When we've finished fetching all the keys for a given server_name,
|
||||||
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
# resolve the deferred passed to `wait_for_previous_lookups` so that
|
||||||
# any lookups waiting will proceed.
|
# any lookups waiting will proceed.
|
||||||
|
#
|
||||||
|
# map from server name to a set of request ids
|
||||||
server_to_request_ids = {}
|
server_to_request_ids = {}
|
||||||
|
|
||||||
def remove_deferreds(res, server_name, verify_request):
|
for verify_request in verify_requests:
|
||||||
|
server_name = verify_request.server_name
|
||||||
|
request_id = id(verify_request)
|
||||||
|
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
||||||
|
|
||||||
|
def remove_deferreds(res, verify_request):
|
||||||
|
server_name = verify_request.server_name
|
||||||
request_id = id(verify_request)
|
request_id = id(verify_request)
|
||||||
server_to_request_ids[server_name].discard(request_id)
|
server_to_request_ids[server_name].discard(request_id)
|
||||||
if not server_to_request_ids[server_name]:
|
if not server_to_request_ids[server_name]:
|
||||||
|
@ -198,20 +189,10 @@ class Keyring(object):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
server_name = verify_request.server_name
|
|
||||||
request_id = id(verify_request)
|
|
||||||
server_to_request_ids.setdefault(server_name, set()).add(request_id)
|
|
||||||
verify_request.deferred.addBoth(
|
verify_request.deferred.addBoth(
|
||||||
remove_deferreds, server_name, verify_request,
|
remove_deferreds, verify_request,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pass those keys to handle_key_deferred so that the json object
|
|
||||||
# signatures can be verified
|
|
||||||
return [
|
|
||||||
preserve_context_over_fn(handle_key_deferred, verify_request)
|
|
||||||
for verify_request in verify_requests
|
|
||||||
]
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
def wait_for_previous_lookups(self, server_names, server_to_deferred):
|
||||||
"""Waits for any previous key lookups for the given servers to finish.
|
"""Waits for any previous key lookups for the given servers to finish.
|
||||||
|
@ -247,7 +228,7 @@ class Keyring(object):
|
||||||
self.key_downloads[server_name] = deferred
|
self.key_downloads[server_name] = deferred
|
||||||
deferred.addBoth(rm, server_name)
|
deferred.addBoth(rm, server_name)
|
||||||
|
|
||||||
def get_server_verify_keys(self, verify_requests):
|
def _get_server_verify_keys(self, verify_requests):
|
||||||
"""Tries to find at least one key for each verify request
|
"""Tries to find at least one key for each verify request
|
||||||
|
|
||||||
For each verify_request, verify_request.deferred is called back with
|
For each verify_request, verify_request.deferred is called back with
|
||||||
|
@ -316,6 +297,7 @@ class Keyring(object):
|
||||||
if not missing_keys:
|
if not missing_keys:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
with PreserveLoggingContext():
|
||||||
for verify_request in requests_missing_keys.values():
|
for verify_request in requests_missing_keys.values():
|
||||||
verify_request.deferred.errback(SynapseError(
|
verify_request.deferred.errback(SynapseError(
|
||||||
401,
|
401,
|
||||||
|
@ -326,11 +308,12 @@ class Keyring(object):
|
||||||
))
|
))
|
||||||
|
|
||||||
def on_err(err):
|
def on_err(err):
|
||||||
|
with PreserveLoggingContext():
|
||||||
for verify_request in verify_requests:
|
for verify_request in verify_requests:
|
||||||
if not verify_request.deferred.called:
|
if not verify_request.deferred.called:
|
||||||
verify_request.deferred.errback(err)
|
verify_request.deferred.errback(err)
|
||||||
|
|
||||||
do_iterations().addErrback(on_err)
|
preserve_fn(do_iterations)().addErrback(on_err)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_keys_from_store(self, server_name_and_key_ids):
|
def get_keys_from_store(self, server_name_and_key_ids):
|
||||||
|
@ -740,3 +723,47 @@ class Keyring(object):
|
||||||
],
|
],
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
).addErrback(unwrapFirstError))
|
).addErrback(unwrapFirstError))
|
||||||
|
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_key_deferred(verify_request):
|
||||||
|
server_name = verify_request.server_name
|
||||||
|
try:
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
_, key_id, verify_key = yield verify_request.deferred
|
||||||
|
except IOError as e:
|
||||||
|
logger.warn(
|
||||||
|
"Got IOError when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
502,
|
||||||
|
"Error downloading keys for %s" % (server_name,),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
"Got Exception when downloading keys for %s: %s %s",
|
||||||
|
server_name, type(e).__name__, str(e.message),
|
||||||
|
)
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"No key for %s with id %s" % (server_name, verify_request.key_ids),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
||||||
|
json_object = verify_request.json_object
|
||||||
|
|
||||||
|
logger.debug("Got key %s %s:%s for server %s, verifying" % (
|
||||||
|
key_id, verify_key.alg, verify_key.version, server_name,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
verify_signed_json(json_object, server_name, verify_key)
|
||||||
|
except:
|
||||||
|
raise SynapseError(
|
||||||
|
401,
|
||||||
|
"Invalid signature for server %s with key %s:%s" % (
|
||||||
|
server_name, verify_key.alg, verify_key.version
|
||||||
|
),
|
||||||
|
Codes.UNAUTHORIZED,
|
||||||
|
)
|
||||||
|
|
|
@ -18,8 +18,7 @@ from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto.event_signing import check_event_content_hash
|
from synapse.crypto.event_signing import check_event_content_hash
|
||||||
from synapse.events import spamcheck
|
from synapse.events import spamcheck
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.logcontext import preserve_context_over_deferred, preserve_fn
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -51,56 +50,52 @@ class FederationBase(object):
|
||||||
"""
|
"""
|
||||||
deferreds = self._check_sigs_and_hashes(pdus)
|
deferreds = self._check_sigs_and_hashes(pdus)
|
||||||
|
|
||||||
def callback(pdu):
|
@defer.inlineCallbacks
|
||||||
return pdu
|
def handle_check_result(pdu, deferred):
|
||||||
|
try:
|
||||||
|
res = yield logcontext.make_deferred_yieldable(deferred)
|
||||||
|
except SynapseError:
|
||||||
|
res = None
|
||||||
|
|
||||||
def errback(failure, pdu):
|
|
||||||
failure.trap(SynapseError)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def try_local_db(res, pdu):
|
|
||||||
if not res:
|
if not res:
|
||||||
# Check local db.
|
# Check local db.
|
||||||
return self.store.get_event(
|
res = yield self.store.get_event(
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
allow_rejected=True,
|
allow_rejected=True,
|
||||||
allow_none=True,
|
allow_none=True,
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
def try_remote(res, pdu):
|
|
||||||
if not res and pdu.origin != origin:
|
if not res and pdu.origin != origin:
|
||||||
return self.get_pdu(
|
try:
|
||||||
|
res = yield self.get_pdu(
|
||||||
destinations=[pdu.origin],
|
destinations=[pdu.origin],
|
||||||
event_id=pdu.event_id,
|
event_id=pdu.event_id,
|
||||||
outlier=outlier,
|
outlier=outlier,
|
||||||
timeout=10000,
|
timeout=10000,
|
||||||
).addErrback(lambda e: None)
|
)
|
||||||
return res
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
def warn(res, pdu):
|
|
||||||
if not res:
|
if not res:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Failed to find copy of %s with valid signature",
|
"Failed to find copy of %s with valid signature",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
)
|
)
|
||||||
return res
|
|
||||||
|
|
||||||
for pdu, deferred in zip(pdus, deferreds):
|
defer.returnValue(res)
|
||||||
deferred.addCallbacks(
|
|
||||||
callback, errback, errbackArgs=[pdu]
|
handle = logcontext.preserve_fn(handle_check_result)
|
||||||
).addCallback(
|
deferreds2 = [
|
||||||
try_local_db, pdu
|
handle(pdu, deferred)
|
||||||
).addCallback(
|
for pdu, deferred in zip(pdus, deferreds)
|
||||||
try_remote, pdu
|
]
|
||||||
).addCallback(
|
|
||||||
warn, pdu
|
valid_pdus = yield logcontext.make_deferred_yieldable(
|
||||||
|
defer.gatherResults(
|
||||||
|
deferreds2,
|
||||||
|
consumeErrors=True,
|
||||||
)
|
)
|
||||||
|
).addErrback(unwrapFirstError)
|
||||||
valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
|
|
||||||
deferreds,
|
|
||||||
consumeErrors=True
|
|
||||||
)).addErrback(unwrapFirstError)
|
|
||||||
|
|
||||||
if include_none:
|
if include_none:
|
||||||
defer.returnValue(valid_pdus)
|
defer.returnValue(valid_pdus)
|
||||||
|
@ -108,7 +103,9 @@ class FederationBase(object):
|
||||||
defer.returnValue([p for p in valid_pdus if p])
|
defer.returnValue([p for p in valid_pdus if p])
|
||||||
|
|
||||||
def _check_sigs_and_hash(self, pdu):
|
def _check_sigs_and_hash(self, pdu):
|
||||||
return self._check_sigs_and_hashes([pdu])[0]
|
return logcontext.make_deferred_yieldable(
|
||||||
|
self._check_sigs_and_hashes([pdu])[0],
|
||||||
|
)
|
||||||
|
|
||||||
def _check_sigs_and_hashes(self, pdus):
|
def _check_sigs_and_hashes(self, pdus):
|
||||||
"""Checks that each of the received events is correctly signed by the
|
"""Checks that each of the received events is correctly signed by the
|
||||||
|
@ -123,6 +120,7 @@ class FederationBase(object):
|
||||||
* returns a redacted version of the event (if the signature
|
* returns a redacted version of the event (if the signature
|
||||||
matched but the hash did not)
|
matched but the hash did not)
|
||||||
* throws a SynapseError if the signature check failed.
|
* throws a SynapseError if the signature check failed.
|
||||||
|
The deferreds run their callbacks in the sentinel logcontext.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
redacted_pdus = [
|
redacted_pdus = [
|
||||||
|
@ -130,12 +128,15 @@ class FederationBase(object):
|
||||||
for pdu in pdus
|
for pdu in pdus
|
||||||
]
|
]
|
||||||
|
|
||||||
deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
|
deferreds = self.keyring.verify_json_objects_for_server([
|
||||||
(p.origin, p.get_pdu_json())
|
(p.origin, p.get_pdu_json())
|
||||||
for p in redacted_pdus
|
for p in redacted_pdus
|
||||||
])
|
])
|
||||||
|
|
||||||
|
ctx = logcontext.LoggingContext.current_context()
|
||||||
|
|
||||||
def callback(_, pdu, redacted):
|
def callback(_, pdu, redacted):
|
||||||
|
with logcontext.PreserveLoggingContext(ctx):
|
||||||
if not check_event_content_hash(pdu):
|
if not check_event_content_hash(pdu):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Event content has been tampered, redacting %s: %s",
|
"Event content has been tampered, redacting %s: %s",
|
||||||
|
@ -154,6 +155,7 @@ class FederationBase(object):
|
||||||
|
|
||||||
def errback(failure, pdu):
|
def errback(failure, pdu):
|
||||||
failure.trap(SynapseError)
|
failure.trap(SynapseError)
|
||||||
|
with logcontext.PreserveLoggingContext(ctx):
|
||||||
logger.warn(
|
logger.warn(
|
||||||
"Signature check failed for %s",
|
"Signature check failed for %s",
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
|
|
|
@ -22,7 +22,7 @@ from synapse.api.constants import Membership
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
CodeMessageException, HttpResponseException, SynapseError,
|
CodeMessageException, HttpResponseException, SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
||||||
|
@ -189,10 +189,10 @@ class FederationClient(FederationBase):
|
||||||
]
|
]
|
||||||
|
|
||||||
# FIXME: We should handle signature failures more gracefully.
|
# FIXME: We should handle signature failures more gracefully.
|
||||||
pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
|
pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||||
self._check_sigs_and_hashes(pdus),
|
self._check_sigs_and_hashes(pdus),
|
||||||
consumeErrors=True,
|
consumeErrors=True,
|
||||||
)).addErrback(unwrapFirstError)
|
).addErrback(unwrapFirstError))
|
||||||
|
|
||||||
defer.returnValue(pdus)
|
defer.returnValue(pdus)
|
||||||
|
|
||||||
|
@ -252,7 +252,7 @@ class FederationClient(FederationBase):
|
||||||
pdu = pdu_list[0]
|
pdu = pdu_list[0]
|
||||||
|
|
||||||
# Check signatures are correct.
|
# Check signatures are correct.
|
||||||
signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
|
signed_pdu = yield self._check_sigs_and_hash(pdu)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
|
@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
|
||||||
keys[key_id] = key
|
keys[key_id] = key
|
||||||
defer.returnValue(keys)
|
defer.returnValue(keys)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
def store_server_verify_key(self, server_name, from_server, time_now_ms,
|
||||||
verify_key):
|
verify_key):
|
||||||
"""Stores a NACL verification key for the given server.
|
"""Stores a NACL verification key for the given server.
|
||||||
Args:
|
Args:
|
||||||
server_name (str): The name of the server.
|
server_name (str): The name of the server.
|
||||||
key_id (str): The version of the key for the server.
|
|
||||||
from_server (str): Where the verification key was looked up
|
from_server (str): Where the verification key was looked up
|
||||||
ts_now_ms (int): The time now in milliseconds
|
time_now_ms (int): The time now in milliseconds
|
||||||
verification_key (VerifyKey): The NACL verify key.
|
verify_key (nacl.signing.VerifyKey): The NACL verify key.
|
||||||
"""
|
"""
|
||||||
yield self._simple_upsert(
|
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||||
|
|
||||||
|
def _txn(txn):
|
||||||
|
self._simple_upsert_txn(
|
||||||
|
txn,
|
||||||
table="server_signature_keys",
|
table="server_signature_keys",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"server_name": server_name,
|
"server_name": server_name,
|
||||||
"key_id": "%s:%s" % (verify_key.alg, verify_key.version),
|
"key_id": key_id,
|
||||||
},
|
},
|
||||||
values={
|
values={
|
||||||
"from_server": from_server,
|
"from_server": from_server,
|
||||||
"ts_added_ms": time_now_ms,
|
"ts_added_ms": time_now_ms,
|
||||||
"verify_key": buffer(verify_key.encode()),
|
"verify_key": buffer(verify_key.encode()),
|
||||||
},
|
},
|
||||||
desc="store_server_verify_key",
|
|
||||||
)
|
)
|
||||||
|
txn.call_after(
|
||||||
|
self._get_server_verify_key.invalidate,
|
||||||
|
(server_name, key_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.runInteraction("store_server_verify_key", _txn)
|
||||||
|
|
||||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||||
|
|
|
@ -12,39 +12,72 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import signedjson
|
import time
|
||||||
|
|
||||||
|
import signedjson.key
|
||||||
|
import signedjson.sign
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError
|
||||||
from synapse.crypto import keyring
|
from synapse.crypto import keyring
|
||||||
from synapse.util import async
|
from synapse.util import async, logcontext
|
||||||
from synapse.util.logcontext import LoggingContext
|
from synapse.util.logcontext import LoggingContext
|
||||||
from tests import unittest, utils
|
from tests import unittest, utils
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
|
||||||
|
class MockPerspectiveServer(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.server_name = "mock_server"
|
||||||
|
self.key = signedjson.key.generate_signing_key(0)
|
||||||
|
|
||||||
|
def get_verify_keys(self):
|
||||||
|
vk = signedjson.key.get_verify_key(self.key)
|
||||||
|
return {
|
||||||
|
"%s:%s" % (vk.alg, vk.version): vk,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_signed_key(self, server_name, verify_key):
|
||||||
|
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
|
||||||
|
res = {
|
||||||
|
"server_name": server_name,
|
||||||
|
"old_verify_keys": {},
|
||||||
|
"valid_until_ts": time.time() * 1000 + 3600,
|
||||||
|
"verify_keys": {
|
||||||
|
key_id: {
|
||||||
|
"key": signedjson.key.encode_verify_key_base64(verify_key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
signedjson.sign.sign_json(res, self.server_name, self.key)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
class KeyringTestCase(unittest.TestCase):
|
class KeyringTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
self.mock_perspective_server = MockPerspectiveServer()
|
||||||
self.http_client = Mock()
|
self.http_client = Mock()
|
||||||
self.hs = yield utils.setup_test_homeserver(
|
self.hs = yield utils.setup_test_homeserver(
|
||||||
handlers=None,
|
handlers=None,
|
||||||
http_client=self.http_client,
|
http_client=self.http_client,
|
||||||
)
|
)
|
||||||
self.hs.config.perspectives = {
|
self.hs.config.perspectives = {
|
||||||
"persp_server": {"k": "v"}
|
self.mock_perspective_server.server_name:
|
||||||
|
self.mock_perspective_server.get_verify_keys()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def check_context(self, _, expected):
|
||||||
|
self.assertEquals(
|
||||||
|
getattr(LoggingContext.current_context(), "test_key", None),
|
||||||
|
expected
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_wait_for_previous_lookups(self):
|
def test_wait_for_previous_lookups(self):
|
||||||
sentinel_context = LoggingContext.current_context()
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
kr = keyring.Keyring(self.hs)
|
kr = keyring.Keyring(self.hs)
|
||||||
|
|
||||||
def check_context(_, expected):
|
|
||||||
self.assertEquals(
|
|
||||||
LoggingContext.current_context().test_key, expected
|
|
||||||
)
|
|
||||||
|
|
||||||
lookup_1_deferred = defer.Deferred()
|
lookup_1_deferred = defer.Deferred()
|
||||||
lookup_2_deferred = defer.Deferred()
|
lookup_2_deferred = defer.Deferred()
|
||||||
|
|
||||||
|
@ -60,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
self.assertTrue(wait_1_deferred.called)
|
self.assertTrue(wait_1_deferred.called)
|
||||||
# ... so we should have preserved the LoggingContext.
|
# ... so we should have preserved the LoggingContext.
|
||||||
self.assertIs(LoggingContext.current_context(), context_one)
|
self.assertIs(LoggingContext.current_context(), context_one)
|
||||||
wait_1_deferred.addBoth(check_context, "one")
|
wait_1_deferred.addBoth(self.check_context, "one")
|
||||||
|
|
||||||
with LoggingContext("two") as context_two:
|
with LoggingContext("two") as context_two:
|
||||||
context_two.test_key = "two"
|
context_two.test_key = "two"
|
||||||
|
@ -74,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
self.assertFalse(wait_2_deferred.called)
|
self.assertFalse(wait_2_deferred.called)
|
||||||
# ... so we should have reset the LoggingContext.
|
# ... so we should have reset the LoggingContext.
|
||||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||||
wait_2_deferred.addBoth(check_context, "two")
|
wait_2_deferred.addBoth(self.check_context, "two")
|
||||||
|
|
||||||
# let the first lookup complete (in the sentinel context)
|
# let the first lookup complete (in the sentinel context)
|
||||||
lookup_1_deferred.callback(None)
|
lookup_1_deferred.callback(None)
|
||||||
|
@ -89,18 +122,40 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
|
|
||||||
kr = keyring.Keyring(self.hs)
|
kr = keyring.Keyring(self.hs)
|
||||||
json1 = {}
|
json1 = {}
|
||||||
signedjson.sign.sign_json(json1, "server1", key1)
|
signedjson.sign.sign_json(json1, "server10", key1)
|
||||||
|
|
||||||
self.http_client.post_json.return_value = defer.Deferred()
|
persp_resp = {
|
||||||
|
"server_keys": [
|
||||||
|
self.mock_perspective_server.get_signed_key(
|
||||||
|
"server10",
|
||||||
|
signedjson.key.get_verify_key(key1)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
persp_deferred = defer.Deferred()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_perspectives(**kwargs):
|
||||||
|
self.assertEquals(
|
||||||
|
LoggingContext.current_context().test_key, "11",
|
||||||
|
)
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
yield persp_deferred
|
||||||
|
defer.returnValue(persp_resp)
|
||||||
|
self.http_client.post_json.side_effect = get_perspectives
|
||||||
|
|
||||||
|
with LoggingContext("11") as context_11:
|
||||||
|
context_11.test_key = "11"
|
||||||
|
|
||||||
# start off a first set of lookups
|
# start off a first set of lookups
|
||||||
res_deferreds = kr.verify_json_objects_for_server(
|
res_deferreds = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1),
|
[("server10", json1),
|
||||||
("server2", {})
|
("server11", {})
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# the unsigned json should be rejected pretty quickly
|
# the unsigned json should be rejected pretty quickly
|
||||||
|
self.assertTrue(res_deferreds[1].called)
|
||||||
try:
|
try:
|
||||||
yield res_deferreds[1]
|
yield res_deferreds[1]
|
||||||
self.assertFalse("unsigned json didn't cause a failure")
|
self.assertFalse("unsigned json didn't cause a failure")
|
||||||
|
@ -108,19 +163,67 @@ class KeyringTestCase(unittest.TestCase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.assertFalse(res_deferreds[0].called)
|
self.assertFalse(res_deferreds[0].called)
|
||||||
|
res_deferreds[0].addBoth(self.check_context, None)
|
||||||
|
|
||||||
# wait a tick for it to send the request to the perspectives server
|
# wait a tick for it to send the request to the perspectives server
|
||||||
# (it first tries the datastore)
|
# (it first tries the datastore)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(0.005)
|
||||||
self.http_client.post_json.assert_called_once()
|
self.http_client.post_json.assert_called_once()
|
||||||
|
|
||||||
# a second request for a server with outstanding requests should
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
# block rather than start a second call
|
|
||||||
|
context_12 = LoggingContext("12")
|
||||||
|
context_12.test_key = "12"
|
||||||
|
with logcontext.PreserveLoggingContext(context_12):
|
||||||
|
# a second request for a server with outstanding requests
|
||||||
|
# should block rather than start a second call
|
||||||
self.http_client.post_json.reset_mock()
|
self.http_client.post_json.reset_mock()
|
||||||
self.http_client.post_json.return_value = defer.Deferred()
|
self.http_client.post_json.return_value = defer.Deferred()
|
||||||
|
|
||||||
kr.verify_json_objects_for_server(
|
res_deferreds_2 = kr.verify_json_objects_for_server(
|
||||||
[("server1", json1)],
|
[("server10", json1)],
|
||||||
)
|
)
|
||||||
yield async.sleep(0.005)
|
yield async.sleep(0.005)
|
||||||
self.http_client.post_json.assert_not_called()
|
self.http_client.post_json.assert_not_called()
|
||||||
|
res_deferreds_2[0].addBoth(self.check_context, None)
|
||||||
|
|
||||||
|
# complete the first request
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
persp_deferred.callback(persp_resp)
|
||||||
|
self.assertIs(LoggingContext.current_context(), context_11)
|
||||||
|
|
||||||
|
with logcontext.PreserveLoggingContext():
|
||||||
|
yield res_deferreds[0]
|
||||||
|
yield res_deferreds_2[0]
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_verify_json_for_server(self):
|
||||||
|
kr = keyring.Keyring(self.hs)
|
||||||
|
|
||||||
|
key1 = signedjson.key.generate_signing_key(1)
|
||||||
|
yield self.hs.datastore.store_server_verify_key(
|
||||||
|
"server9", "", time.time() * 1000,
|
||||||
|
signedjson.key.get_verify_key(key1),
|
||||||
|
)
|
||||||
|
json1 = {}
|
||||||
|
signedjson.sign.sign_json(json1, "server9", key1)
|
||||||
|
|
||||||
|
sentinel_context = LoggingContext.current_context()
|
||||||
|
|
||||||
|
with LoggingContext("one") as context_one:
|
||||||
|
context_one.test_key = "one"
|
||||||
|
|
||||||
|
defer = kr.verify_json_for_server("server9", {})
|
||||||
|
try:
|
||||||
|
yield defer
|
||||||
|
self.fail("should fail on unsigned json")
|
||||||
|
except SynapseError:
|
||||||
|
pass
|
||||||
|
self.assertIs(LoggingContext.current_context(), context_one)
|
||||||
|
|
||||||
|
defer = kr.verify_json_for_server("server9", json1)
|
||||||
|
self.assertFalse(defer.called)
|
||||||
|
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||||
|
yield defer
|
||||||
|
|
||||||
|
self.assertIs(LoggingContext.current_context(), context_one)
|
||||||
|
|
Loading…
Reference in a new issue