mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-15 17:53:51 +01:00
Merge pull request #2206 from matrix-org/rav/one_time_key_upload_change_sig
Allow clients to upload one-time-keys with new sigs
This commit is contained in:
commit
5331cd150a
4 changed files with 238 additions and 35 deletions
|
@ -440,6 +440,16 @@ class FederationServer(FederationBase):
|
||||||
key_id: json.loads(json_bytes)
|
key_id: json.loads(json_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Claimed one-time-keys: %s",
|
||||||
|
",".join((
|
||||||
|
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||||
|
for user_id, user_keys in json_result.iteritems()
|
||||||
|
for device_id, device_keys in user_keys.iteritems()
|
||||||
|
for key_id, _ in device_keys.iteritems()
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({"one_time_keys": json_result})
|
defer.returnValue({"one_time_keys": json_result})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -21,7 +21,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError, CodeMessageException
|
from synapse.api.errors import SynapseError, CodeMessageException
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -145,7 +145,7 @@ class E2eKeysHandler(object):
|
||||||
"status": 503, "message": e.message
|
"status": 503, "message": e.message
|
||||||
}
|
}
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults([
|
yield make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(do_remote_query)(destination)
|
preserve_fn(do_remote_query)(destination)
|
||||||
for destination in remote_queries_not_in_cache
|
for destination in remote_queries_not_in_cache
|
||||||
]))
|
]))
|
||||||
|
@ -257,11 +257,21 @@ class E2eKeysHandler(object):
|
||||||
"status": 503, "message": e.message
|
"status": 503, "message": e.message
|
||||||
}
|
}
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults([
|
yield make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(claim_client_keys)(destination)
|
preserve_fn(claim_client_keys)(destination)
|
||||||
for destination in remote_queries
|
for destination in remote_queries
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Claimed one-time-keys: %s",
|
||||||
|
",".join((
|
||||||
|
"%s for %s:%s" % (key_id, user_id, device_id)
|
||||||
|
for user_id, user_keys in json_result.iteritems()
|
||||||
|
for device_id, device_keys in user_keys.iteritems()
|
||||||
|
for key_id, _ in device_keys.iteritems()
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
|
||||||
defer.returnValue({
|
defer.returnValue({
|
||||||
"one_time_keys": json_result,
|
"one_time_keys": json_result,
|
||||||
"failures": failures
|
"failures": failures
|
||||||
|
@ -288,19 +298,8 @@ class E2eKeysHandler(object):
|
||||||
|
|
||||||
one_time_keys = keys.get("one_time_keys", None)
|
one_time_keys = keys.get("one_time_keys", None)
|
||||||
if one_time_keys:
|
if one_time_keys:
|
||||||
logger.info(
|
yield self._upload_one_time_keys_for_user(
|
||||||
"Adding %d one_time_keys for device %r for user %r at %d",
|
user_id, device_id, time_now, one_time_keys,
|
||||||
len(one_time_keys), device_id, user_id, time_now
|
|
||||||
)
|
|
||||||
key_list = []
|
|
||||||
for key_id, key_json in one_time_keys.items():
|
|
||||||
algorithm, key_id = key_id.split(":")
|
|
||||||
key_list.append((
|
|
||||||
algorithm, key_id, encode_canonical_json(key_json)
|
|
||||||
))
|
|
||||||
|
|
||||||
yield self.store.add_e2e_one_time_keys(
|
|
||||||
user_id, device_id, time_now, key_list
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# the device should have been registered already, but it may have been
|
# the device should have been registered already, but it may have been
|
||||||
|
@ -313,3 +312,58 @@ class E2eKeysHandler(object):
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
|
||||||
defer.returnValue({"one_time_key_counts": result})
|
defer.returnValue({"one_time_key_counts": result})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
|
||||||
|
one_time_keys):
|
||||||
|
logger.info(
|
||||||
|
"Adding one_time_keys %r for device %r for user %r at %d",
|
||||||
|
one_time_keys.keys(), device_id, user_id, time_now,
|
||||||
|
)
|
||||||
|
|
||||||
|
# make a list of (alg, id, key) tuples
|
||||||
|
key_list = []
|
||||||
|
for key_id, key_obj in one_time_keys.items():
|
||||||
|
algorithm, key_id = key_id.split(":")
|
||||||
|
key_list.append((
|
||||||
|
algorithm, key_id, key_obj
|
||||||
|
))
|
||||||
|
|
||||||
|
# First we check if we have already persisted any of the keys.
|
||||||
|
existing_key_map = yield self.store.get_e2e_one_time_keys(
|
||||||
|
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||||
|
)
|
||||||
|
|
||||||
|
new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
|
||||||
|
for algorithm, key_id, key in key_list:
|
||||||
|
ex_json = existing_key_map.get((algorithm, key_id), None)
|
||||||
|
if ex_json:
|
||||||
|
if not _one_time_keys_match(ex_json, key):
|
||||||
|
raise SynapseError(
|
||||||
|
400,
|
||||||
|
("One time key %s:%s already exists. "
|
||||||
|
"Old key: %s; new key: %r") %
|
||||||
|
(algorithm, key_id, ex_json, key)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_keys.append((algorithm, key_id, encode_canonical_json(key)))
|
||||||
|
|
||||||
|
yield self.store.add_e2e_one_time_keys(
|
||||||
|
user_id, device_id, time_now, new_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _one_time_keys_match(old_key_json, new_key):
|
||||||
|
old_key = json.loads(old_key_json)
|
||||||
|
|
||||||
|
# if either is a string rather than an object, they must match exactly
|
||||||
|
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
|
||||||
|
return old_key == new_key
|
||||||
|
|
||||||
|
# otherwise, we strip off the 'signatures' if any, because it's legitimate
|
||||||
|
# for different upload attempts to have different signatures.
|
||||||
|
old_key.pop("signatures", None)
|
||||||
|
new_key_copy = dict(new_key)
|
||||||
|
new_key_copy.pop("signatures", None)
|
||||||
|
|
||||||
|
return old_key == new_key_copy
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.errors import SynapseError
|
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
|
@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
|
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
|
||||||
"""Insert some new one time keys for a device.
|
"""Retrieve a number of one-time keys for a user
|
||||||
|
|
||||||
Checks if any of the keys are already inserted, if they are then check
|
Args:
|
||||||
if they match. If they don't then we raise an error.
|
user_id(str): id of user to get keys for
|
||||||
|
device_id(str): id of device to get keys for
|
||||||
|
key_ids(list[str]): list of key ids (excluding algorithm) to
|
||||||
|
retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
deferred resolving to Dict[(str, str), str]: map from (algorithm,
|
||||||
|
key_id) to json string for key
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# First we check if we have already persisted any of the keys.
|
|
||||||
rows = yield self._simple_select_many_batch(
|
rows = yield self._simple_select_many_batch(
|
||||||
table="e2e_one_time_keys_json",
|
table="e2e_one_time_keys_json",
|
||||||
column="key_id",
|
column="key_id",
|
||||||
iterable=[key_id for _, key_id, _ in key_list],
|
iterable=key_ids,
|
||||||
retcols=("algorithm", "key_id", "key_json",),
|
retcols=("algorithm", "key_id", "key_json",),
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||||
desc="add_e2e_one_time_keys_check",
|
desc="add_e2e_one_time_keys_check",
|
||||||
)
|
)
|
||||||
|
|
||||||
existing_key_map = {
|
defer.returnValue({
|
||||||
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
|
||||||
}
|
})
|
||||||
|
|
||||||
new_keys = [] # Keys that we need to insert
|
@defer.inlineCallbacks
|
||||||
for algorithm, key_id, json_bytes in key_list:
|
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
|
||||||
ex_bytes = existing_key_map.get((algorithm, key_id), None)
|
"""Insert some new one time keys for a device. Errors if any of the
|
||||||
if ex_bytes:
|
keys already exist.
|
||||||
if json_bytes != ex_bytes:
|
|
||||||
raise SynapseError(
|
Args:
|
||||||
400, "One time key with key_id %r already exists" % (key_id,)
|
user_id(str): id of user to get keys for
|
||||||
)
|
device_id(str): id of device to get keys for
|
||||||
else:
|
time_now(long): insertion time to record (ms since epoch)
|
||||||
new_keys.append((algorithm, key_id, json_bytes))
|
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
|
||||||
|
(algorithm, key_id, key json)
|
||||||
|
"""
|
||||||
|
|
||||||
def _add_e2e_one_time_keys(txn):
|
def _add_e2e_one_time_keys(txn):
|
||||||
# We are protected from race between lookup and insertion due to
|
# We are protected from race between lookup and insertion due to
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
|
from synapse.api import errors
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
import synapse.api.errors
|
import synapse.api.errors
|
||||||
|
@ -44,3 +45,134 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||||
local_user = "@boris:" + self.hs.hostname
|
local_user = "@boris:" + self.hs.hostname
|
||||||
res = yield self.handler.query_local_devices({local_user: None})
|
res = yield self.handler.query_local_devices({local_user: None})
|
||||||
self.assertDictEqual(res, {local_user: {}})
|
self.assertDictEqual(res, {local_user: {}})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_reupload_one_time_keys(self):
|
||||||
|
"""we should be able to re-upload the same keys"""
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
keys = {
|
||||||
|
"alg1:k1": "key1",
|
||||||
|
"alg2:k2": {
|
||||||
|
"key": "key2",
|
||||||
|
"signatures": {"k1": "sig1"}
|
||||||
|
},
|
||||||
|
"alg2:k3": {
|
||||||
|
"key": "key3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res = yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": keys},
|
||||||
|
)
|
||||||
|
self.assertDictEqual(res, {
|
||||||
|
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||||
|
})
|
||||||
|
|
||||||
|
# we should be able to change the signature without a problem
|
||||||
|
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
||||||
|
res = yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": keys},
|
||||||
|
)
|
||||||
|
self.assertDictEqual(res, {
|
||||||
|
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||||
|
})
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_change_one_time_keys(self):
|
||||||
|
"""attempts to change one-time-keys should be rejected"""
|
||||||
|
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
keys = {
|
||||||
|
"alg1:k1": "key1",
|
||||||
|
"alg2:k2": {
|
||||||
|
"key": "key2",
|
||||||
|
"signatures": {"k1": "sig1"}
|
||||||
|
},
|
||||||
|
"alg2:k3": {
|
||||||
|
"key": "key3",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res = yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": keys},
|
||||||
|
)
|
||||||
|
self.assertDictEqual(res, {
|
||||||
|
"one_time_key_counts": {"alg1": 1, "alg2": 2}
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
|
||||||
|
)
|
||||||
|
self.fail("No error when changing string key")
|
||||||
|
except errors.SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
|
||||||
|
)
|
||||||
|
self.fail("No error when replacing dict key with string")
|
||||||
|
except errors.SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {
|
||||||
|
"one_time_keys": {"alg1:k1": {"key": "key"}}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.fail("No error when replacing string key with dict")
|
||||||
|
except errors.SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {
|
||||||
|
"one_time_keys": {
|
||||||
|
"alg2:k2": {
|
||||||
|
"key": "key3",
|
||||||
|
"signatures": {"k1": "sig1"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
self.fail("No error when replacing dict key")
|
||||||
|
except errors.SynapseError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@unittest.DEBUG
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_claim_one_time_key(self):
|
||||||
|
local_user = "@boris:" + self.hs.hostname
|
||||||
|
device_id = "xyz"
|
||||||
|
keys = {
|
||||||
|
"alg1:k1": "key1",
|
||||||
|
}
|
||||||
|
|
||||||
|
res = yield self.handler.upload_keys_for_user(
|
||||||
|
local_user, device_id, {"one_time_keys": keys},
|
||||||
|
)
|
||||||
|
self.assertDictEqual(res, {
|
||||||
|
"one_time_key_counts": {"alg1": 1}
|
||||||
|
})
|
||||||
|
|
||||||
|
res2 = yield self.handler.claim_one_time_keys({
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {
|
||||||
|
device_id: "alg1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, timeout=None)
|
||||||
|
self.assertEqual(res2, {
|
||||||
|
"failures": {},
|
||||||
|
"one_time_keys": {
|
||||||
|
local_user: {
|
||||||
|
device_id: {
|
||||||
|
"alg1:k1": "key1"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
Loading…
Reference in a new issue