Convert the device message and pagination handlers to async/await. (#7678)

This commit is contained in:
Patrick Cloke 2020-06-16 08:06:17 -04:00 committed by GitHub
parent 03619324fc
commit 98c4e35e3c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 19 additions and 31 deletions

1
changelog.d/7678.misc Normal file
View file

@ -0,0 +1 @@
Convert the device message and pagination handlers to async/await.

View file

@ -18,8 +18,6 @@ from typing import Any, Dict
from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
self._device_list_updater = hs.get_device_handler().device_list_updater
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
async def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
}
local_messages[user_id] = messages_by_device
yield self._check_for_unknown_devices(
await self._check_for_unknown_devices(
message_type, sender_user_id, by_device
)
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
stream_id = await self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
)
@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys()
)
@defer.inlineCallbacks
def _check_for_unknown_devices(
async def _check_for_unknown_devices(
self,
message_type: str,
sender_user_id: str,
by_device: Dict[str, Dict[str, Any]],
):
"""Checks inbound device messages for unkown remote devices, and if
"""Checks inbound device messages for unknown remote devices, and if
found marks the remote cache for the user as stale.
"""
@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
requesting_device_ids.add(device_id)
# Check if we are tracking the devices of the remote user.
room_ids = yield self.store.get_rooms_for_user(sender_user_id)
room_ids = await self.store.get_rooms_for_user(sender_user_id)
if not room_ids:
logger.info(
"Received device message from remote device we don't"
@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
# If we are tracking check that we know about the sending
# devices.
cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
unknown_devices = requesting_device_ids - set(cached_devices)
if unknown_devices:
@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
sender_user_id,
unknown_devices,
)
yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
run_in_background(
self._device_list_updater.user_device_resync, sender_user_id
)
@defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
async def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
}
log_kv({"local_messages": local_messages})
stream_id = yield self.store.add_messages_to_device_inbox(
stream_id = await self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)

View file

@ -15,7 +15,6 @@
# limitations under the License.
import logging
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
@ -97,8 +96,7 @@ class PaginationHandler(object):
job["longest_max_lifetime"],
)
@defer.inlineCallbacks
def purge_history_for_rooms_in_range(self, min_ms, max_ms):
async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
"""Purge outdated events from rooms within the given retention range.
If a default retention policy is defined in the server's configuration and its
@ -137,7 +135,7 @@ class PaginationHandler(object):
include_null,
)
rooms = yield self.store.get_rooms_for_retention_period_in_range(
rooms = await self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null
)
@ -165,9 +163,9 @@ class PaginationHandler(object):
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = yield self.store.get_room_event_before_stream_ordering(
r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering,
)
if not r:
@ -227,8 +225,7 @@ class PaginationHandler(object):
)
return purge_id
@defer.inlineCallbacks
def _purge_history(self, purge_id, room_id, token, delete_local_events):
async def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@ -237,14 +234,11 @@ class PaginationHandler(object):
token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
Returns:
Deferred
"""
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
yield self.storage.purge_events.purge_history(
with await self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] complete")
@ -282,9 +276,7 @@ class PaginationHandler(object):
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
joined = await defer.maybeDeferred(
self.store.is_host_joined, room_id, self._server_name
)
joined = await self.store.is_host_joined(room_id, self._server_name)
if joined:
raise SynapseError(400, "Users are still joined to this room")