Convert the device message and pagination handlers to async/await. (#7678)

This commit is contained in:
Patrick Cloke 2020-06-16 08:06:17 -04:00 committed by GitHub
parent 03619324fc
commit 98c4e35e3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 31 deletions

1
changelog.d/7678.misc Normal file
View file

@ -0,0 +1 @@
Convert the device message and pagination handlers to async/await.

View file

@ -18,8 +18,6 @@ from typing import Any, Dict
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background from synapse.logging.context import run_in_background
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
self._device_list_updater = hs.get_device_handler().device_list_updater self._device_list_updater = hs.get_device_handler().device_list_updater
@defer.inlineCallbacks async def on_direct_to_device_edu(self, origin, content):
def on_direct_to_device_edu(self, origin, content):
local_messages = {} local_messages = {}
sender_user_id = content["sender"] sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id): if origin != get_domain_from_id(sender_user_id):
@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
} }
local_messages[user_id] = messages_by_device local_messages[user_id] = messages_by_device
yield self._check_for_unknown_devices( await self._check_for_unknown_devices(
message_type, sender_user_id, by_device message_type, sender_user_id, by_device
) )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox( stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages origin, message_id, local_messages
) )
@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys() "to_device_key", stream_id, users=local_messages.keys()
) )
@defer.inlineCallbacks async def _check_for_unknown_devices(
def _check_for_unknown_devices(
self, self,
message_type: str, message_type: str,
sender_user_id: str, sender_user_id: str,
by_device: Dict[str, Dict[str, Any]], by_device: Dict[str, Dict[str, Any]],
): ):
"""Checks inbound device messages for unkown remote devices, and if """Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale. found marks the remote cache for the user as stale.
""" """
@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
requesting_device_ids.add(device_id) requesting_device_ids.add(device_id)
# Check if we are tracking the devices of the remote user. # Check if we are tracking the devices of the remote user.
room_ids = yield self.store.get_rooms_for_user(sender_user_id) room_ids = await self.store.get_rooms_for_user(sender_user_id)
if not room_ids: if not room_ids:
logger.info( logger.info(
"Received device message from remote device we don't" "Received device message from remote device we don't"
@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
# If we are tracking check that we know about the sending # If we are tracking check that we know about the sending
# devices. # devices.
cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
unknown_devices = requesting_device_ids - set(cached_devices) unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices: if unknown_devices:
@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
sender_user_id, sender_user_id,
unknown_devices, unknown_devices,
) )
yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background # Immediately attempt a resync in the background
run_in_background( run_in_background(
self._device_list_updater.user_device_resync, sender_user_id self._device_list_updater.user_device_resync, sender_user_id
) )
@defer.inlineCallbacks async def send_device_message(self, sender_user_id, message_type, messages):
def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages)) set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id) set_tag("sender", sender_user_id)
local_messages = {} local_messages = {}
@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
} }
log_kv({"local_messages": local_messages}) log_kv({"local_messages": local_messages})
stream_id = yield self.store.add_messages_to_device_inbox( stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents local_messages, remote_edu_contents
) )

View file

@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -97,8 +96,7 @@ class PaginationHandler(object):
job["longest_max_lifetime"], job["longest_max_lifetime"],
) )
@defer.inlineCallbacks async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
def purge_history_for_rooms_in_range(self, min_ms, max_ms):
"""Purge outdated events from rooms within the given retention range. """Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its If a default retention policy is defined in the server's configuration and its
@ -137,7 +135,7 @@ class PaginationHandler(object):
include_null, include_null,
) )
rooms = yield self.store.get_rooms_for_retention_period_in_range( rooms = await self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null min_ms, max_ms, include_null
) )
@ -165,9 +163,9 @@ class PaginationHandler(object):
# Figure out what token we should start purging at. # Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime ts = self.clock.time_msec() - max_lifetime
stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts) stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = yield self.store.get_room_event_before_stream_ordering( r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering, room_id, stream_ordering,
) )
if not r: if not r:
@ -227,8 +225,7 @@ class PaginationHandler(object):
) )
return purge_id return purge_id
@defer.inlineCallbacks async def _purge_history(self, purge_id, room_id, token, delete_local_events):
def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room. """Carry out a history purge on a room.
Args: Args:
@ -237,14 +234,11 @@ class PaginationHandler(object):
token (str): topological token to delete events before token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as delete_local_events (bool): True to delete local events as well as
remote ones remote ones
Returns:
Deferred
""" """
self._purges_in_progress_by_room.add(room_id) self._purges_in_progress_by_room.add(room_id)
try: try:
with (yield self.pagination_lock.write(room_id)): with await self.pagination_lock.write(room_id):
yield self.storage.purge_events.purge_history( await self.storage.purge_events.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
logger.info("[purge] complete") logger.info("[purge] complete")
@ -282,9 +276,7 @@ class PaginationHandler(object):
await self.store.get_room_version_id(room_id) await self.store.get_room_version_id(room_id)
# first check that we have no users in this room # first check that we have no users in this room
joined = await defer.maybeDeferred( joined = await self.store.is_host_joined(room_id, self._server_name)
self.store.is_host_joined, room_id, self._server_name
)
if joined: if joined:
raise SynapseError(400, "Users are still joined to this room") raise SynapseError(400, "Users are still joined to this room")