Convert devices database to async/await. (#8069)

This commit is contained in:
Patrick Cloke 2020-08-12 10:51:42 -04:00 committed by GitHub
parent 5dd73d029e
commit 5ecc8b5825
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 220 additions and 176 deletions

1
changelog.d/8069.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List, Optional, Set, Tuple from typing import Dict, Iterable, List, Optional, Set, Tuple
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -33,14 +31,9 @@ from synapse.storage.database import (
LoggingTransaction, LoggingTransaction,
make_tuple_comparison_clause, make_tuple_comparison_clause,
) )
from synapse.types import Collection, get_verify_key_from_cross_signing_key from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import ( from synapse.util.caches.descriptors import Cache, cached, cachedList
Cache,
cached,
cachedInlineCallbacks,
cachedList,
)
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore): class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id): def get_device(self, user_id: str, device_id: str):
"""Retrieve a device. Only returns devices that are not marked as """Retrieve a device. Only returns devices that are not marked as
hidden. hidden.
Args: Args:
user_id (str): The ID of the user which owns the device user_id: The ID of the user which owns the device
device_id (str): The ID of the device to retrieve device_id: The ID of the device to retrieve
Returns: Returns:
defer.Deferred for a dict containing the device information defer.Deferred for a dict containing the device information
Raises: Raises:
@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_device", desc="get_device",
) )
@defer.inlineCallbacks async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices. Only returns devices """Retrieve all of a user's registered devices. Only returns devices
that are not marked as hidden. that are not marked as hidden.
Args: Args:
user_id (str): user_id:
Returns: Returns:
defer.Deferred: resolves to a dict from device_id to a dict A mapping from device_id to a dict containing "device_id", "user_id"
containing "device_id", "user_id" and "display_name" for each and "display_name" for each device.
device.
""" """
devices = yield self.db_pool.simple_select_list( devices = await self.db_pool.simple_select_list(
table="devices", table="devices",
keyvalues={"user_id": user_id, "hidden": False}, keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"), retcols=("user_id", "device_id", "display_name"),
@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices} return {d["device_id"]: d for d in devices}
@trace @trace
@defer.inlineCallbacks async def get_device_updates_by_remote(
def get_device_updates_by_remote(self, destination, from_stream_id, limit): self, destination: str, from_stream_id: int, limit: int
) -> Tuple[int, List[Tuple[str, dict]]]:
"""Get a stream of device updates to send to the given remote server. """Get a stream of device updates to send to the given remote server.
Args: Args:
destination (str): The host the device updates are intended for destination: The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive from_stream_id: The minimum stream_id to filter updates by, exclusive
limit (int): Maximum number of device updates to return limit: Maximum number of device updates to return
Returns: Returns:
Deferred[tuple[int, list[tuple[string,dict]]]]: A mapping from the current stream id (ie, the stream id of the last
current stream id (ie, the stream id of the last update included in the update included in the response), and the list of updates, where
response), and the list of updates, where each update is a pair of EDU each update is a pair of EDU type and EDU contents.
type and EDU contents
""" """
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed: if not has_changed:
return now_stream_id, [] return now_stream_id, []
updates = yield self.db_pool.runInteraction( updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote", "get_device_updates_by_remote",
self._get_device_updates_by_remote_txn, self._get_device_updates_by_remote_txn,
destination, destination,
@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
master_key_by_user = {} master_key_by_user = {}
self_signing_key_by_user = {} self_signing_key_by_user = {}
for user in users: for user in users:
cross_signing_key = yield defer.ensureDeferred( cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
self.get_e2e_cross_signing_key(user, "master")
)
if cross_signing_key: if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key( key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key cross_signing_key
@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version, "device_id": verify_key.version,
} }
cross_signing_key = yield defer.ensureDeferred( cross_signing_key = await self.get_e2e_cross_signing_key(
self.get_e2e_cross_signing_key(user, "self_signing") user, "self_signing"
) )
if cross_signing_key: if cross_signing_key:
key_id, verify_key = get_verify_key_from_cross_signing_key( key_id, verify_key = get_verify_key_from_cross_signing_key(
@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id: if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context) query_map[key] = (update_stream_id, update_context)
results = yield self._get_device_update_edus_by_remote( results = await self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map destination, from_stream_id, query_map
) )
@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, results return now_stream_id, results
def _get_device_updates_by_remote_txn( def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit self,
txn: LoggingTransaction,
destination: str,
from_stream_id: int,
now_stream_id: int,
limit: int,
): ):
"""Return device update information for a given remote destination """Return device update information for a given remote destination
Args: Args:
txn (LoggingTransaction): The transaction to execute txn: The transaction to execute
destination (str): The host the device updates are intended for destination: The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive from_stream_id: The minimum stream_id to filter updates by, exclusive
now_stream_id (int): The maximum stream_id to filter updates by, inclusive now_stream_id: The maximum stream_id to filter updates by, inclusive
limit (int): Maximum number of device updates to return limit: Maximum number of device updates to return
Returns: Returns:
List: List of device updates List: List of device updates
@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn) return list(txn)
@defer.inlineCallbacks async def _get_device_update_edus_by_remote(
def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map): self,
destination: str,
from_stream_id: int,
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
) -> List[Tuple[str, dict]]:
"""Returns a list of device update EDUs as well as E2EE keys """Returns a list of device update EDUs as well as E2EE keys
Args: Args:
destination (str): The host the device updates are intended for destination: The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive from_stream_id: The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevant json-encoded user_id/device_id to update stream_id and the relevant json-encoded
opentracing context opentracing context
Returns: Returns:
List[Dict]: List of objects representing an device update EDU List of objects representing an device update EDU
""" """
devices = ( devices = (
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_get_e2e_device_keys_txn", "_get_e2e_device_keys_txn",
self._get_e2e_device_keys_txn, self._get_e2e_device_keys_txn,
query_map.keys(), query_map.keys(),
@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
for user_id, user_devices in devices.items(): for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before # The prev_id for the first row is always the last row before
# `from_stream_id` # `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user( prev_id = await self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id destination, user_id, from_stream_id
) )
@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results return results
def _get_last_device_update_for_remote_user( def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id self, destination: str, user_id: str, from_stream_id: int
): ):
def f(txn): def f(txn):
prev_sent_id_sql = """ prev_sent_id_sql = """
@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id): def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
"""Mark that updates have successfully been sent to the destination. """Mark that updates have successfully been sent to the destination.
""" """
return self.db_pool.runInteraction( return self.db_pool.runInteraction(
@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
stream_id, stream_id,
) )
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): def _mark_as_sent_devices_by_remote_txn(
self, txn: LoggingTransaction, destination: str, stream_id: int
) -> None:
# We update the device_lists_outbound_last_success with the successfully # We update the device_lists_outbound_last_success with the successfully
# poked users. # poked users.
sql = """ sql = """
@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
""" """
txn.execute(sql, (destination, stream_id)) txn.execute(sql, (destination, stream_id))
@defer.inlineCallbacks async def add_user_signature_change_to_streams(
def add_user_signature_change_to_streams(self, from_user_id, user_ids): self, from_user_id: str, user_ids: List[str]
) -> int:
"""Persist that a user has made new signatures """Persist that a user has made new signatures
Args: Args:
from_user_id (str): the user who made the signatures from_user_id: the user who made the signatures
user_ids (list[str]): the users who were signed user_ids: the users who were signed
Returns:
THe new stream ID.
""" """
with self._device_list_id_gen.get_next() as stream_id: with self._device_list_id_gen.get_next() as stream_id:
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_user_sig_change_to_streams", "add_user_sig_change_to_streams",
self._add_user_signature_change_txn, self._add_user_signature_change_txn,
from_user_id, from_user_id,
@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
) )
return stream_id return stream_id
def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): def _add_user_signature_change_txn(
self,
txn: LoggingTransaction,
from_user_id: str,
user_ids: List[str],
stream_id: int,
) -> None:
txn.call_after( txn.call_after(
self._user_signature_stream_cache.entity_has_changed, self._user_signature_stream_cache.entity_has_changed,
from_user_id, from_user_id,
@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
}, },
) )
def get_device_stream_token(self): def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
@trace @trace
@defer.inlineCallbacks async def get_user_devices_from_cache(
def get_user_devices_from_cache(self, query_list): self, query_list: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.
Args: Args:
query_list(list): List of (user_id, device_ids), if device_ids is query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user. falsey then return all device ids for that user.
Returns: Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is A tuple of (user_ids_not_in_cache, results_map), where
a set of user_ids and results_map is a mapping of user_ids_not_in_cache is a set of user_ids and results_map is a
user_id -> device_id -> device_info mapping of user_id -> device_id -> device_info.
""" """
user_ids = {user_id for user_id, _ in query_list} user_ids = {user_id for user_id, _ in query_list}
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
# We go and check if any of the users need to have their device lists # We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list. # resynced. If they do then we remove them from the cached list.
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync( users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids user_ids
) )
user_ids_in_cache = { user_ids_in_cache = {
@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
continue continue
if device_id: if device_id:
device = yield self._get_cached_user_device(user_id, device_id) device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device results.setdefault(user_id, {})[device_id] = device
else: else:
results[user_id] = yield self.get_cached_devices_for_user(user_id) results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results) set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache) set_tag("not_in_cache", user_ids_not_in_cache)
return user_ids_not_in_cache, results return user_ids_not_in_cache, results
@cachedInlineCallbacks(num_args=2, tree=True) @cached(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id): async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
content = yield self.db_pool.simple_select_one_onecol( content = await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content", retcol="content",
@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
) )
return db_to_json(content) return db_to_json(content)
@cachedInlineCallbacks() @cached()
def get_cached_devices_for_user(self, user_id): async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
devices = yield self.db_pool.simple_select_list( devices = await self.db_pool.simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("device_id", "content"), retcols=("device_id", "content"),
@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices device["device_id"]: db_to_json(device["content"]) for device in devices
} }
def get_devices_with_keys_by_user(self, user_id): def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user
Returns: Returns:
(stream_id, devices) Deferred which resolves to (stream_id, devices)
""" """
return self.db_pool.runInteraction( return self.db_pool.runInteraction(
"get_devices_with_keys_by_user", "get_devices_with_keys_by_user",
@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
user_id, user_id,
) )
def _get_devices_with_keys_by_user_txn(self, txn, user_id): def _get_devices_with_keys_by_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token() now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn( devices = self._get_e2e_device_keys_txn(
@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, [] return now_stream_id, []
def get_users_whose_devices_changed(self, from_key, user_ids): async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that """Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids. are in the given list of user_ids.
Args: Args:
from_key (str): The device lists stream token from_key: The device lists stream token
user_ids (Iterable[str]) user_ids: The user IDs to query for devices.
Returns: Returns:
Deferred[set[str]]: The set of user_ids whose devices have changed The set of user_ids whose devices have changed since `from_key`
since `from_key`
""" """
from_key = int(from_key) from_key = int(from_key)
@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
if not to_check: if not to_check:
return defer.succeed(set()) return set()
def _get_users_whose_devices_changed_txn(txn): def _get_users_whose_devices_changed_txn(txn):
changes = set() changes = set()
@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
return changes return changes
return self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
) )
@defer.inlineCallbacks async def get_users_whose_signatures_changed(
def get_users_whose_signatures_changed(self, user_id, from_key): self, user_id: str, from_key: str
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since """Get the users who have new cross-signing signatures made by `user_id` since
`from_key`. `from_key`.
Args: Args:
user_id (str): the user who made the signatures user_id: the user who made the signatures
from_key (str): The device lists stream token from_key: The device lists stream token
Returns:
A set of user IDs with updated signatures.
""" """
from_key = int(from_key) from_key = int(from_key)
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ? WHERE from_user_id = ? AND stream_id > ?
""" """
rows = yield self.db_pool.execute( rows = await self.db_pool.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key "get_users_whose_signatures_changed", None, sql, user_id, from_key
) )
return {user for row in rows for user in db_to_json(row[0])} return {user for row in rows for user in db_to_json(row[0])}
@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
@cached(max_entries=10000) @cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id): def get_device_list_last_stream_id_for_remote(self, user_id: str):
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
""" """
@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
list_name="user_ids", list_name="user_ids",
inlineCallbacks=True, inlineCallbacks=True,
) )
def get_device_list_last_stream_id_for_remotes(self, user_ids): def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch( rows = yield self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
column="user_id", column="user_id",
@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
return results return results
@defer.inlineCallbacks async def get_user_ids_requiring_device_list_resync(
def get_user_ids_requiring_device_list_resync(
self, user_ids: Optional[Collection[str]] = None, self, user_ids: Optional[Collection[str]] = None,
) -> Set[str]: ) -> Set[str]:
"""Given a list of remote users return the list of users that we """Given a list of remote users return the list of users that we
@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
The IDs of users whose device lists need resync. The IDs of users whose device lists need resync.
""" """
if user_ids: if user_ids:
rows = yield self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_resync", table="device_lists_remote_resync",
column="user_id", column="user_id",
iterable=user_ids, iterable=user_ids,
@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_user_ids_requiring_device_list_resync_with_iterable", desc="get_user_ids_requiring_device_list_resync_with_iterable",
) )
else: else:
rows = yield self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="device_lists_remote_resync", table="device_lists_remote_resync",
keyvalues=None, keyvalues=None,
retcols=("user_id",), retcols=("user_id",),
@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale", desc="make_remote_user_device_cache_as_stale",
) )
def mark_remote_user_device_list_as_unsubscribed(self, user_id): def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
"""Mark that we no longer track device lists for remote user. """Mark that we no longer track device lists for remote user.
""" """
@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
"drop_device_lists_outbound_last_success_non_unique_idx", "drop_device_lists_outbound_last_success_non_unique_idx",
) )
@defer.inlineCallbacks async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn): def f(conn):
txn = conn.cursor() txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id") txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close() txn.close()
yield self.db_pool.runWithConnection(f) await self.db_pool.runWithConnection(f)
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
) )
return 1 return 1
@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000) self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
@defer.inlineCallbacks async def store_device(
def store_device(self, user_id, device_id, initial_device_display_name): self, user_id: str, device_id: str, initial_device_display_name: str
) -> bool:
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
Args: Args:
user_id (str): id of user associated with the device user_id: id of user associated with the device
device_id (str): id of device device_id: id of device
initial_device_display_name (str): initial displayname of the initial_device_display_name: initial displayname of the device.
device. Ignored if device exists. Ignored if device exists.
Returns: Returns:
defer.Deferred: boolean whether the device was inserted or an Whether the device was inserted or an existing device existed with that ID.
existing device existed with that ID.
Raises: Raises:
StoreError: if the device is already in use StoreError: if the device is already in use
""" """
@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False return False
try: try:
inserted = yield self.db_pool.simple_insert( inserted = await self.db_pool.simple_insert(
"devices", "devices",
values={ values={
"user_id": user_id, "user_id": user_id,
@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted: if not inserted:
# if the device already exists, check if it's a real device, or # if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else # if the device ID is reserved by something else
hidden = yield self.db_pool.simple_select_one_onecol( hidden = await self.db_pool.simple_select_one_onecol(
"devices", "devices",
keyvalues={"user_id": user_id, "device_id": device_id}, keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden", retcol="hidden",
@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
raise StoreError(500, "Problem storing device.") raise StoreError(500, "Problem storing device.")
@defer.inlineCallbacks async def delete_device(self, user_id: str, device_id: str) -> None:
def delete_device(self, user_id, device_id):
"""Delete a device. """Delete a device.
Args: Args:
user_id (str): The ID of the user which owns the device user_id: The ID of the user which owns the device
device_id (str): The ID of the device to delete device_id: The ID of the device to delete
Returns:
defer.Deferred
""" """
yield self.db_pool.simple_delete_one( await self.db_pool.simple_delete_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device", desc="delete_device",
@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))
@defer.inlineCallbacks async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
def delete_devices(self, user_id, device_ids):
"""Deletes several devices. """Deletes several devices.
Args: Args:
user_id (str): The ID of the user which owns the devices user_id: The ID of the user which owns the devices
device_ids (list): The IDs of the devices to delete device_ids: The IDs of the devices to delete
Returns:
defer.Deferred
""" """
yield self.db_pool.simple_delete_many( await self.db_pool.simple_delete_many(
table="devices", table="devices",
column="device_id", column="device_id",
iterable=device_ids, iterable=device_ids,
@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
for device_id in device_ids: for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id)) self.device_id_exists_cache.invalidate((user_id, device_id))
def update_device(self, user_id, device_id, new_display_name=None): async def update_device(
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
) -> None:
"""Update a device. Only updates the device if it is not marked as """Update a device. Only updates the device if it is not marked as
hidden. hidden.
Args: Args:
user_id (str): The ID of the user which owns the device user_id: The ID of the user which owns the device
device_id (str): The ID of the device to update device_id: The ID of the device to update
new_display_name (str|None): new displayname for device; None new_display_name: new displayname for device; None to leave unchanged
to leave unchanged
Raises: Raises:
StoreError: if the device is not found StoreError: if the device is not found
Returns:
defer.Deferred
""" """
updates = {} updates = {}
if new_display_name is not None: if new_display_name is not None:
updates["display_name"] = new_display_name updates["display_name"] = new_display_name
if not updates: if not updates:
return defer.succeed(None) return None
return self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="devices", table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates, updatevalues=updates,
@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
def update_remote_device_list_cache_entry( def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id self, user_id: str, device_id: str, content: JsonDict, stream_id: int
): ):
"""Updates a single device in the cache of a remote user's devicelist. """Updates a single device in the cache of a remote user's devicelist.
@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
device list. device list.
Args: Args:
user_id (str): User to update device list for user_id: User to update device list for
device_id (str): ID of decivice being updated device_id: ID of decivice being updated
content (dict): new data on this device content: new data on this device
stream_id (int): the version of the device list stream_id: the version of the device list
Returns: Returns:
Deferred[None] Deferred[None]
@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
def _update_remote_device_list_cache_entry_txn( def _update_remote_device_list_cache_entry_txn(
self, txn, user_id, device_id, content, stream_id self,
): txn: LoggingTransaction,
user_id: str,
device_id: str,
content: JsonDict,
stream_id: int,
) -> None:
if content.get("deleted"): if content.get("deleted"):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
lock=False, lock=False,
) )
def update_remote_device_list_cache(self, user_id, devices, stream_id): def update_remote_device_list_cache(
self, user_id: str, devices: List[dict], stream_id: int
):
"""Replace the entire cache of the remote user's devices. """Replace the entire cache of the remote user's devices.
Note: assumes that we are the only thread that can be updating this user's Note: assumes that we are the only thread that can be updating this user's
device list. device list.
Args: Args:
user_id (str): User to update device list for user_id: User to update device list for
devices (list[dict]): list of device objects supplied over federation devices: list of device objects supplied over federation
stream_id (int): the version of the device list stream_id: the version of the device list
Returns: Returns:
Deferred[None] Deferred[None]
@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_id, stream_id,
) )
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id): def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
):
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
) )
@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id}, txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
) )
@defer.inlineCallbacks async def add_device_change_to_streams(
def add_device_change_to_streams(self, user_id, device_ids, hosts): self, user_id: str, device_ids: Collection[str], hosts: List[str]
):
"""Persist that a user's devices have been updated, and which hosts """Persist that a user's devices have been updated, and which hosts
(if any) should be poked. (if any) should be poked.
""" """
@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids: with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_device_change_to_stream", "add_device_change_to_stream",
self._add_device_change_to_stream_txn, self._add_device_change_to_stream_txn,
user_id, user_id,
@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
with self._device_list_id_gen.get_next_mult( with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids) len(hosts) * len(device_ids)
) as stream_ids: ) as stream_ids:
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"add_device_outbound_poke_to_stream", "add_device_outbound_poke_to_stream",
self._add_device_outbound_poke_to_stream_txn, self._add_device_outbound_poke_to_stream_txn,
user_id, user_id,
@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
def _add_device_outbound_poke_to_stream_txn( def _add_device_outbound_poke_to_stream_txn(
self, txn, user_id, device_ids, hosts, stream_ids, context, self,
txn: LoggingTransaction,
user_id: str,
device_ids: Collection[str],
hosts: List[str],
stream_ids: List[str],
context: Dict[str, str],
): ):
for host in hosts: for host in hosts:
txn.call_after( txn.call_after(
@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
], ],
) )
def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000): def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure """Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. that we don't fill up due to dead servers.

View file

@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
retry_timings_res retry_timings_res
) )
self.datastore.get_device_updates_by_remote.return_value = defer.succeed( self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
(0, []) (0, [])
) )

View file

@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_store_new_device(self): def test_store_new_device(self):
yield self.store.store_device("user_id", "device_id", "display_name") yield defer.ensureDeferred(
self.store.store_device("user_id", "device_id", "display_name")
)
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertDictContainsSubset( self.assertDictContainsSubset(
@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_devices_by_user(self): def test_get_devices_by_user(self):
yield self.store.store_device("user_id", "device1", "display_name 1") yield defer.ensureDeferred(
yield self.store.store_device("user_id", "device2", "display_name 2") self.store.store_device("user_id", "device1", "display_name 1")
yield self.store.store_device("user_id2", "device3", "display_name 3") )
yield defer.ensureDeferred(
self.store.store_device("user_id", "device2", "display_name 2")
)
yield defer.ensureDeferred(
self.store.store_device("user_id2", "device3", "display_name 3")
)
res = yield self.store.get_devices_by_user("user_id") res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
self.assertEqual(2, len(res.keys())) self.assertEqual(2, len(res.keys()))
self.assertDictContainsSubset( self.assertDictContainsSubset(
{ {
@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
device_ids = ["device_id1", "device_id2"] device_ids = ["device_id1", "device_id2"]
# Add two device updates with a single stream_id # Add two device updates with a single stream_id
yield self.store.add_device_change_to_streams( yield defer.ensureDeferred(
"user_id", device_ids, ["somehost"] self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
) )
# Get all device updates ever meant for this remote # Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote( now_stream_id, device_updates = yield defer.ensureDeferred(
"somehost", -1, limit=100 self.store.get_device_updates_by_remote("somehost", -1, limit=100)
) )
# Check original device_ids are contained within these updates # Check original device_ids are contained within these updates
@ -99,19 +107,23 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_device(self): def test_update_device(self):
yield self.store.store_device("user_id", "device_id", "display_name 1") yield defer.ensureDeferred(
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do a no-op first # do a no-op first
yield self.store.update_device("user_id", "device_id") yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id") res = yield self.store.get_device("user_id", "device_id")
self.assertEqual("display_name 1", res["display_name"]) self.assertEqual("display_name 1", res["display_name"])
# do the update # do the update
yield self.store.update_device( yield defer.ensureDeferred(
"user_id", "device_id", new_display_name="display_name 2" self.store.update_device(
"user_id", "device_id", new_display_name="display_name 2"
)
) )
# check it worked # check it worked
@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_update_unknown_device(self): def test_update_unknown_device(self):
with self.assertRaises(synapse.api.errors.StoreError) as cm: with self.assertRaises(synapse.api.errors.StoreError) as cm:
yield self.store.update_device( yield defer.ensureDeferred(
"user_id", "unknown_device_id", new_display_name="display_name 2" self.store.update_device(
"user_id", "unknown_device_id", new_display_name="display_name 2"
)
) )
self.assertEqual(404, cm.exception.code) self.assertEqual(404, cm.exception.code)

View file

@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield self.store.store_device("user", "device", None) yield defer.ensureDeferred(self.store.store_device("user", "device", None))
yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.set_e2e_device_keys("user", "device", now, json)
@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = {"key": "value"} json = {"key": "value"}
yield self.store.store_device("user", "device", None) yield defer.ensureDeferred(self.store.store_device("user", "device", None))
changed = yield self.store.set_e2e_device_keys("user", "device", now, json) changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
self.assertTrue(changed) self.assertTrue(changed)
@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
json = {"key": "value"} json = {"key": "value"}
yield self.store.set_e2e_device_keys("user", "device", now, json) yield self.store.set_e2e_device_keys("user", "device", now, json)
yield self.store.store_device("user", "device", "display_name") yield defer.ensureDeferred(
self.store.store_device("user", "device", "display_name")
)
res = yield defer.ensureDeferred( res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),)) self.store.get_e2e_device_keys((("user", "device"),))
@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield self.store.store_device("user1", "device1", None) yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
yield self.store.store_device("user1", "device2", None) yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
yield self.store.store_device("user2", "device1", None) yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
yield self.store.store_device("user2", "device2", None) yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"}) yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"}) yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})