0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-10 22:58:53 +02:00

Convert state and stream stores and related code to async (#8194)

This commit is contained in:
Patrick Cloke 2020-08-28 09:37:55 -04:00 committed by GitHub
parent b055dc9322
commit aec7085179
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 51 additions and 45 deletions

1
changelog.d/8194.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -451,7 +451,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in old_room_member_state_events.items():
for old_event in old_room_member_state_events.values():
# Only transfer ban events
if (
"membership" in old_event.content

View file

@ -27,6 +27,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@ -163,15 +164,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
Args:
room_id (str)
room_id: The room to get the state IDs of.
Returns:
deferred: dict of (type, state_key) -> event_id
The current state of the room.
"""
def _get_current_state_ids_txn(txn):
@ -184,14 +185,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: StateFilter = StateFilter.all()
):
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
@ -202,14 +203,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
from the database.
Returns:
defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
Map from type/state_key to event ID.
"""
where_clause, where_args = state_filter.make_sql_filter_clause()
if not where_clause:
# We delegate to the cached version
return self.get_current_state_ids(room_id)
return await self.get_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(txn):
results = {}
@ -231,7 +232,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)

View file

@ -14,8 +14,7 @@
# limitations under the License.
import logging
from twisted.internet import defer
from typing import Any, Dict, List, Tuple
from synapse.storage._base import SQLBaseStore
@ -23,7 +22,9 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@ -37,12 +38,12 @@ class StateDeltasStore(SQLBaseStore):
if it's new state.
Args:
prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
Deferred[tuple[int, list[dict]]: A tuple consisting of:
A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return defer.succeed((max_stream_id, []))
return (max_stream_id, [])
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@ -102,7 +103,7 @@ class StateDeltasStore(SQLBaseStore):
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
@ -114,8 +115,8 @@ class StateDeltasStore(SQLBaseStore):
retcol="COALESCE(MAX(stream_id), -1)",
)
def get_max_stream_id_in_current_state_deltas(self):
return self.db_pool.runInteraction(
async def get_max_stream_id_in_current_state_deltas(self):
return await self.db_pool.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)

View file

@ -539,7 +539,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows, token
def get_room_event_before_stream_ordering(self, room_id: str, stream_ordering: int):
async def get_room_event_before_stream_ordering(
self, room_id: str, stream_ordering: int
) -> Tuple[int, int, str]:
"""Gets details of the first event in a room at or before a stream ordering
Args:
@ -547,8 +549,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering:
Returns:
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
A tuple of (stream ordering, topological ordering, event_id)
"""
def _f(txn):
@ -563,7 +564,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
return self.db_pool.runInteraction("get_room_event_before_stream_ordering", _f)
return await self.db_pool.runInteraction(
"get_room_event_before_stream_ordering", _f
)
async def get_room_events_max_id(self, room_id: Optional[str] = None) -> str:
"""Returns the current token for rooms stream.

View file

@ -17,8 +17,6 @@ import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@ -103,7 +101,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
async def get_state_group_delta(self, state_group):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@ -135,7 +133,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_state_group_delta", _get_state_group_delta_txn
)
@ -367,9 +365,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
fetched_keys=non_member_types,
)
def store_state_group(
async def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@ -383,7 +381,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to event_id.
Returns:
Deferred[int]: The state group ID
The state group ID
"""
def _store_state_group_txn(txn):
@ -484,11 +482,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_group
return self.db_pool.runInteraction("store_state_group", _store_state_group_txn)
return await self.db_pool.runInteraction(
"store_state_group", _store_state_group_txn
)
def purge_unreferenced_state_groups(
async def purge_unreferenced_state_groups(
self, room_id: str, state_groups_to_delete
) -> defer.Deferred:
) -> None:
"""Deletes no longer referenced state groups and de-deltas any state
groups that reference them.
@ -499,7 +499,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
to delete.
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_unreferenced_state_groups",
self._purge_unreferenced_state_groups,
room_id,
@ -594,7 +594,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return {row["state_group"]: row["prev_state_group"] for row in rows}
def purge_room_state(self, room_id, state_groups_to_delete):
async def purge_room_state(self, room_id, state_groups_to_delete):
"""Deletes all record of a room from state tables
Args:
@ -602,7 +602,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_groups_to_delete (list[int]): State groups to delete
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"purge_room_state",
self._purge_room_state_txn,
room_id,

View file

@ -333,7 +333,7 @@ class StateGroupStorage(object):
def __init__(self, hs, stores):
self.stores = stores
def get_state_group_delta(self, state_group: int):
async def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
@ -341,11 +341,11 @@ class StateGroupStorage(object):
state_group: The state group used to retrieve state deltas.
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
Tuple[Optional[int], Optional[StateMap[str]]]:
(prev_group, delta_ids)
"""
return self.stores.state.get_state_group_delta(state_group)
return await self.stores.state.get_state_group_delta(state_group)
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
@ -525,7 +525,7 @@ class StateGroupStorage(object):
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
@ -546,14 +546,14 @@ class StateGroupStorage(object):
"""
return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
async def store_state_group(
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
) -> int:
"""Store a new set of state, returning a newly assigned state group.
Args:
@ -567,8 +567,8 @@ class StateGroupStorage(object):
to event_id.
Returns:
Deferred[int]: The state group ID
The state group ID
"""
return self.stores.state.store_state_group(
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)