Convert some of the data store to async. (#7976)

This commit is contained in:
Patrick Cloke 2020-07-30 07:20:41 -04:00 committed by GitHub
parent 3950ae51ef
commit b3a97d6dac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 190 additions and 207 deletions

1
changelog.d/7976.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -15,11 +15,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import List
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database from synapse.storage.database import Database
@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": notify_count, "highlight_count": highlight_count} return {"notify_count": notify_count, "highlight_count": highlight_count}
@defer.inlineCallbacks async def get_push_action_users_in_range(
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): self, min_stream_ordering, max_stream_ordering
):
def f(txn): def f(txn):
sql = ( sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE" "SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering)) txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn] return [r[0] for r in txn]
ret = yield self.db.runInteraction("get_push_action_users_in_range", f) ret = await self.db.runInteraction("get_push_action_users_in_range", f)
return ret return ret
@defer.inlineCallbacks async def get_unread_push_actions_for_user_in_range_for_http(
def get_unread_push_actions_for_user_in_range_for_http( self,
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 user_id: str,
): min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user, """Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher. within the given stream ordering range. Called by the httppusher.
Args: Args:
user_id (str): The user to fetch push actions for. user_id: The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return. limit: The maximum number of rows to return.
Returns: Returns:
A promise which resolves to a list of dicts with the keys "event_id", A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
"room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering. The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.db.runInteraction( after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
) )
@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.db.runInteraction( no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
) )
@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# one of the subqueries may have hit the limit. # one of the subqueries may have hit the limit.
return notifs[:limit] return notifs[:limit]
@defer.inlineCallbacks async def get_unread_push_actions_for_user_in_range_for_email(
def get_unread_push_actions_for_user_in_range_for_email( self,
self, user_id, min_stream_ordering, max_stream_ordering, limit=20 user_id: str,
): min_stream_ordering: int,
max_stream_ordering: int,
limit: int = 20,
) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user, """Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher within the given stream ordering range. Called by the emailpusher
Args: Args:
user_id (str): The user to fetch push actions for. user_id: The user to fetch push actions for.
min_stream_ordering(int): The exclusive lower bound on the min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
max_stream_ordering(int): The inclusive upper bound on the max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch. stream ordering of event push actions to fetch.
limit (int): The maximum number of rows to return. limit: The maximum number of rows to return.
Returns: Returns:
A promise which resolves to a list of dicts with the keys "event_id", A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
"room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts. The list will be ordered by descending received_ts.
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
after_read_receipt = yield self.db.runInteraction( after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
) )
@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args) txn.execute(sql, args)
return txn.fetchall() return txn.fetchall()
no_read_receipt = yield self.db.runInteraction( no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
) )
@ -461,17 +465,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"add_push_actions_to_staging", _add_push_actions_to_staging_txn "add_push_actions_to_staging", _add_push_actions_to_staging_txn
) )
@defer.inlineCallbacks async def remove_push_actions_from_staging(self, event_id: str) -> None:
def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push """Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB actions don't build up in the DB
Args:
event_id (str)
""" """
try: try:
res = yield self.db.simple_delete( res = await self.db.simple_delete(
table="event_push_actions_staging", table="event_push_actions_staging",
keyvalues={"event_id": event_id}, keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging", desc="remove_push_actions_from_staging",
@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end return range_end
@defer.inlineCallbacks async def get_time_of_last_push_action_before(self, stream_ordering):
def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn): def f(txn):
sql = ( sql = (
"SELECT e.received_ts" "SELECT e.received_ts"
@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,)) txn.execute(sql, (stream_ordering,))
return txn.fetchone() return txn.fetchone()
result = yield self.db.runInteraction("get_time_of_last_push_action_before", f) result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None return result[0] if result else None
@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000 self._start_rotate_notifs, 30 * 60 * 1000
) )
@defer.inlineCallbacks async def get_push_actions_for_user(
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False self, user_id, before=None, limit=50, only_highlight=False
): ):
def f(txn): def f(txn):
@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, args) txn.execute(sql, args)
return self.db.cursor_to_dict(txn) return self.db.cursor_to_dict(txn)
push_actions = yield self.db.runInteraction("get_push_actions_for_user", f) push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions: for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions return push_actions
@defer.inlineCallbacks async def get_latest_push_action_stream_ordering(self):
def get_latest_push_action_stream_ordering(self):
def f(txn): def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone() return txn.fetchone()
result = yield self.db.runInteraction( result = await self.db.runInteraction(
"get_latest_push_action_stream_ordering", f "get_latest_push_action_stream_ordering", f
) )
return result[0] or 0 return result[0] or 0
@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def _start_rotate_notifs(self): def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs) return run_as_background_process("rotate_notifs", self._rotate_notifs)
@defer.inlineCallbacks async def _rotate_notifs(self):
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None: if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return return
self._doing_notif_rotation = True self._doing_notif_rotation = True
@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True: while True:
logger.info("Rotating notifications") logger.info("Rotating notifications")
caught_up = yield self.db.runInteraction( caught_up = await self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn "_rotate_notifs", self._rotate_notifs_txn
) )
if caught_up: if caught_up:
break break
yield self.hs.get_clock().sleep(self._rotate_delay) await self.hs.get_clock().sleep(self._rotate_delay)
finally: finally:
self._doing_notif_rotation = False self._doing_notif_rotation = False

View file

@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn) return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
@defer.inlineCallbacks async def get_largest_public_rooms(
def get_largest_public_rooms(
self, self,
network_tuple: Optional[ThirdPartyInstanceID], network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict], search_filter: Optional[dict],
@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
return results return results
ret_val = yield self.db.runInteraction( ret_val = await self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn "get_largest_public_rooms", _get_largest_public_rooms_txn
) )
defer.returnValue(ret_val) return ret_val
@cached(max_entries=10000) @cached(max_entries=10000)
def is_room_blocked(self, room_id): def is_room_blocked(self, room_id):
@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_rooms_paginate", _get_rooms_paginate_txn, "get_rooms_paginate", _get_rooms_paginate_txn,
) )
@cachedInlineCallbacks(max_entries=10000) @cached(max_entries=10000)
def get_ratelimit_for_user(self, user_id): async def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given """Check if there are any overrides for ratelimiting for the given
user user
@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely. disabled for that user entirely.
""" """
row = yield self.db.simple_select_one( row = await self.db.simple_select_one(
table="ratelimit_override", table="ratelimit_override",
keyvalues={"user_id": user_id}, keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"), retcols=("messages_per_second", "burst_count"),
@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
else: else:
return None return None
@cachedInlineCallbacks() @cached()
def get_retention_policy_for_room(self, room_id): async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room. """Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined If no retention policy has been found for this room, returns a policy defined
@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.cursor_to_dict(txn) return self.db.cursor_to_dict(txn)
ret = yield self.db.runInteraction( ret = await self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn, "get_retention_policy_for_room", get_retention_policy_for_room_txn,
) )
# If we don't know this room ID, ret will be None, in this case return the default # If we don't know this room ID, ret will be None, in this case return the default
# policy. # policy.
if not ret: if not ret:
defer.returnValue( return {
{
"min_lifetime": self.config.retention_default_min_lifetime, "min_lifetime": self.config.retention_default_min_lifetime,
"max_lifetime": self.config.retention_default_max_lifetime, "max_lifetime": self.config.retention_default_max_lifetime,
} }
)
row = ret[0] row = ret[0]
@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
if row["max_lifetime"] is None: if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention_default_max_lifetime row["max_lifetime"] = self.config.retention_default_max_lifetime
defer.returnValue(row) return row
def get_media_mxcs_in_room(self, room_id): def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room """Retrieves all the local and remote media MXC URIs in a given room
@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column, self._background_add_rooms_room_version_column,
) )
@defer.inlineCallbacks async def _background_insert_retention(self, progress, batch_size):
def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of """Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table. them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's NULLs the property's columns if missing from the retention event in the room's
@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else: else:
return False return False
end = yield self.db.runInteraction( end = await self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn, "insert_room_retention", _background_insert_retention_txn,
) )
if end: if end:
yield self.db.updates._end_background_update("insert_room_retention") await self.db.updates._end_background_update("insert_room_retention")
defer.returnValue(batch_size) return batch_size
async def _background_add_rooms_room_version_column( async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int self, progress: dict, batch_size: int
@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False, lock=False,
) )
@defer.inlineCallbacks async def store_room(
def store_room(
self, self,
room_id: str, room_id: str,
room_creator_user_id: str, room_creator_user_id: str,
@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction("store_room_txn", store_room_txn, next_id) await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e: except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e) logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.") raise StoreError(500, "Problem creating room.")
@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False, lock=False,
) )
@defer.inlineCallbacks async def set_room_is_public(self, room_id, is_public):
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id): def set_room_is_public_txn(txn, next_id):
self.db.simple_update_one_txn( self.db.simple_update_one_txn(
txn, txn,
@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction( await self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id "set_room_is_public", set_room_is_public_txn, next_id
) )
self.hs.get_notifier().on_new_replication_data() self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks async def set_room_is_public_appservice(
def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public self, room_id, appservice_id, network_id, is_public
): ):
"""Edit the appservice/network specific public room list. """Edit the appservice/network specific public room list.
@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
) )
with self._public_room_id_gen.get_next() as next_id: with self._public_room_id_gen.get_next() as next_id:
yield self.db.runInteraction( await self.db.runInteraction(
"set_room_is_public_appservice", "set_room_is_public_appservice",
set_room_is_public_appservice_txn, set_room_is_public_appservice_txn,
next_id, next_id,
@ -1327,49 +1318,44 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self): def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token() return self._public_room_id_gen.get_current_token()
@defer.inlineCallbacks async def block_room(self, room_id: str, user_id: str) -> None:
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times. """Marks the room as blocked. Can be called multiple times.
Args: Args:
room_id (str): Room to block room_id: Room to block
user_id (str): Who blocked it user_id: Who blocked it
Returns:
Deferred
""" """
yield self.db.simple_upsert( await self.db.simple_upsert(
table="blocked_rooms", table="blocked_rooms",
keyvalues={"room_id": room_id}, keyvalues={"room_id": room_id},
values={}, values={},
insertion_values={"user_id": user_id}, insertion_values={"user_id": user_id},
desc="block_room", desc="block_room",
) )
yield self.db.runInteraction( await self.db.runInteraction(
"block_room_invalidation", "block_room_invalidation",
self._invalidate_cache_and_stream, self._invalidate_cache_and_stream,
self.is_room_blocked, self.is_room_blocked,
(room_id,), (room_id,),
) )
@defer.inlineCallbacks async def get_rooms_for_retention_period_in_range(
def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
self, min_ms, max_ms, include_null=False ) -> Dict[str, dict]:
):
"""Retrieves all of the rooms within the given retention range. """Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy. Optionally includes the rooms which don't have a retention policy.
Args: Args:
min_ms (int|None): Duration in milliseconds that define the lower limit of min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit. the range to handle (exclusive). If None, doesn't set a lower limit.
max_ms (int|None): Duration in milliseconds that define the upper limit of max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit. the range to handle (inclusive). If None, doesn't set an upper limit.
include_null (bool): Whether to include rooms which retention policy is NULL include_null: Whether to include rooms which retention policy is NULL
in the returned set. in the returned set.
Returns: Returns:
dict[str, dict]: The rooms within this range, along with their retention The rooms within this range, along with their retention
policy. The key is "room_id", and maps to a dict describing the retention policy. The key is "room_id", and maps to a dict describing the retention
policy associated with this room ID. The keys for this nested dict are policy associated with this room ID. The keys for this nested dict are
"min_lifetime" (int|None), and "max_lifetime" (int|None). "min_lifetime" (int|None), and "max_lifetime" (int|None).
@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict return rooms_dict
rooms = yield self.db.runInteraction( rooms = await self.db.runInteraction(
"get_rooms_for_retention_period_in_range", "get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn, get_rooms_for_retention_period_in_range_txn,
) )
defer.returnValue(rooms) return rooms

View file

@ -16,12 +16,12 @@
import collections.abc import collections.abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Iterable, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@ -108,16 +108,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id) create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1") return create_event.content.get("room_version", "1")
@defer.inlineCallbacks async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
def get_room_predecessor(self, room_id):
"""Get the predecessor of an upgraded room if it exists. """Get the predecessor of an upgraded room if it exists.
Otherwise return None. Otherwise return None.
Args: Args:
room_id (str) room_id: The room ID.
Returns: Returns:
Deferred[dict|None]: A dictionary containing the structure of the predecessor A dictionary containing the structure of the predecessor
field from the room's create event. The structure is subject to other servers, field from the room's create event. The structure is subject to other servers,
but it is expected to be: but it is expected to be:
* room_id (str): The room ID of the predecessor room * room_id (str): The room ID of the predecessor room
@ -129,7 +128,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
NotFoundError if the given room is unknown NotFoundError if the given room is unknown
""" """
# Retrieve the room's create event # Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id) create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event # Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None) predecessor = create_event.content.get("predecessor", None)
@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor return predecessor
@defer.inlineCallbacks async def get_create_event_for_room(self, room_id: str) -> EventBase:
def get_create_event_for_room(self, room_id):
"""Get the create state event for a room. """Get the create state event for a room.
Args: Args:
room_id (str) room_id: The room ID.
Returns: Returns:
Deferred[EventBase]: The room creation event. The room creation event.
Raises: Raises:
NotFoundError if the room is unknown NotFoundError if the room is unknown
""" """
state_ids = yield self.get_current_state_ids(room_id) state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, "")) create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end # If we can't find the create event, assume we've hit a dead end
@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,)) raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return # Retrieve the room's create event and return
create_event = yield self.get_event(create_id) create_event = await self.get_event(create_id)
return create_event return create_event
@cached(max_entries=100000, iterable=True) @cached(max_entries=100000, iterable=True)
@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
) )
@defer.inlineCallbacks async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
def get_canonical_alias_for_room(self, room_id):
"""Get canonical alias for room, if any """Get canonical alias for room, if any
Args: Args:
room_id (str) room_id: The room ID
Returns: Returns:
Deferred[str|None]: The canonical alias, if any The canonical alias, if any
""" """
state = yield self.get_filtered_current_state_ids( state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
) )
@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id: if not event_id:
return return
event = yield self.get_event(event_id, allow_none=True) event = await self.get_event(event_id, allow_none=True)
if not event: if not event:
return return
@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows} return {row["event_id"]: row["state_group"] for row in rows}
@defer.inlineCallbacks async def get_referenced_state_groups(
def get_referenced_state_groups(self, state_groups): self, state_groups: Iterable[int]
) -> Set[int]:
"""Check if the state groups are referenced by events. """Check if the state groups are referenced by events.
Args: Args:
state_groups (Iterable[int]) state_groups
Returns: Returns:
Deferred[set[int]]: The subset of state groups that are The subset of state groups that are referenced.
referenced.
""" """
rows = yield self.db.simple_select_many_batch( rows = await self.db.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="state_group", column="state_group",
iterable=state_groups, iterable=state_groups,

View file

@ -16,8 +16,8 @@
import logging import logging
from itertools import chain from itertools import chain
from typing import Tuple
from twisted.internet import defer
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
""" """
return (ts // self.stats_bucket_size) * self.stats_bucket_size return (ts // self.stats_bucket_size) * self.stats_bucket_size
@defer.inlineCallbacks async def _populate_stats_process_users(self, progress, batch_size):
def _populate_stats_process_users(self, progress, batch_size):
""" """
This is a background update which regenerates statistics for users. This is a background update which regenerates statistics for users.
""" """
if not self.stats_enabled: if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_users") await self.db.updates._end_background_update("populate_stats_process_users")
return 1 return 1
last_user_id = progress.get("last_user_id", "") last_user_id = progress.get("last_user_id", "")
@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size)) txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn] return [r for r, in txn]
users_to_work_on = yield self.db.runInteraction( users_to_work_on = await self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch "_populate_stats_process_users", _get_next_batch
) )
# No more rooms -- complete the transaction. # No more rooms -- complete the transaction.
if not users_to_work_on: if not users_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_users") await self.db.updates._end_background_update("populate_stats_process_users")
return 1 return 1
for user_id in users_to_work_on: for user_id in users_to_work_on:
yield self._calculate_and_set_initial_state_for_user(user_id) await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id progress["last_user_id"] = user_id
yield self.db.runInteraction( await self.db.runInteraction(
"populate_stats_process_users", "populate_stats_process_users",
self.db.updates._background_update_progress_txn, self.db.updates._background_update_progress_txn,
"populate_stats_process_users", "populate_stats_process_users",
@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on) return len(users_to_work_on)
@defer.inlineCallbacks async def _populate_stats_process_rooms(self, progress, batch_size):
def _populate_stats_process_rooms(self, progress, batch_size):
""" """
This is a background update which regenerates statistics for rooms. This is a background update which regenerates statistics for rooms.
""" """
if not self.stats_enabled: if not self.stats_enabled:
yield self.db.updates._end_background_update("populate_stats_process_rooms") await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1 return 1
last_room_id = progress.get("last_room_id", "") last_room_id = progress.get("last_room_id", "")
@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size)) txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn] return [r for r, in txn]
rooms_to_work_on = yield self.db.runInteraction( rooms_to_work_on = await self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch "populate_stats_rooms_get_batch", _get_next_batch
) )
# No more rooms -- complete the transaction. # No more rooms -- complete the transaction.
if not rooms_to_work_on: if not rooms_to_work_on:
yield self.db.updates._end_background_update("populate_stats_process_rooms") await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1 return 1
for room_id in rooms_to_work_on: for room_id in rooms_to_work_on:
yield self._calculate_and_set_initial_state_for_room(room_id) await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id progress["last_room_id"] = room_id
yield self.db.runInteraction( await self.db.runInteraction(
"_populate_stats_process_rooms", "_populate_stats_process_rooms",
self.db.updates._background_update_progress_txn, self.db.updates._background_update_progress_txn,
"populate_stats_process_rooms", "populate_stats_process_rooms",
@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
return room_deltas, user_deltas return room_deltas, user_deltas
@defer.inlineCallbacks async def _calculate_and_set_initial_state_for_room(
def _calculate_and_set_initial_state_for_room(self, room_id): self, room_id: str
) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current. """Calculate and insert an entry into room_stats_current.
Args: Args:
room_id (str) room_id: The room ID under calculation.
Returns: Returns:
Deferred[tuple[dict, dict, int]]: A tuple of room state, membership A tuple of room state, membership counts and stream position.
counts and stream position.
""" """
def _fetch_current_state_stats(txn): def _fetch_current_state_stats(txn):
@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
current_state_events_count, current_state_events_count,
users_in_room, users_in_room,
pos, pos,
) = yield self.db.runInteraction( ) = await self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats "get_initial_state_for_room", _fetch_current_state_stats
) )
state_event_map = yield self.get_events(event_ids, get_prev_content=False) state_event_map = await self.get_events(event_ids, get_prev_content=False)
room_state = { room_state = {
"join_rules": None, "join_rules": None,
@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
event.content.get("m.federate", True) is True event.content.get("m.federate", True) is True
) )
yield self.update_room_state(room_id, room_state) await self.update_room_state(room_id, room_state)
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)] local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
yield self.update_stats_delta( await self.update_stats_delta(
ts=self.clock.time_msec(), ts=self.clock.time_msec(),
stats_type="room", stats_type="room",
stats_id=room_id, stats_id=room_id,
@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
}, },
) )
@defer.inlineCallbacks async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn): def _calculate_and_set_initial_state_for_user_txn(txn):
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn) pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone() (count,) = txn.fetchone()
return count, pos return count, pos
joined_rooms, pos = yield self.db.runInteraction( joined_rooms, pos = await self.db.runInteraction(
"calculate_and_set_initial_state_for_user", "calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn, _calculate_and_set_initial_state_for_user_txn,
) )
yield self.update_stats_delta( await self.update_stats_delta(
ts=self.clock.time_msec(), ts=self.clock.time_msec(),
stats_type="user", stats_type="user",
stats_id=user_id, stats_id=user_id,

View file

@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"get_state_group_delta", _get_state_group_delta_txn "get_state_group_delta", _get_state_group_delta_txn
) )
@defer.inlineCallbacks async def _get_state_groups_from_groups(
def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: List[int], state_filter: StateFilter
): ) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the """Returns the state groups for a given set of groups from the
database, filtering on types of state events. database, filtering on types of state events.
@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
results = {} results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)] chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks: for chunk in chunks:
res = yield self.db.runInteraction( res = await self.db.runInteraction(
"_get_state_groups_from_groups", "_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, self._get_state_groups_from_groups_txn,
chunk, chunk,
@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks async def _get_state_for_groups(
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
): ) -> Dict[int, StateMap[str]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
member_filter, non_member_filter = state_filter.get_member_split() member_filter, non_member_filter = state_filter.get_member_split()
@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
( (
non_member_state, non_member_state,
incomplete_groups_nm, incomplete_groups_nm,
) = yield self._get_state_for_groups_using_cache( ) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter groups, self._state_group_cache, state_filter=non_member_filter
) )
( (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
member_state,
incomplete_groups_m,
) = yield self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter groups, self._state_group_members_cache, state_filter=member_filter
) )
@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Help the cache hit ratio by expanding the filter a bit # Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded() db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups( group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter list(incomplete_groups), state_filter=db_state_filter
) )
@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )
@defer.inlineCallbacks async def get_previous_state_groups(
def get_previous_state_groups(self, state_groups): self, state_groups: Iterable[int]
) -> Dict[int, int]:
"""Fetch the previous groups of the given state groups. """Fetch the previous groups of the given state groups.
Args: Args:
state_groups (Iterable[int]) state_groups
Returns: Returns:
Deferred[dict[int, int]]: mapping from state group to previous A mapping from state group to previous state group.
state group.
""" """
rows = yield self.db.simple_select_many_batch( rows = await self.db.simple_select_many_batch(
table="state_group_edges", table="state_group_edges",
column="prev_state_group", column="prev_state_group",
iterable=state_groups, iterable=state_groups,

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar from typing import Awaitable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr import attr
@ -419,7 +419,7 @@ class StateGroupStorage(object):
def _get_state_groups_from_groups( def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter self, groups: List[int], state_filter: StateFilter
): ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Returns the state groups for a given set of groups, filtering on """Returns the state groups for a given set of groups, filtering on
types of state events. types of state events.
@ -429,7 +429,7 @@ class StateGroupStorage(object):
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@ -532,7 +532,7 @@ class StateGroupStorage(object):
def _get_state_for_groups( def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all() self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
): ) -> Awaitable[Dict[int, StateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally """Gets the state at each of a list of state groups, optionally
filtering by type/state_key filtering by type/state_key
@ -540,8 +540,9 @@ class StateGroupStorage(object):
groups: list of state groups for which we want to get the state. groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state. state_filter: The state filter used to fetch state.
from the database. from the database.
Returns: Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. Dict of state group to state map.
""" """
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)

View file

@ -39,15 +39,19 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_http(self): def test_get_unread_push_actions_for_user_in_range_for_http(self):
yield self.store.get_unread_push_actions_for_user_in_range_for_http( yield defer.ensureDeferred(
self.store.get_unread_push_actions_for_user_in_range_for_http(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_unread_push_actions_for_user_in_range_for_email(self): def test_get_unread_push_actions_for_user_in_range_for_email(self):
yield self.store.get_unread_push_actions_for_user_in_range_for_email( yield defer.ensureDeferred(
self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_count_aggregation(self): def test_count_aggregation(self):

View file

@ -37,12 +37,14 @@ class RoomStoreTestCase(unittest.TestCase):
self.alias = RoomAlias.from_string("#a-room-name:test") self.alias = RoomAlias.from_string("#a-room-name:test")
self.u_creator = UserID.from_string("@creator:test") self.u_creator = UserID.from_string("@creator:test")
yield self.store.store_room( yield defer.ensureDeferred(
self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id=self.u_creator.to_string(), room_creator_user_id=self.u_creator.to_string(),
is_public=True, is_public=True,
room_version=RoomVersions.V1, room_version=RoomVersions.V1,
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_room(self): def test_get_room(self):
@ -88,12 +90,14 @@ class RoomEventsStoreTestCase(unittest.TestCase):
self.room = RoomID.from_string("!abcde:test") self.room = RoomID.from_string("!abcde:test")
yield self.store.store_room( yield defer.ensureDeferred(
self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
is_public=True, is_public=True,
room_version=RoomVersions.V1, room_version=RoomVersions.V1,
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):

View file

@ -44,12 +44,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room = RoomID.from_string("!abc123:test") self.room = RoomID.from_string("!abc123:test")
yield self.store.store_room( yield defer.ensureDeferred(
self.store.store_room(
self.room.to_string(), self.room.to_string(),
room_creator_user_id="@creator:text", room_creator_user_id="@creator:text",
is_public=True, is_public=True,
room_version=RoomVersions.V1, room_version=RoomVersions.V1,
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_state_event(self, room, sender, typ, state_key, content): def inject_state_event(self, room, sender, typ, state_key, content):