0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 19:23:53 +01:00

Merge pull request #5727 from matrix-org/uhoreg/e2e_cross-signing2-part3

Cross-signing [4/4] -- federation edition
This commit is contained in:
Hubert Chathi 2019-10-31 23:59:35 -04:00 committed by GitHub
commit 53d7680e32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 241 additions and 47 deletions

1
changelog.d/5727.feature Normal file
View file

@ -0,0 +1 @@
Add federation support for cross-signing.

View file

@ -360,20 +360,20 @@ class PerDestinationQueue(object):
last_device_list = self._last_device_list_stream_id last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination # Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote( now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit self._destination, last_device_list, limit=limit
) )
edus = [ edus = [
Edu( Edu(
origin=self._server_name, origin=self._server_name,
destination=self._destination, destination=self._destination,
edu_type="m.device_list_update", edu_type=edu_type,
content=content, content=content,
) )
for content in results for (edu_type, content) in results
] ]
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
return (edus, now_stream_id) return (edus, now_stream_id)

View file

@ -459,7 +459,18 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id): def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
return {"user_id": user_id, "stream_id": stream_id, "devices": devices} master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
return {
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
"master_key": master_key,
"self_signing_key": self_signing_key,
}
@defer.inlineCallbacks @defer.inlineCallbacks
def user_left_room(self, user, room_id): def user_left_room(self, user, room_id):

View file

@ -36,6 +36,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
) )
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,10 +51,19 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine self.is_mine = hs.is_mine
self.clock = hs.get_clock() self.clock = hs.get_clock()
self._edu_updater = SigningKeyEduUpdater(hs, self)
federation_registry = hs.get_federation_registry()
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
federation_registry.register_edu_handler(
"org.matrix.signing_key_update",
self._edu_updater.incoming_signing_key_update,
)
# doesn't really work as part of the generic query API, because the # doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the # query request requires an object POST, but we abuse the
# "query handler" interface. # "query handler" interface.
hs.get_federation_registry().register_query_handler( federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys "client_keys", self.on_federation_query_client_keys
) )
@ -208,13 +219,15 @@ class E2eKeysHandler(object):
if user_id in destination_query: if user_id in destination_query:
results[user_id] = keys results[user_id] = keys
for user_id, key in remote_result["master_keys"].items(): if "master_keys" in remote_result:
if user_id in destination_query: for user_id, key in remote_result["master_keys"].items():
cross_signing_keys["master_keys"][user_id] = key if user_id in destination_query:
cross_signing_keys["master_keys"][user_id] = key
for user_id, key in remote_result["self_signing_keys"].items(): if "self_signing_keys" in remote_result:
if user_id in destination_query: for user_id, key in remote_result["self_signing_keys"].items():
cross_signing_keys["self_signing_keys"][user_id] = key if user_id in destination_query:
cross_signing_keys["self_signing_keys"][user_id] = key
except Exception as e: except Exception as e:
failure = _exception_to_failure(e) failure = _exception_to_failure(e)
@ -252,7 +265,7 @@ class E2eKeysHandler(object):
Returns: Returns:
defer.Deferred[dict[str, dict[str, dict]]]: map from defer.Deferred[dict[str, dict[str, dict]]]: map from
(master|self_signing|user_signing) -> user_id -> key (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
""" """
master_keys = {} master_keys = {}
self_signing_keys = {} self_signing_keys = {}
@ -344,7 +357,16 @@ class E2eKeysHandler(object):
""" """
device_keys_query = query_body.get("device_keys", {}) device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query) res = yield self.query_local_devices(device_keys_query)
return {"device_keys": res} ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
ret.update(cross_signing_keys)
return ret
@trace @trace
@defer.inlineCallbacks @defer.inlineCallbacks
@ -1058,3 +1080,100 @@ class SignatureListItem:
target_user_id = attr.ib() target_user_id = attr.ib()
target_device_id = attr.ib() target_device_id = attr.ib()
signature = attr.ib() signature = attr.ib()
class SigningKeyEduUpdater(object):
"""Handles incoming signing key updates from federation and updates the DB"""
def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
self.federation = hs.get_federation_client()
self.clock = hs.get_clock()
self.e2e_keys_handler = e2e_keys_handler
self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="signing_key_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
@defer.inlineCallbacks
def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
Args:
origin (string): the server that sent the EDU
edu_content (dict): the contents of the EDU
"""
user_id = edu_content.pop("user_id")
master_key = edu_content.pop("master_key", None)
self_signing_key = edu_content.pop("self_signing_key", None)
if get_domain_from_id(user_id) != origin:
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
self._pending_updates.setdefault(user_id, []).append(
(master_key, self_signing_key)
)
yield self._handle_signing_key_updates(user_id)
@defer.inlineCallbacks
def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
user_id (string): the user whose updates we are processing
"""
device_handler = self.e2e_keys_handler.device_handler
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
return
device_ids = []
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(
self_signing_key
)
device_ids.append(verify_key.version)
yield device_handler.notify_device_update(user_id, device_ids)

View file

@ -37,6 +37,7 @@ from synapse.storage._base import (
make_in_list_sql_clause, make_in_list_sql_clause,
) )
from synapse.storage.background_updates import BackgroundUpdateStore from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@ -90,13 +91,18 @@ class DeviceWorkerStore(SQLBaseStore):
@trace @trace
@defer.inlineCallbacks @defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit): def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers """Get a stream of device updates to send to the given remote server.
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
limit (int): Maximum number of device updates to return
Returns: Returns:
Deferred[tuple[int, list[dict]]]: Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the current stream id (ie, the stream id of the last update included in the
response), and the list of updates response), and the list of updates, where each update is a pair of EDU
type and EDU contents
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
@ -117,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
# stream_id; the rationale being that such a large device list update # stream_id; the rationale being that such a large device list update
# is likely an error. # is likely an error.
updates = yield self.runInteraction( updates = yield self.runInteraction(
"get_devices_by_remote", "get_device_updates_by_remote",
self._get_devices_by_remote_txn, self._get_device_updates_by_remote_txn,
destination, destination,
from_stream_id, from_stream_id,
now_stream_id, now_stream_id,
@ -129,6 +135,37 @@ class DeviceWorkerStore(SQLBaseStore):
if not updates: if not updates:
return now_stream_id, [] return now_stream_id, []
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
for user in users:
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
master_key_by_user[user] = {
"key_info": cross_signing_key,
"device_id": verify_key.version,
}
cross_signing_key = yield self.get_e2e_cross_signing_key(
user, "self_signing"
)
if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
"device_id": verify_key.version,
}
# if we have exceeded the limit, we need to exclude any results with the # if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row. # same stream_id as the last row.
if len(updates) > limit: if len(updates) > limit:
@ -153,20 +190,33 @@ class DeviceWorkerStore(SQLBaseStore):
# context which created the Edu. # context which created the Edu.
query_map = {} query_map = {}
for update in updates: cross_signing_keys_by_user = {}
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff: for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates # Stop processing updates
break break
key = (update[0], update[1]) if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
elif (
user_id in self_signing_key_by_user
and device_id == self_signing_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
else:
key = (user_id, device_id)
update_context = update[3] previous_update_stream_id, _ = query_map.get(key, (0, None))
update_stream_id = update[2]
previous_update_stream_id, _ = query_map.get(key, (0, None)) if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it # If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same # means that there are more than limit updates all of which have the same
@ -176,16 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just # devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next # skip that stream_id and return an empty list, and continue with the next
# stream_id next time. # stream_id next time.
if not query_map: if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, [] return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote( results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map destination, from_stream_id, query_map
) )
# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))
return now_stream_id, results return now_stream_id, results
def _get_devices_by_remote_txn( def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit self, txn, destination, from_stream_id, now_stream_id, limit
): ):
"""Return device update information for a given remote destination """Return device update information for a given remote destination
@ -200,6 +256,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns: Returns:
List: List of device updates List: List of device updates
""" """
# get the list of device updates that need to be sent
sql = """ sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
@ -225,12 +282,16 @@ class DeviceWorkerStore(SQLBaseStore):
List[Dict]: List of objects representing an device update EDU List[Dict]: List of objects representing an device update EDU
""" """
devices = yield self.runInteraction( devices = (
"_get_e2e_device_keys_txn", yield self.runInteraction(
self._get_e2e_device_keys_txn, "_get_e2e_device_keys_txn",
query_map.keys(), self._get_e2e_device_keys_txn,
include_all_devices=True, query_map.keys(),
include_deleted_devices=True, include_all_devices=True,
include_deleted_devices=True,
)
if query_map
else {}
) )
results = [] results = []
@ -262,7 +323,7 @@ class DeviceWorkerStore(SQLBaseStore):
else: else:
result["deleted"] = True result["deleted"] = True
results.append(result) results.append(("m.device_list_update", result))
return results return results

View file

@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote", "get_device_updates_by_remote",
# Bits that user_directory needs # Bits that user_directory needs
"get_user_directory_stream_pos", "get_user_directory_stream_pos",
"get_current_state_deltas", "get_current_state_deltas",
@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res retry_timings_res
) )
self.datastore.get_devices_by_remote.return_value = (0, []) self.datastore.get_device_updates_by_remote.return_value = (0, [])
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)

View file

@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
) )
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_devices_by_remote(self): def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id # Add two device updates with a single stream_id
@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
) )
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote( now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100 "somehost", -1, limit=100
) )
@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self._check_devices_in_updates(device_ids, device_updates) self._check_devices_in_updates(device_ids, device_updates)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_devices_by_remote_limited(self): def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments # Test breaking the update limit in 1, 101, and 1 device_id segments
# first add one device # first add one device
@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
# #
# first we should get a single update # first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote( now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100 "someotherhost", -1, limit=100
) )
self._check_devices_in_updates(device_ids1, device_updates) self._check_devices_in_updates(device_ids1, device_updates)
# Then we should get an empty list back as the 101 devices broke the limit # Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote( now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100 "someotherhost", now_stream_id, limit=100
) )
self.assertEqual(len(device_updates), 0) self.assertEqual(len(device_updates), 0)
# The 101 devices should've been cleared, so we should now just get one device # The 101 devices should've been cleared, so we should now just get one device
# update # update
now_stream_id, device_updates = yield self.store.get_devices_by_remote( now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100 "someotherhost", now_stream_id, limit=100
) )
self._check_devices_in_updates(device_ids3, device_updates) self._check_devices_in_updates(device_ids3, device_updates)
@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
"""Check that an specific device ids exist in a list of device update EDUs""" """Check that an specific device ids exist in a list of device update EDUs"""
self.assertEqual(len(device_updates), len(expected_device_ids)) self.assertEqual(len(device_updates), len(expected_device_ids))
received_device_ids = {update["device_id"] for update in device_updates} received_device_ids = {
update["device_id"] for edu_type, update in device_updates
}
self.assertEqual(received_device_ids, set(expected_device_ids)) self.assertEqual(received_device_ids, set(expected_device_ids))
@defer.inlineCallbacks @defer.inlineCallbacks