0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-09 22:28:55 +02:00

Convert account data, device inbox, and censor events databases to async/await (#8063)

This commit is contained in:
Patrick Cloke 2020-08-12 09:29:06 -04:00 committed by GitHub
parent a3a59bab7b
commit d68e10f308
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 99 additions and 87 deletions

1
changelog.d/8063.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -16,15 +16,16 @@
import abc
import logging
from typing import List, Tuple
from typing import List, Optional, Tuple
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
@ -97,13 +98,15 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_user", get_account_data_for_user_txn
)
@cachedInlineCallbacks(num_args=2, max_entries=5000)
def get_global_account_data_by_type_for_user(self, data_type, user_id):
@cached(num_args=2, max_entries=5000)
async def get_global_account_data_by_type_for_user(
self, data_type: str, user_id: str
) -> Optional[JsonDict]:
"""
Returns:
Deferred: A dict
The account data.
"""
result = yield self.db_pool.simple_select_one_onecol(
result = await self.db_pool.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@ -280,9 +283,11 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
@cached(num_args=2, cache_context=True, max_entries=5000)
async def is_ignored_by(
self, ignored_user_id: str, ignorer_user_id: str, cache_context: _CacheContext
) -> bool:
ignored_account_data = await self.get_global_account_data_by_type_for_user(
"m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate,
@ -307,24 +312,27 @@ class AccountDataStore(AccountDataWorkerStore):
super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
Returns:
A deferred int.
The maximum stream ID.
"""
return self._account_data_id_gen.get_current_token()
@defer.inlineCallbacks
def add_account_data_to_room(self, user_id, room_id, account_data_type, content):
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
room_id(str): The room to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
room_id: The room to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
The maximum stream ID.
"""
content_json = json_encoder.encode(content)
@ -332,7 +340,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@ -350,7 +358,7 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
yield self._update_max_stream_id(next_id)
await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@ -359,18 +367,20 @@ class AccountDataStore(AccountDataWorkerStore):
(user_id, room_id, account_data_type), content
)
result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
async def add_account_data_for_user(
self, user_id: str, account_data_type: str, content: JsonDict
) -> int:
"""Add some account_data to a room for a user.
Args:
user_id(str): The user to add a tag for.
account_data_type(str): The type of account_data to add.
content(dict): A json object to associate with the tag.
user_id: The user to add a tag for.
account_data_type: The type of account_data to add.
content: A json object to associate with the tag.
Returns:
A deferred that completes once the account_data has been added.
The maximum stream ID.
"""
content_json = json_encoder.encode(content)
@ -378,7 +388,7 @@ class AccountDataStore(AccountDataWorkerStore):
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
yield self.db_pool.simple_upsert(
await self.db_pool.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@ -396,7 +406,7 @@ class AccountDataStore(AccountDataWorkerStore):
# Note: This is only here for backwards compat to allow admins to
# roll back to a previous Synapse version. Next time we update the
# database version we can remove this table.
yield self._update_max_stream_id(next_id)
await self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
@ -404,14 +414,13 @@ class AccountDataStore(AccountDataWorkerStore):
(account_data_type, user_id)
)
result = self._account_data_id_gen.get_current_token()
return result
return self._account_data_id_gen.get_current_token()
def _update_max_stream_id(self, next_id):
def _update_max_stream_id(self, next_id: int):
"""Update the max stream_id
Args:
next_id(int): The the revision to advance to.
next_id: The the revision to advance to.
"""
# Note: This is only here for backwards compat to allow admins to

View file

@ -16,8 +16,6 @@
import logging
from typing import TYPE_CHECKING
from twisted.internet import defer
from synapse.events.utils import prune_event_dict
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
@ -148,17 +146,16 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
updatevalues={"json": pruned_json},
)
@defer.inlineCallbacks
def expire_event(self, event_id):
async def expire_event(self, event_id: str) -> None:
"""Retrieve and expire an event that has expired, and delete its associated
expiry timestamp. If the event can't be retrieved, delete its associated
timestamp so we don't try to expire it again in the future.
Args:
event_id (str): The ID of the event to delete.
event_id: The ID of the event to delete.
"""
# Try to retrieve the event's content from the database or the event cache.
event = yield self.get_event(event_id)
event = await self.get_event(event_id)
def delete_expired_event_txn(txn):
# Delete the expiry timestamp associated with this event from the database.
@ -193,7 +190,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
txn, "_get_event_cache", (event.event_id,)
)
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_expired_event", delete_expired_event_txn
)

View file

@ -16,8 +16,6 @@
import logging
from typing import List, Tuple
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
@ -31,24 +29,31 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
async def get_new_messages_for_device(
self,
user_id: str,
device_id: str,
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
) -> Tuple[List[dict], int]:
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
user_id: The recipient user_id.
device_id: The recipient device_id.
last_stream_id: The last stream ID checked.
current_stream_id: The current position of the to device
message stream.
limit: The maximum number of messages to retrieve.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
A list of messages for the device and where in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
return ([], current_stream_id)
def get_new_messages_for_device_txn(txn):
sql = (
@ -69,20 +74,22 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@trace
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
async def delete_messages_for_device(
self, user_id: str, device_id: str, up_to_stream_id: int
) -> int:
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
user_id: The recipient user_id.
device_id: The recipient device_id.
up_to_stream_id: Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
The number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
@ -109,7 +116,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.db_pool.runInteraction(
count = await self.db_pool.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
@ -128,9 +135,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return count
@trace
def get_new_device_msgs_for_remote(
async def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
):
) -> Tuple[List[dict], int]:
"""
Args:
destination(str): The name of the remote server.
@ -139,8 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
A list of messages for the device and where in the stream the messages got to.
"""
set_tag("destination", destination)
@ -153,11 +159,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
if not has_changed or last_stream_id == current_stream_id:
log_kv({"message": "No new messages in stream"})
return defer.succeed(([], current_stream_id))
return ([], current_stream_id)
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
return defer.succeed(([], last_stream_id))
return ([], last_stream_id)
@trace
def get_new_messages_for_remote_destination_txn(txn):
@ -178,7 +184,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = current_stream_id
return messages, stream_pos
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
@ -290,16 +296,15 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
async def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.db_pool.runWithConnection(reindex_txn)
await self.db_pool.runWithConnection(reindex_txn)
yield self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
@ -320,21 +325,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
)
@trace
@defer.inlineCallbacks
def add_messages_to_device_inbox(
self, local_messages_by_user_then_device, remote_messages_by_destination
):
async def add_messages_to_device_inbox(
self,
local_messages_by_user_then_device: dict,
remote_messages_by_destination: dict,
) -> int:
"""Used to send messages from this server.
Args:
sender_user_id(str): The ID of the user sending these messages.
local_messages_by_user_and_device(dict):
local_messages_by_user_and_device:
Dictionary of user_id to device_id to message.
remote_messages_by_destination(dict):
remote_messages_by_destination:
Dictionary of destination server_name to the EDU JSON to send.
Returns:
A deferred stream_id that resolves when the messages have been
inserted.
The new stream_id.
"""
def add_messages_txn(txn, now_ms, stream_id):
@ -359,7 +364,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@ -371,10 +376,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
self, origin, message_id, local_messages_by_user_then_device
):
async def add_messages_from_remote_to_device_inbox(
self, origin: str, message_id: str, local_messages_by_user_then_device: dict
) -> int:
def add_messages_txn(txn, now_ms, stream_id):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
@ -409,7 +413,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,

View file

@ -24,6 +24,7 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import register_federation_servlets
@ -151,7 +152,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_current_state_deltas.return_value = (0, None)
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: defer.succeed(
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
([], 0)
)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None