0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-21 01:22:02 +01:00

Convert receipts and events databases to async/await. (#8076)

This commit is contained in:
Patrick Cloke 2020-08-14 10:05:19 -04:00 committed by GitHub
parent dc22090a67
commit e8861957d9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 80 additions and 82 deletions

1
changelog.d/8076.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -17,13 +17,11 @@
import itertools import itertools
import logging import logging
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
@ -113,15 +111,14 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name() hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master" ), "Can only instantiate EventsStore on master"
@defer.inlineCallbacks async def _persist_events_and_state_updates(
def _persist_events_and_state_updates(
self, self,
events_and_contexts: List[Tuple[EventBase, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]], current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState], state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]], new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False, backfilled: bool = False,
): ) -> None:
"""Persist a set of events alongside updates to the current state and """Persist a set of events alongside updates to the current state and
forward extremities tables. forward extremities tables.
@ -136,7 +133,7 @@ class PersistEventsStore:
backfilled backfilled
Returns: Returns:
Deferred: resolves when the events have been persisted Resolves when the events have been persisted
""" """
# We want to calculate the stream orderings as late as possible, as # We want to calculate the stream orderings as late as possible, as
@ -168,7 +165,7 @@ class PersistEventsStore:
for (event, context), stream in zip(events_and_contexts, stream_orderings): for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream event.internal_metadata.stream_ordering = stream
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"persist_events", "persist_events",
self._persist_events_txn, self._persist_events_txn,
events_and_contexts=events_and_contexts, events_and_contexts=events_and_contexts,
@ -206,16 +203,15 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids) (room_id,), list(latest_event_ids)
) )
@defer.inlineCallbacks async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
def _get_events_which_are_prevs(self, event_ids):
"""Filter the supplied list of event_ids to get those which are prev_events of """Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events. existing (non-outlier/rejected) events.
Args: Args:
event_ids (Iterable[str]): event ids to filter event_ids: event ids to filter
Returns: Returns:
Deferred[List[str]]: filtered event ids Filtered event ids
""" """
results = [] results = []
@ -240,14 +236,13 @@ class PersistEventsStore:
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed")) results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
) )
return results return results
@defer.inlineCallbacks async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
def _get_prevs_before_rejected(self, event_ids):
"""Get soft-failed ancestors to remove from the extremities. """Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or Given a set of events, find all those that have been soft-failed or
@ -259,11 +254,11 @@ class PersistEventsStore:
are separated by soft failed events. are separated by soft failed events.
Args: Args:
event_ids (Iterable[str]): Events to find prev events for. Note event_ids: Events to find prev events for. Note that these must have
that these must have already been persisted. already been persisted.
Returns: Returns:
Deferred[set[str]] The previous events.
""" """
# The set of event_ids to return. This includes all soft-failed events # The set of event_ids to return. This includes all soft-failed events
@ -304,7 +299,7 @@ class PersistEventsStore:
existing_prevs.add(prev_event_id) existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100): for chunk in batch_iter(event_ids, 100):
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
) )

View file

@ -15,8 +15,6 @@
import logging import logging
from twisted.internet import defer
from synapse.api.constants import EventContentFields from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
@ -94,8 +92,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
where_clause="NOT have_censored", where_clause="NOT have_censored",
) )
@defer.inlineCallbacks async def _background_reindex_fields_sender(self, progress, batch_size):
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
@ -155,19 +152,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows) return len(rows)
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
) )
if not result: if not result:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
) )
return result return result
@defer.inlineCallbacks async def _background_reindex_origin_server_ts(self, progress, batch_size):
def _background_reindex_origin_server_ts(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"] target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
@ -234,19 +230,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows_to_update) return len(rows_to_update)
result = yield self.db_pool.runInteraction( result = await self.db_pool.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
) )
if not result: if not result:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.EVENT_ORIGIN_SERVER_TS_NAME self.EVENT_ORIGIN_SERVER_TS_NAME
) )
return result return result
@defer.inlineCallbacks async def _cleanup_extremities_bg_update(self, progress, batch_size):
def _cleanup_extremities_bg_update(self, progress, batch_size):
"""Background update to clean out extremities that should have been """Background update to clean out extremities that should have been
deleted previously. deleted previously.
@ -414,26 +409,25 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(original_set) return len(original_set)
num_handled = yield self.db_pool.runInteraction( num_handled = await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
) )
if not num_handled: if not num_handled:
yield self.db_pool.updates._end_background_update( await self.db_pool.updates._end_background_update(
self.DELETE_SOFT_FAILED_EXTREMITIES self.DELETE_SOFT_FAILED_EXTREMITIES
) )
def _drop_table_txn(txn): def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check") txn.execute("DROP TABLE _extremities_to_check")
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
) )
return num_handled return num_handled
@defer.inlineCallbacks async def _redactions_received_ts(self, progress, batch_size):
def _redactions_received_ts(self, progress, batch_size):
"""Handles filling out the `received_ts` column in redactions. """Handles filling out the `received_ts` column in redactions.
""" """
last_event_id = progress.get("last_event_id", "") last_event_id = progress.get("last_event_id", "")
@ -480,17 +474,16 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return len(rows) return len(rows)
count = yield self.db_pool.runInteraction( count = await self.db_pool.runInteraction(
"_redactions_received_ts", _redactions_received_ts_txn "_redactions_received_ts", _redactions_received_ts_txn
) )
if not count: if not count:
yield self.db_pool.updates._end_background_update("redactions_received_ts") await self.db_pool.updates._end_background_update("redactions_received_ts")
return count return count
@defer.inlineCallbacks async def _event_fix_redactions_bytes(self, progress, batch_size):
def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON. """Undoes hex encoded censored redacted event JSON.
""" """
@ -511,16 +504,15 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
txn.execute("DROP INDEX redactions_censored_redacts") txn.execute("DROP INDEX redactions_censored_redacts")
yield self.db_pool.runInteraction( await self.db_pool.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
) )
yield self.db_pool.updates._end_background_update("event_fix_redactions_bytes") await self.db_pool.updates._end_background_update("event_fix_redactions_bytes")
return 1 return 1
@defer.inlineCallbacks async def _event_store_labels(self, progress, batch_size):
def _event_store_labels(self, progress, batch_size):
"""Background update handler which will store labels for existing events.""" """Background update handler which will store labels for existing events."""
last_event_id = progress.get("last_event_id", "") last_event_id = progress.get("last_event_id", "")
@ -575,11 +567,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
return nbrows return nbrows
num_rows = yield self.db_pool.runInteraction( num_rows = await self.db_pool.runInteraction(
desc="event_store_labels", func=_event_store_labels_txn desc="event_store_labels", func=_event_store_labels_txn
) )
if not num_rows: if not num_rows:
yield self.db_pool.updates._end_background_update("event_store_labels") await self.db_pool.updates._end_background_update("event_store_labels")
return num_rows return num_rows

View file

@ -16,7 +16,7 @@
import abc import abc
import logging import logging
from typing import List, Tuple from typing import List, Optional, Tuple
from twisted.internet import defer from twisted.internet import defer
@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -56,9 +56,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
""" """
raise NotImplementedError() raise NotImplementedError()
@cachedInlineCallbacks() @cached()
def get_users_with_read_receipts_in_room(self, room_id): async def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read") receipts = await self.get_receipts_for_room(room_id, "m.read")
return {r["user_id"] for r in receipts} return {r["user_id"] for r in receipts}
@cached(num_args=2) @cached(num_args=2)
@ -84,9 +84,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@cachedInlineCallbacks(num_args=2) @cached(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type): async def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self.db_pool.simple_select_list( rows = await self.db_pool.simple_select_list(
table="receipts_linearized", table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type}, keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"), retcols=("room_id", "event_id"),
@ -95,8 +95,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return {row["room_id"]: row["event_id"] for row in rows} return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
def f(txn): def f(txn):
sql = ( sql = (
"SELECT rl.room_id, rl.event_id," "SELECT rl.room_id, rl.event_id,"
@ -110,7 +109,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,)) txn.execute(sql, (user_id,))
return txn.fetchall() return txn.fetchall()
rows = yield self.db_pool.runInteraction( rows = await self.db_pool.runInteraction(
"get_receipts_for_user_with_orderings", f "get_receipts_for_user_with_orderings", f
) )
return { return {
@ -122,56 +121,61 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows for row in rows
} }
@defer.inlineCallbacks async def get_linearized_receipts_for_rooms(
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for multiple rooms for sending to clients. """Get receipts for multiple rooms for sending to clients.
Args: Args:
room_ids (list): List of room_ids. room_id: List of room_ids.
to_key (int): Max stream id to fetch receipts upto. to_key: Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches from_key: Min stream id to fetch receipts from. None fetches
from the start. from the start.
Returns: Returns:
list: A list of receipts. A list of receipts.
""" """
room_ids = set(room_ids) room_ids = set(room_ids)
if from_key is not None: if from_key is not None:
# Only ask the database about rooms where there have been new # Only ask the database about rooms where there have been new
# receipts added since `from_key` # receipts added since `from_key`
room_ids = yield self._receipts_stream_cache.get_entities_changed( room_ids = self._receipts_stream_cache.get_entities_changed(
room_ids, from_key room_ids, from_key
) )
results = yield self._get_linearized_receipts_for_rooms( results = await self._get_linearized_receipts_for_rooms(
room_ids, to_key, from_key=from_key room_ids, to_key, from_key=from_key
) )
return [ev for res in results.values() for ev in res] return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): async def get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
Args: Args:
room_ids (str): The room id. room_ids: The room id.
to_key (int): Max stream id to fetch receipts upto. to_key: Max stream id to fetch receipts upto.
from_key (int): Min stream id to fetch receipts from. None fetches from_key: Min stream id to fetch receipts from. None fetches
from the start. from the start.
Returns: Returns:
Deferred[list]: A list of receipts. A list of receipts.
""" """
if from_key is not None: if from_key is not None:
# Check the cache first to see if any new receipts have been added # Check the cache first to see if any new receipts have been added
# since`from_key`. If not we can no-op. # since`from_key`. If not we can no-op.
if not self._receipts_stream_cache.has_entity_changed(room_id, from_key): if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
defer.succeed([]) return []
return self._get_linearized_receipts_for_room(room_id, to_key, from_key) return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
@cachedInlineCallbacks(num_args=3, tree=True) @cached(num_args=3, tree=True)
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): async def _get_linearized_receipts_for_room(
self, room_id: str, to_key: int, from_key: Optional[int] = None
) -> List[dict]:
"""See get_linearized_receipts_for_room """See get_linearized_receipts_for_room
""" """
@ -195,7 +199,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return rows return rows
rows = yield self.db_pool.runInteraction("get_linearized_receipts_for_room", f) rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
if not rows: if not rows:
return [] return []
@ -345,7 +349,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
def _invalidate_get_users_with_receipts_in_room( def _invalidate_get_users_with_receipts_in_room(
self, room_id, receipt_type, user_id self, room_id: str, receipt_type: str, user_id: str
): ):
if receipt_type != "m.read": if receipt_type != "m.read":
return return
@ -471,15 +475,21 @@ class ReceiptsStore(ReceiptsWorkerStore):
return rx_ts return rx_ts
@defer.inlineCallbacks async def insert_receipt(
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): self,
room_id: str,
receipt_type: str,
user_id: str,
event_ids: List[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server. """Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph Automatically does conversion between linearized and graph
representations. representations.
""" """
if not event_ids: if not event_ids:
return return None
if len(event_ids) == 1: if len(event_ids) == 1:
linearized_event_id = event_ids[0] linearized_event_id = event_ids[0]
@ -506,13 +516,13 @@ class ReceiptsStore(ReceiptsWorkerStore):
else: else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
linearized_event_id = yield self.db_pool.runInteraction( linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv", graph_to_linear "insert_receipt_conv", graph_to_linear
) )
stream_id_manager = self._receipts_id_gen.get_next() stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id: with stream_id_manager as stream_id:
event_ts = yield self.db_pool.runInteraction( event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt", "insert_linearized_receipt",
self.insert_linearized_receipt_txn, self.insert_linearized_receipt_txn,
room_id, room_id,
@ -534,7 +544,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
now - event_ts, now - event_ts,
) )
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data) await self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token() max_persisted_id = self._receipts_id_gen.get_current_token()