forked from MirrorHub/synapse
Convert devices database to async/await. (#8069)
This commit is contained in:
parent
5dd73d029e
commit
5ecc8b5825
5 changed files with 220 additions and 176 deletions
1
changelog.d/8069.misc
Normal file
1
changelog.d/8069.misc
Normal file
|
@ -0,0 +1 @@
|
|||
Convert various parts of the codebase to async/await.
|
|
@ -15,9 +15,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.logging.opentracing import (
|
||||
|
@ -33,14 +31,9 @@ from synapse.storage.database import (
|
|||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.types import Collection, get_verify_key_from_cross_signing_key
|
||||
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import (
|
||||
Cache,
|
||||
cached,
|
||||
cachedInlineCallbacks,
|
||||
cachedList,
|
||||
)
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
||||
|
@ -54,13 +47,13 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
|||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def get_device(self, user_id, device_id):
|
||||
def get_device(self, user_id: str, device_id: str):
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to retrieve
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
defer.Deferred for a dict containing the device information
|
||||
Raises:
|
||||
|
@ -73,19 +66,17 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
desc="get_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
async def get_devices_by_user(self, user_id: str) -> Dict[str, Dict[str, str]]:
|
||||
"""Retrieve all of a user's registered devices. Only returns devices
|
||||
that are not marked as hidden.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
user_id:
|
||||
Returns:
|
||||
defer.Deferred: resolves to a dict from device_id to a dict
|
||||
containing "device_id", "user_id" and "display_name" for each
|
||||
device.
|
||||
A mapping from device_id to a dict containing "device_id", "user_id"
|
||||
and "display_name" for each device.
|
||||
"""
|
||||
devices = yield self.db_pool.simple_select_list(
|
||||
devices = await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
|
@ -95,19 +86,20 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return {d["device_id"]: d for d in devices}
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
|
||||
async def get_device_updates_by_remote(
|
||||
self, destination: str, from_stream_id: int, limit: int
|
||||
) -> Tuple[int, List[Tuple[str, dict]]]:
|
||||
"""Get a stream of device updates to send to the given remote server.
|
||||
|
||||
Args:
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
limit (int): Maximum number of device updates to return
|
||||
destination: The host the device updates are intended for
|
||||
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||
limit: Maximum number of device updates to return
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[int, list[tuple[string,dict]]]]:
|
||||
current stream id (ie, the stream id of the last update included in the
|
||||
response), and the list of updates, where each update is a pair of EDU
|
||||
type and EDU contents
|
||||
A mapping from the current stream id (ie, the stream id of the last
|
||||
update included in the response), and the list of updates, where
|
||||
each update is a pair of EDU type and EDU contents.
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
|
@ -117,7 +109,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
if not has_changed:
|
||||
return now_stream_id, []
|
||||
|
||||
updates = yield self.db_pool.runInteraction(
|
||||
updates = await self.db_pool.runInteraction(
|
||||
"get_device_updates_by_remote",
|
||||
self._get_device_updates_by_remote_txn,
|
||||
destination,
|
||||
|
@ -136,9 +128,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
master_key_by_user = {}
|
||||
self_signing_key_by_user = {}
|
||||
for user in users:
|
||||
cross_signing_key = yield defer.ensureDeferred(
|
||||
self.get_e2e_cross_signing_key(user, "master")
|
||||
)
|
||||
cross_signing_key = await self.get_e2e_cross_signing_key(user, "master")
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
|
@ -151,8 +141,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
cross_signing_key = yield defer.ensureDeferred(
|
||||
self.get_e2e_cross_signing_key(user, "self_signing")
|
||||
cross_signing_key = await self.get_e2e_cross_signing_key(
|
||||
user, "self_signing"
|
||||
)
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
|
@ -202,7 +192,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
if update_stream_id > previous_update_stream_id:
|
||||
query_map[key] = (update_stream_id, update_context)
|
||||
|
||||
results = yield self._get_device_update_edus_by_remote(
|
||||
results = await self._get_device_update_edus_by_remote(
|
||||
destination, from_stream_id, query_map
|
||||
)
|
||||
|
||||
|
@ -215,16 +205,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return now_stream_id, results
|
||||
|
||||
def _get_device_updates_by_remote_txn(
|
||||
self, txn, destination, from_stream_id, now_stream_id, limit
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
destination: str,
|
||||
from_stream_id: int,
|
||||
now_stream_id: int,
|
||||
limit: int,
|
||||
):
|
||||
"""Return device update information for a given remote destination
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to execute
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
now_stream_id (int): The maximum stream_id to filter updates by, inclusive
|
||||
limit (int): Maximum number of device updates to return
|
||||
txn: The transaction to execute
|
||||
destination: The host the device updates are intended for
|
||||
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||
now_stream_id: The maximum stream_id to filter updates by, inclusive
|
||||
limit: Maximum number of device updates to return
|
||||
|
||||
Returns:
|
||||
List: List of device updates
|
||||
|
@ -240,23 +235,26 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return list(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
|
||||
async def _get_device_update_edus_by_remote(
|
||||
self,
|
||||
destination: str,
|
||||
from_stream_id: int,
|
||||
query_map: Dict[Tuple[str, str], Tuple[int, Optional[str]]],
|
||||
) -> List[Tuple[str, dict]]:
|
||||
"""Returns a list of device update EDUs as well as E2EE keys
|
||||
|
||||
Args:
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
destination: The host the device updates are intended for
|
||||
from_stream_id: The minimum stream_id to filter updates by, exclusive
|
||||
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
||||
user_id/device_id to update stream_id and the relevant json-encoded
|
||||
opentracing context
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of objects representing an device update EDU
|
||||
|
||||
List of objects representing an device update EDU
|
||||
"""
|
||||
devices = (
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_get_e2e_device_keys_txn",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_map.keys(),
|
||||
|
@ -271,7 +269,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
for user_id, user_devices in devices.items():
|
||||
# The prev_id for the first row is always the last row before
|
||||
# `from_stream_id`
|
||||
prev_id = yield self._get_last_device_update_for_remote_user(
|
||||
prev_id = await self._get_last_device_update_for_remote_user(
|
||||
destination, user_id, from_stream_id
|
||||
)
|
||||
|
||||
|
@ -315,7 +313,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
return results
|
||||
|
||||
def _get_last_device_update_for_remote_user(
|
||||
self, destination, user_id, from_stream_id
|
||||
self, destination: str, user_id: str, from_stream_id: int
|
||||
):
|
||||
def f(txn):
|
||||
prev_sent_id_sql = """
|
||||
|
@ -329,7 +327,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
|
@ -339,7 +337,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
stream_id,
|
||||
)
|
||||
|
||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||
def _mark_as_sent_devices_by_remote_txn(
|
||||
self, txn: LoggingTransaction, destination: str, stream_id: int
|
||||
) -> None:
|
||||
# We update the device_lists_outbound_last_success with the successfully
|
||||
# poked users.
|
||||
sql = """
|
||||
|
@ -367,17 +367,21 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
"""
|
||||
txn.execute(sql, (destination, stream_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_user_signature_change_to_streams(self, from_user_id, user_ids):
|
||||
async def add_user_signature_change_to_streams(
|
||||
self, from_user_id: str, user_ids: List[str]
|
||||
) -> int:
|
||||
"""Persist that a user has made new signatures
|
||||
|
||||
Args:
|
||||
from_user_id (str): the user who made the signatures
|
||||
user_ids (list[str]): the users who were signed
|
||||
from_user_id: the user who made the signatures
|
||||
user_ids: the users who were signed
|
||||
|
||||
Returns:
|
||||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
from_user_id,
|
||||
|
@ -386,7 +390,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
)
|
||||
return stream_id
|
||||
|
||||
def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
|
||||
def _add_user_signature_change_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
from_user_id: str,
|
||||
user_ids: List[str],
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
txn.call_after(
|
||||
self._user_signature_stream_cache.entity_has_changed,
|
||||
from_user_id,
|
||||
|
@ -402,29 +412,30 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
},
|
||||
)
|
||||
|
||||
def get_device_stream_token(self):
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_user_devices_from_cache(self, query_list):
|
||||
async def get_user_devices_from_cache(
|
||||
self, query_list: List[Tuple[str, str]]
|
||||
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get the devices (and keys if any) for remote users from the cache.
|
||||
|
||||
Args:
|
||||
query_list(list): List of (user_id, device_ids), if device_ids is
|
||||
query_list: List of (user_id, device_ids), if device_ids is
|
||||
falsey then return all device ids for that user.
|
||||
|
||||
Returns:
|
||||
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
|
||||
a set of user_ids and results_map is a mapping of
|
||||
user_id -> device_id -> device_info
|
||||
A tuple of (user_ids_not_in_cache, results_map), where
|
||||
user_ids_not_in_cache is a set of user_ids and results_map is a
|
||||
mapping of user_id -> device_id -> device_info.
|
||||
"""
|
||||
user_ids = {user_id for user_id, _ in query_list}
|
||||
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||
|
||||
# We go and check if any of the users need to have their device lists
|
||||
# resynced. If they do then we remove them from the cached list.
|
||||
users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
|
||||
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
|
||||
user_ids
|
||||
)
|
||||
user_ids_in_cache = {
|
||||
|
@ -438,19 +449,19 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
continue
|
||||
|
||||
if device_id:
|
||||
device = yield self._get_cached_user_device(user_id, device_id)
|
||||
device = await self._get_cached_user_device(user_id, device_id)
|
||||
results.setdefault(user_id, {})[device_id] = device
|
||||
else:
|
||||
results[user_id] = yield self.get_cached_devices_for_user(user_id)
|
||||
results[user_id] = await self.get_cached_devices_for_user(user_id)
|
||||
|
||||
set_tag("in_cache", results)
|
||||
set_tag("not_in_cache", user_ids_not_in_cache)
|
||||
|
||||
return user_ids_not_in_cache, results
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, tree=True)
|
||||
def _get_cached_user_device(self, user_id, device_id):
|
||||
content = yield self.db_pool.simple_select_one_onecol(
|
||||
@cached(num_args=2, tree=True)
|
||||
async def _get_cached_user_device(self, user_id: str, device_id: str) -> JsonDict:
|
||||
content = await self.db_pool.simple_select_one_onecol(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcol="content",
|
||||
|
@ -458,9 +469,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
)
|
||||
return db_to_json(content)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def get_cached_devices_for_user(self, user_id):
|
||||
devices = yield self.db_pool.simple_select_list(
|
||||
@cached()
|
||||
async def get_cached_devices_for_user(self, user_id: str) -> Dict[str, JsonDict]:
|
||||
devices = await self.db_pool.simple_select_list(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("device_id", "content"),
|
||||
|
@ -470,11 +481,11 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||
}
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id):
|
||||
def get_devices_with_keys_by_user(self, user_id: str):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
(stream_id, devices)
|
||||
Deferred which resolves to (stream_id, devices)
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
|
@ -482,7 +493,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||
def _get_devices_with_keys_by_user_txn(
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
|
@ -515,17 +528,18 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return now_stream_id, []
|
||||
|
||||
def get_users_whose_devices_changed(self, from_key, user_ids):
|
||||
async def get_users_whose_devices_changed(
|
||||
self, from_key: str, user_ids: Iterable[str]
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
are in the given list of user_ids.
|
||||
|
||||
Args:
|
||||
from_key (str): The device lists stream token
|
||||
user_ids (Iterable[str])
|
||||
from_key: The device lists stream token
|
||||
user_ids: The user IDs to query for devices.
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: The set of user_ids whose devices have changed
|
||||
since `from_key`
|
||||
The set of user_ids whose devices have changed since `from_key`
|
||||
"""
|
||||
from_key = int(from_key)
|
||||
|
||||
|
@ -536,7 +550,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
if not to_check:
|
||||
return defer.succeed(set())
|
||||
return set()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn):
|
||||
changes = set()
|
||||
|
@ -556,18 +570,22 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return changes
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_whose_signatures_changed(self, user_id, from_key):
|
||||
async def get_users_whose_signatures_changed(
|
||||
self, user_id: str, from_key: str
|
||||
) -> Set[str]:
|
||||
"""Get the users who have new cross-signing signatures made by `user_id` since
|
||||
`from_key`.
|
||||
|
||||
Args:
|
||||
user_id (str): the user who made the signatures
|
||||
from_key (str): The device lists stream token
|
||||
user_id: the user who made the signatures
|
||||
from_key: The device lists stream token
|
||||
|
||||
Returns:
|
||||
A set of user IDs with updated signatures.
|
||||
"""
|
||||
from_key = int(from_key)
|
||||
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
|
||||
|
@ -575,7 +593,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
SELECT DISTINCT user_ids FROM user_signature_stream
|
||||
WHERE from_user_id = ? AND stream_id > ?
|
||||
"""
|
||||
rows = yield self.db_pool.execute(
|
||||
rows = await self.db_pool.execute(
|
||||
"get_users_whose_signatures_changed", None, sql, user_id, from_key
|
||||
)
|
||||
return {user for row in rows for user in db_to_json(row[0])}
|
||||
|
@ -638,7 +656,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id):
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id: str):
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
|
@ -655,7 +673,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
|
@ -669,8 +687,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
|
||||
return results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_ids_requiring_device_list_resync(
|
||||
async def get_user_ids_requiring_device_list_resync(
|
||||
self, user_ids: Optional[Collection[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""Given a list of remote users return the list of users that we
|
||||
|
@ -681,7 +698,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
The IDs of users whose device lists need resync.
|
||||
"""
|
||||
if user_ids:
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_resync",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
@ -689,7 +706,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
desc="get_user_ids_requiring_device_list_resync_with_iterable",
|
||||
)
|
||||
else:
|
||||
rows = yield self.db_pool.simple_select_list(
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="device_lists_remote_resync",
|
||||
keyvalues=None,
|
||||
retcols=("user_id",),
|
||||
|
@ -710,7 +727,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
|||
desc="make_remote_user_device_cache_as_stale",
|
||||
)
|
||||
|
||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
|
||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
|
||||
"""Mark that we no longer track device lists for remote user.
|
||||
"""
|
||||
|
||||
|
@ -779,16 +796,15 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
|||
"drop_device_lists_outbound_last_success_non_unique_idx",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
|
||||
txn.close()
|
||||
|
||||
yield self.db_pool.runWithConnection(f)
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.runWithConnection(f)
|
||||
await self.db_pool.updates._end_background_update(
|
||||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
|
||||
)
|
||||
return 1
|
||||
|
@ -868,18 +884,20 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
|
||||
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_device(self, user_id, device_id, initial_device_display_name):
|
||||
async def store_device(
|
||||
self, user_id: str, device_id: str, initial_device_display_name: str
|
||||
) -> bool:
|
||||
"""Ensure the given device is known; add it to the store if not
|
||||
|
||||
Args:
|
||||
user_id (str): id of user associated with the device
|
||||
device_id (str): id of device
|
||||
initial_device_display_name (str): initial displayname of the
|
||||
device. Ignored if device exists.
|
||||
user_id: id of user associated with the device
|
||||
device_id: id of device
|
||||
initial_device_display_name: initial displayname of the device.
|
||||
Ignored if device exists.
|
||||
|
||||
Returns:
|
||||
defer.Deferred: boolean whether the device was inserted or an
|
||||
existing device existed with that ID.
|
||||
Whether the device was inserted or an existing device existed with that ID.
|
||||
|
||||
Raises:
|
||||
StoreError: if the device is already in use
|
||||
"""
|
||||
|
@ -888,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
return False
|
||||
|
||||
try:
|
||||
inserted = yield self.db_pool.simple_insert(
|
||||
inserted = await self.db_pool.simple_insert(
|
||||
"devices",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
|
@ -902,7 +920,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
if not inserted:
|
||||
# if the device already exists, check if it's a real device, or
|
||||
# if the device ID is reserved by something else
|
||||
hidden = yield self.db_pool.simple_select_one_onecol(
|
||||
hidden = await self.db_pool.simple_select_one_onecol(
|
||||
"devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcol="hidden",
|
||||
|
@ -927,17 +945,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
raise StoreError(500, "Problem storing device.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
async def delete_device(self, user_id: str, device_id: str) -> None:
|
||||
"""Delete a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to delete
|
||||
"""
|
||||
yield self.db_pool.simple_delete_one(
|
||||
await self.db_pool.simple_delete_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
desc="delete_device",
|
||||
|
@ -945,17 +960,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
|
||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_devices(self, user_id, device_ids):
|
||||
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
||||
"""Deletes several devices.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the devices
|
||||
device_ids (list): The IDs of the devices to delete
|
||||
Returns:
|
||||
defer.Deferred
|
||||
user_id: The ID of the user which owns the devices
|
||||
device_ids: The IDs of the devices to delete
|
||||
"""
|
||||
yield self.db_pool.simple_delete_many(
|
||||
await self.db_pool.simple_delete_many(
|
||||
table="devices",
|
||||
column="device_id",
|
||||
iterable=device_ids,
|
||||
|
@ -965,26 +977,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
for device_id in device_ids:
|
||||
self.device_id_exists_cache.invalidate((user_id, device_id))
|
||||
|
||||
def update_device(self, user_id, device_id, new_display_name=None):
|
||||
async def update_device(
|
||||
self, user_id: str, device_id: str, new_display_name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Update a device. Only updates the device if it is not marked as
|
||||
hidden.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to update
|
||||
new_display_name (str|None): new displayname for device; None
|
||||
to leave unchanged
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to update
|
||||
new_display_name: new displayname for device; None to leave unchanged
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
Returns:
|
||||
defer.Deferred
|
||||
"""
|
||||
updates = {}
|
||||
if new_display_name is not None:
|
||||
updates["display_name"] = new_display_name
|
||||
if not updates:
|
||||
return defer.succeed(None)
|
||||
return self.db_pool.simple_update_one(
|
||||
return None
|
||||
await self.db_pool.simple_update_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
updatevalues=updates,
|
||||
|
@ -992,7 +1003,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
def update_remote_device_list_cache_entry(
|
||||
self, user_id, device_id, content, stream_id
|
||||
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
|
||||
):
|
||||
"""Updates a single device in the cache of a remote user's devicelist.
|
||||
|
||||
|
@ -1000,10 +1011,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
device list.
|
||||
|
||||
Args:
|
||||
user_id (str): User to update device list for
|
||||
device_id (str): ID of decivice being updated
|
||||
content (dict): new data on this device
|
||||
stream_id (int): the version of the device list
|
||||
user_id: User to update device list for
|
||||
device_id: ID of decivice being updated
|
||||
content: new data on this device
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
|
@ -1018,8 +1029,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
def _update_remote_device_list_cache_entry_txn(
|
||||
self, txn, user_id, device_id, content, stream_id
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
content: JsonDict,
|
||||
stream_id: int,
|
||||
) -> None:
|
||||
if content.get("deleted"):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
|
@ -1055,16 +1071,18 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
lock=False,
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache(self, user_id, devices, stream_id):
|
||||
def update_remote_device_list_cache(
|
||||
self, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
"""Replace the entire cache of the remote user's devices.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
device list.
|
||||
|
||||
Args:
|
||||
user_id (str): User to update device list for
|
||||
devices (list[dict]): list of device objects supplied over federation
|
||||
stream_id (int): the version of the device list
|
||||
user_id: User to update device list for
|
||||
devices: list of device objects supplied over federation
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
|
@ -1077,7 +1095,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
stream_id,
|
||||
)
|
||||
|
||||
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
|
||||
def _update_remote_device_list_cache_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
@ -1118,8 +1138,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
||||
async def add_device_change_to_streams(
|
||||
self, user_id: str, device_ids: Collection[str], hosts: List[str]
|
||||
):
|
||||
"""Persist that a user's devices have been updated, and which hosts
|
||||
(if any) should be poked.
|
||||
"""
|
||||
|
@ -1127,7 +1148,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
return
|
||||
|
||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
self._add_device_change_to_stream_txn,
|
||||
user_id,
|
||||
|
@ -1142,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
with self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_outbound_poke_to_stream",
|
||||
self._add_device_outbound_poke_to_stream_txn,
|
||||
user_id,
|
||||
|
@ -1187,7 +1208,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
)
|
||||
|
||||
def _add_device_outbound_poke_to_stream_txn(
|
||||
self, txn, user_id, device_ids, hosts, stream_ids, context,
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
user_id: str,
|
||||
device_ids: Collection[str],
|
||||
hosts: List[str],
|
||||
stream_ids: List[str],
|
||||
context: Dict[str, str],
|
||||
):
|
||||
for host in hosts:
|
||||
txn.call_after(
|
||||
|
@ -1219,7 +1246,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
|||
],
|
||||
)
|
||||
|
||||
def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
|
||||
def _prune_old_outbound_device_pokes(self, prune_age: int = 24 * 60 * 60 * 1000):
|
||||
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
||||
that we don't fill up due to dead servers.
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
|||
retry_timings_res
|
||||
)
|
||||
|
||||
self.datastore.get_device_updates_by_remote.return_value = defer.succeed(
|
||||
self.datastore.get_device_updates_by_remote.side_effect = lambda destination, from_stream_id, limit: make_awaitable(
|
||||
(0, [])
|
||||
)
|
||||
|
||||
|
|
|
@ -34,7 +34,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_store_new_device(self):
|
||||
yield self.store.store_device("user_id", "device_id", "display_name")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device_id", "display_name")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertDictContainsSubset(
|
||||
|
@ -48,11 +50,17 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_user(self):
|
||||
yield self.store.store_device("user_id", "device1", "display_name 1")
|
||||
yield self.store.store_device("user_id", "device2", "display_name 2")
|
||||
yield self.store.store_device("user_id2", "device3", "display_name 3")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device1", "display_name 1")
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device2", "display_name 2")
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id2", "device3", "display_name 3")
|
||||
)
|
||||
|
||||
res = yield self.store.get_devices_by_user("user_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_devices_by_user("user_id"))
|
||||
self.assertEqual(2, len(res.keys()))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
|
@ -76,13 +84,13 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
device_ids = ["device_id1", "device_id2"]
|
||||
|
||||
# Add two device updates with a single stream_id
|
||||
yield self.store.add_device_change_to_streams(
|
||||
"user_id", device_ids, ["somehost"]
|
||||
yield defer.ensureDeferred(
|
||||
self.store.add_device_change_to_streams("user_id", device_ids, ["somehost"])
|
||||
)
|
||||
|
||||
# Get all device updates ever meant for this remote
|
||||
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
|
||||
"somehost", -1, limit=100
|
||||
now_stream_id, device_updates = yield defer.ensureDeferred(
|
||||
self.store.get_device_updates_by_remote("somehost", -1, limit=100)
|
||||
)
|
||||
|
||||
# Check original device_ids are contained within these updates
|
||||
|
@ -99,20 +107,24 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self.store.store_device("user_id", "device_id", "display_name 1")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield self.store.update_device("user_id", "device_id")
|
||||
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
yield self.store.update_device(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.update_device(
|
||||
"user_id", "device_id", new_display_name="display_name 2"
|
||||
)
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
|
@ -121,7 +133,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
|||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
with self.assertRaises(synapse.api.errors.StoreError) as cm:
|
||||
yield self.store.update_device(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.update_device(
|
||||
"user_id", "unknown_device_id", new_display_name="display_name 2"
|
||||
)
|
||||
)
|
||||
self.assertEqual(404, cm.exception.code)
|
||||
|
|
|
@ -30,7 +30,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||
now = 1470174257070
|
||||
json = {"key": "value"}
|
||||
|
||||
yield self.store.store_device("user", "device", None)
|
||||
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
|
||||
|
||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
|
||||
|
@ -47,7 +47,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||
now = 1470174257070
|
||||
json = {"key": "value"}
|
||||
|
||||
yield self.store.store_device("user", "device", None)
|
||||
yield defer.ensureDeferred(self.store.store_device("user", "device", None))
|
||||
|
||||
changed = yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
self.assertTrue(changed)
|
||||
|
@ -63,7 +63,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||
json = {"key": "value"}
|
||||
|
||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
yield self.store.store_device("user", "device", "display_name")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.store_device("user", "device", "display_name")
|
||||
)
|
||||
|
||||
res = yield defer.ensureDeferred(
|
||||
self.store.get_e2e_device_keys((("user", "device"),))
|
||||
|
@ -79,10 +81,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
|||
def test_multiple_devices(self):
|
||||
now = 1470174257070
|
||||
|
||||
yield self.store.store_device("user1", "device1", None)
|
||||
yield self.store.store_device("user1", "device2", None)
|
||||
yield self.store.store_device("user2", "device1", None)
|
||||
yield self.store.store_device("user2", "device2", None)
|
||||
yield defer.ensureDeferred(self.store.store_device("user1", "device1", None))
|
||||
yield defer.ensureDeferred(self.store.store_device("user1", "device2", None))
|
||||
yield defer.ensureDeferred(self.store.store_device("user2", "device1", None))
|
||||
yield defer.ensureDeferred(self.store.store_device("user2", "device2", None))
|
||||
|
||||
yield self.store.set_e2e_device_keys("user1", "device1", now, {"key": "json11"})
|
||||
yield self.store.set_e2e_device_keys("user1", "device2", now, {"key": "json12"})
|
||||
|
|
Loading…
Reference in a new issue