0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-01 10:18:54 +02:00

Convert events worker database to async/await. (#8071)

This commit is contained in:
Patrick Cloke 2020-08-18 16:20:49 -04:00 committed by GitHub
parent acfb7c3b5d
commit f40645e60b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 106 additions and 97 deletions

1
changelog.d/8071.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
auth_events: the existing room state.
Raises:
AuthError if the checks fail

View file

@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@ -2155,9 +2151,9 @@ class FederationHandler(BaseHandler):
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_auth_events = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
(e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@ -2173,9 +2169,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.

View file

@ -960,7 +960,7 @@ class EventCreationHandler(object):
allow_none=True,
)
is_admin_redaction = (
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)
@ -1080,8 +1080,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]

View file

@ -716,7 +716,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id)
return (
return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content

View file

@ -51,5 +51,5 @@ class SpamCheckerApi(object):
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()

View file

@ -641,7 +641,7 @@ class StateResolutionStore(object):
allow_rejected (bool): If True return rejected events.
Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
return self.store.get_events(

View file

@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
async def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of events
"""
return self.get_auth_chain_ids(
event_ids = await self.get_auth_chain_ids(
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
)
return await self.get_events_as_list(event_ids)
def get_auth_chain_ids(
self,
@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list)
limit (int)
"""
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
events = await self.get_events_as_list(ids)
return events
return await self.get_events_as_list(ids)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):

View file

@ -19,9 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
from typing import List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, overload
from constantly import NamedConstant, Names
from typing_extensions import Literal
from twisted.internet import defer
@ -32,7 +33,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
from synapse.events import make_event_from_dict
from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
@defer.inlineCallbacks
def get_event(
# Inform mypy that if allow_none is False (the default) then get_event
# always returns an EventBase.
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
) -> EventBase:
...
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
...
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
):
) -> Optional[EventBase]:
"""Get an event from the database by event_id.
Args:
@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
Deferred[EventBase|None]
The event, or None if the event was not found.
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
return event
@defer.inlineCallbacks
def get_events(
async def get_events(
self,
event_ids: List[str],
event_ids: Iterable[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
) -> Dict[str, EventBase]:
"""Get events from the database
Args:
@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
Deferred : Dict from event_id to event.
A mapping from event_id to event.
"""
events = yield self.get_events_as_list(
events = await self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
@defer.inlineCallbacks
def get_events_as_list(
async def get_events_as_list(
self,
event_ids: List[str],
event_ids: Collection[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
) -> List[EventBase]:
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
List of events fetched from the database. The events are in the same
order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db(
event_entry_map = await self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = yield self.get_event(
prev = await self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
return events
@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = yield self._get_events_from_db(
missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _get_events_from_db(self, event_ids, allow_rejected=False):
async def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Deferred[Dict[str, _EventCacheEntry]]:
Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
"""
@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
row_map = yield self._enqueue_events(events_to_fetch)
row_map = await self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@ -650,8 +672,7 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
@defer.inlineCallbacks
def _enqueue_events(self, events):
async def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@ -660,7 +681,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
Dict[str, Dict]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@ -683,7 +704,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
row_map = yield events_d
row_map = await events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@ -842,33 +863,29 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
async def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
return {r["event_id"] for r in rows}
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
async def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
Deferred[set[str]]: The events we have already seen.
set[str]: The events we have already seen.
"""
results = set()
@ -884,7 +901,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@ -914,8 +931,7 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
@defer.inlineCallbacks
def get_room_complexity(self, room_id):
async def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@ -926,9 +942,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
Deferred[dict[str:int]] of complexity version to complexity.
dict[str:int] of complexity version to complexity.
"""
state_events = yield self.get_current_state_event_counts(room_id)
state_events = await self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@ -1165,9 +1181,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
res = yield self.db_pool.simple_select_one(
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
res = await self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},

View file

@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:

View file

@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=defer.succeed({"123": mock_event})
return_value=make_awaitable({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

View file

@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)