Merge pull request #1861 from matrix-org/erikj/device_list_fixes

Device List fixes
This commit is contained in:
Erik Johnston 2017-01-30 17:56:19 +00:00 committed by GitHub
commit 4c9812f5da
6 changed files with 63 additions and 25 deletions

View file

@ -203,7 +203,7 @@ class DeviceHandler(BaseHandler):
hosts = set() hosts = set()
if self.hs.is_mine_id(user_id): if self.hs.is_mine_id(user_id):
for room_id in room_ids: for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id) users = yield self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users) hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name) hosts.discard(self.server_name)

View file

@ -194,7 +194,7 @@ class E2eKeysHandler(object):
# "unsigned" section # "unsigned" section
for user_id, device_keys in results.items(): for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items(): for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"]) r = dict(device_info["keys"])
r["unsigned"] = {} r["unsigned"] = {}
display_name = device_info["device_display_name"] display_name = device_info["device_display_name"]
if display_name is not None: if display_name is not None:
@ -287,10 +287,11 @@ class E2eKeysHandler(object):
device_id, user_id, time_now device_id, user_id, time_now
) )
# TODO: Sign the JSON with the server key # TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys( changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, user_id, device_id, time_now, device_keys,
encode_canonical_json(device_keys)
) )
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id]) yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)

View file

@ -164,6 +164,7 @@ class DeviceStore(SQLBaseStore):
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
}, },
desc="mark_remote_user_device_list_as_unsubscribed",
) )
def update_remote_device_list_cache_entry(self, user_id, device_id, content, def update_remote_device_list_cache_entry(self, user_id, device_id, content,
@ -463,7 +464,7 @@ class DeviceStore(SQLBaseStore):
SELECT user_id FROM device_lists_stream WHERE stream_id > ? SELECT user_id FROM device_lists_stream WHERE stream_id > ?
""" """
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row["user_id"] for row in rows)) defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key): def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the """Return a list of `(stream_id, user_id, destination)` which is the

View file

@ -14,12 +14,35 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from canonicaljson import encode_canonical_json
import ujson as json
from ._base import SQLBaseStore from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore): class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
return self._simple_upsert( """Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
new_key_json = encode_canonical_json(device_keys)
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json", table="e2e_device_keys_json",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
@ -27,10 +50,17 @@ class EndToEndKeyStore(SQLBaseStore):
}, },
values={ values={
"ts_added_ms": time_now, "ts_added_ms": time_now,
"key_json": json_bytes, "key_json": new_key_json,
} }
) )
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
@defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False): def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
@ -42,13 +72,19 @@ class EndToEndKeyStore(SQLBaseStore):
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
""" """
if not query_list: if not query_list:
return {} defer.returnValue({})
return self.runInteraction( results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices, query_list, include_all_devices,
) )
for user_id, device_keys in results.iteritems():
for device_id, device_info in device_keys.iteritems():
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = [] query_clauses = []
query_params = [] query_params = []

View file

@ -131,7 +131,7 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering) yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000) @cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id): def get_users_in_room(self, room_id):
def f(txn): def f(txn):

View file

@ -33,7 +33,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_key_without_device_name(self): def test_key_without_device_name(self):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = {"key": "value"}
yield self.store.store_device( yield self.store.store_device(
"user", "device", None "user", "device", None
@ -47,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset({
"key_json": json, "keys": json,
"device_display_name": None, "device_display_name": None,
}, dev) }, dev)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_key_with_device_name(self): def test_get_key_with_device_name(self):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = {"key": "value"}
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user", "device", now, json) "user", "device", now, json)
@ -67,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"]) self.assertIn("device", res["user"])
dev = res["user"]["device"] dev = res["user"]["device"]
self.assertDictContainsSubset({ self.assertDictContainsSubset({
"key_json": json, "keys": json,
"device_display_name": "display_name", "device_display_name": "display_name",
}, dev) }, dev)