0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 21:43:50 +01:00

Convert storage layer to async/await. (#7963)

This commit is contained in:
Patrick Cloke 2020-07-28 16:09:53 -04:00 committed by GitHub
parent e866e3b896
commit 3345c166a4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 210 additions and 185 deletions

1
changelog.d/7963.misc Normal file
View file

@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View file

@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue() self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks async def persist_events(
def persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ) -> int:
""" """
Write events to the database Write events to the database
Args: Args:
@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc. which might update the current state etc.
Returns: Returns:
Deferred[int]: the stream ordering of the latest persisted event the stream ordering of the latest persisted event
""" """
partitioned = {} partitioned = {}
for event, ctx in events_and_contexts: for event, ctx in events_and_contexts:
@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned: for room_id in partitioned:
self._maybe_start_persisting(room_id) self._maybe_start_persisting(room_id)
yield make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
max_persisted_id = yield self.main_store.get_current_events_token() return self.main_store.get_current_events_token()
return max_persisted_id async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
@defer.inlineCallbacks ) -> Tuple[int, int]:
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
""" """
Returns: Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``, The stream ordering of `event`, and the stream ordering of the
and the stream ordering of the latest persisted event latest persisted event
""" """
deferred = self._event_persist_queue.add_to_queue( deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled event.room_id, [(event, context)], backfilled=backfilled
@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id) self._maybe_start_persisting(event.room_id)
yield make_deferred_yieldable(deferred) await make_deferred_yieldable(deferred)
max_persisted_id = yield self.main_store.get_current_events_token() max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id) return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events( async def _persist_events(
self, self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]], events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False, backfilled: bool = False,
): ):
"""Calculates the change to current state and forward extremities, and """Calculates the change to current state and forward extremities, and
@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities( async def _calculate_new_extremities(
self, self,
room_id: str, room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]], event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str], latest_event_ids: List[str],
): ):
"""Calculates the new forward extremities for a room given events to """Calculates the new forward extremities for a room given events to
@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events( async def _get_new_state_after_events(
self, self,
room_id: str, room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]], events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str], old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str], new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]: ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined( async def _is_server_still_joined(
self, self,
room_id: str, room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]], ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState, delta: DeltaState,
current_state: Optional[StateMap[str]], current_state: Optional[StateMap[str]],
potentially_left_users: Set[str], potentially_left_users: Set[str],

View file

@ -15,8 +15,7 @@
import itertools import itertools
import logging import logging
from typing import Set
from twisted.internet import defer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores): def __init__(self, hs, stores):
self.stores = stores self.stores = stores
@defer.inlineCallbacks async def purge_room(self, room_id: str):
def purge_room(self, room_id: str):
"""Deletes all record of a room """Deletes all record of a room
""" """
state_groups_to_delete = yield self.stores.main.purge_room(room_id) state_groups_to_delete = await self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete) await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks async def purge_history(
def purge_history(self, room_id, token, delete_local_events): self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point """Deletes room history before a certain point
Args: Args:
room_id (str): room_id: The room ID
token (str): A topological token to delete events before token: A topological token to delete events before
delete_local_events (bool): delete_local_events:
if True, we will delete local events as well as remote ones if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their (instead of just marking them as outliers and deleting their
state groups). state groups).
""" """
state_groups = yield self.stores.main.purge_history( state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events room_id, token, delete_local_events
) )
logger.info("[purge] finding state groups that can be deleted") logger.info("[purge] finding state groups that can be deleted")
sg_to_delete = yield self._find_unreferenced_groups(state_groups) sg_to_delete = await self._find_unreferenced_groups(state_groups)
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete) await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
def _find_unreferenced_groups(self, state_groups):
"""Used when purging history to figure out which state groups can be """Used when purging history to figure out which state groups can be
deleted. deleted.
Args: Args:
state_groups (set[int]): Set of state groups referenced by events state_groups: Set of state groups referenced by events
that are going to be deleted. that are going to be deleted.
Returns: Returns:
Deferred[set[int]] The set of state groups that can be deleted. The set of state groups that can be deleted.
""" """
# Graph of state group -> previous group # Graph of state group -> previous group
graph = {} graph = {}
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100)) current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search next_to_search -= current_search
referenced = yield self.stores.main.get_referenced_state_groups( referenced = await self.stores.main.get_referenced_state_groups(
current_search current_search
) )
referenced_groups |= referenced referenced_groups |= referenced
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced. # groups that are referenced.
current_search -= referenced current_search -= referenced
edges = yield self.stores.state.get_previous_state_groups(current_search) edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values()) prevs = set(edges.values())
# We don't bother re-handling groups we've already seen # We don't bother re-handling groups we've already seen

View file

@ -14,13 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Iterable, List, TypeVar from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap from synapse.types import StateMap
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state. """A filter used when querying for state.
Attributes: Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or types: Map from type to set of state keys (or None). This specifies
None). This specifies which state_keys for the given type to fetch which state_keys for the given type to fetch from the DB. If None
from the DB. If None then all events with that type are fetched. If then all events with that type are fetched. If the set is empty
the set is empty then no events with that type are fetched. then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not include_others: Whether to fetch events with types that do not
appear in `types`. appear in `types`.
""" """
types = attr.ib() types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False) include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self): def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None} self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod @staticmethod
def all(): def all() -> "StateFilter":
"""Creates a filter that fetches everything. """Creates a filter that fetches everything.
Returns: Returns:
StateFilter The new state filter.
""" """
return StateFilter(types={}, include_others=True) return StateFilter(types={}, include_others=True)
@staticmethod @staticmethod
def none(): def none() -> "StateFilter":
"""Creates a filter that fetches nothing. """Creates a filter that fetches nothing.
Returns: Returns:
StateFilter The new state filter.
""" """
return StateFilter(types={}, include_others=False) return StateFilter(types={}, include_others=False)
@staticmethod @staticmethod
def from_types(types): def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types """Creates a filter that only fetches the given types
Args: Args:
types (Iterable[tuple[str, str|None]]): A list of type and state types: A list of type and state keys to fetch. A state_key of None
keys to fetch. A state_key of None fetches everything for fetches everything for that type
that type
Returns: Returns:
StateFilter The new state filter.
""" """
type_dict = {} type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types: for typ, s in types:
if typ in type_dict: if typ in type_dict:
if type_dict[typ] is None: if type_dict[typ] is None:
@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None type_dict[typ] = None
continue continue
type_dict.setdefault(typ, set()).add(s) type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict) return StateFilter(types=type_dict)
@staticmethod @staticmethod
def from_lazy_load_member_list(members): def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member """Creates a filter that returns all non-member events, plus the member
events for the given users events for the given users
Args: Args:
members (iterable[str]): Set of user IDs members: Set of user IDs
Returns: Returns:
StateFilter The new state filter
""" """
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True) return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self): def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed """Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the (except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass current one, i.e. anything that passes the current filter will pass
@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events return all non-member events
Returns: Returns:
StateFilter The new state filter.
""" """
if self.is_full(): if self.is_full():
@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True, include_others=True,
) )
def make_sql_filter_clause(self): def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause. """Converts the filter to an SQL clause.
For example: For example:
@ -179,13 +177,12 @@ class StateFilter(object):
Returns: Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An The SQL string (may be empty) and arguments. An empty SQL string is
empty SQL string is returned when the filter matches everything returned when the filter matches everything (i.e. is "full").
(i.e. is "full").
""" """
where_clause = "" where_clause = ""
where_args = [] where_args = [] # type: List[str]
if self.is_full(): if self.is_full():
return where_clause, where_args return where_clause, where_args
@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args return where_clause, where_args
def max_entries_returned(self): def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if """Returns the maximum number of entries this filter will return if
known, otherwise returns None. known, otherwise returns None.
@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state return filtered_state
def is_full(self): def is_full(self) -> bool:
"""Whether this filter fetches everything or not """Whether this filter fetches everything or not
Returns: Returns:
bool True if the filter fetches everything.
""" """
return self.include_others and not self.types return self.include_others and not self.types
def has_wildcards(self): def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch """Whether the filter includes wildcards or is attempting to fetch
specific state. specific state.
Returns: Returns:
bool True if the filter includes wildcards.
""" """
return self.include_others or any( return self.include_others or any(
state_keys is None for state_keys in self.types.values() state_keys is None for state_keys in self.types.values()
) )
def concrete_types(self): def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that """Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards` will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty). returns False, but otherwise will be a subset (or even empty).
Returns: Returns:
list[tuple[str,str]] A list of type/state_keys tuples.
""" """
return [ return [
(t, s) (t, s)
@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys for s in state_keys
] ]
def get_member_split(self): def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively """Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching matching against member state, and one which assumes it's matching
against non member state. against non member state.
@ -307,7 +304,7 @@ class StateFilter(object):
state caches). state caches).
Returns: Returns:
tuple[StateFilter, StateFilter]: The member and non member filters The member and non member filters
""" """
if EventTypes.Member in self.types: if EventTypes.Member in self.types:
@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between """Given a state group try to return a previous group and a delta between
the old and the new. the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns: Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]: Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids) (prev_group, delta_ids)
@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group) return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks async def get_state_groups_ids(
def get_state_groups_ids(self, _room_id, event_ids): self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
Args: Args:
_room_id (str): id of the room for these events _room_id: id of the room for these events
event_ids (iterable[str]): ids of the events event_ids: ids of the events
Returns: Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
""" """
if not event_ids: if not event_ids:
return {} return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups) group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state return group_to_state
@defer.inlineCallbacks async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
def get_state_ids_for_group(self, state_group):
"""Get the event IDs of all the state in the given state group """Get the event IDs of all the state in the given state group
Args: Args:
state_group (int) state_group: A state group for which we want to get the state IDs.
Returns: Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id Resolves to a map of (type, state_key) -> event_id
""" """
group_to_state = yield self._get_state_for_groups((state_group,)) group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group] return group_to_state[state_group]
@defer.inlineCallbacks async def get_state_groups(
def get_state_groups(self, room_id, event_ids): self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids """ Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns: Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events. dict of state_group_id -> list of state events.
""" """
if not event_ids: if not event_ids:
return {} return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids) group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
[ [
ev_id ev_id
for group_ids in group_to_ids.values() for group_ids in group_to_ids.values()
@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query groups: list of state group IDs to query
state_filter: The state filter used to fetch state state_filter: The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map. Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
""" """
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks async def get_state_for_events(
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()): self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
Args: Args:
event_ids (list[string]) event_ids: The events to fetch the state of.
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state.
from the database.
Returns: Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events] A dict of (event_id) -> (type, state_key) -> [state_events]
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
state_event_map = yield self.stores.main.get_events( state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()], [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False, get_prev_content=False,
) )
@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks async def get_state_ids_for_events(
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()): self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves) of the state events (as opposed to the events themselves)
Args: Args:
event_ids(list(str)): events whose state should be returned event_ids: events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from event_id -> (type, state_key) -> event_id A dict from event_id -> (type, state_key) -> event_id
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
@ -491,36 +498,36 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks async def get_state_for_event(
def get_state_for_event(self, event_id, state_filter=StateFilter.all()): self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id: event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_for_events([event_id], state_filter) state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
@defer.inlineCallbacks async def get_state_ids_for_event(
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()): self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
Args: Args:
event_id(str): event whose state should be returned event_id: event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state state_filter: The state filter used to fetch state from the database.
from the database.
Returns: Returns:
A deferred dict from (type, state_key) -> state_event A deferred dict from (type, state_key) -> state_event
""" """
state_map = yield self.get_state_ids_for_events([event_id], state_filter) state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id] return state_map[event_id]
def _get_state_for_groups( def _get_state_for_groups(
@ -530,9 +537,8 @@ class StateGroupStorage(object):
filtering by type/state_key filtering by type/state_key
Args: Args:
groups (iterable[int]): list of state groups for which we want groups: list of state groups for which we want to get the state.
to get the state. state_filter: The state filter used to fetch state.
state_filter (StateFilter): The state filter used to fetch state
from the database. from the database.
Returns: Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map. Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@ -540,18 +546,23 @@ class StateGroupStorage(object):
return self.stores.state._get_state_for_groups(groups, state_filter) return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group( def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
): ):
"""Store a new set of state, returning a newly assigned state group. """Store a new set of state, returning a newly assigned state group.
Args: Args:
event_id (str): The event ID for which the state was calculated event_id: The event ID for which the state was calculated.
room_id (str) room_id: ID of the room for which the state was calculated.
prev_group (int|None): A previous state group for the room, optional. prev_group: A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as `current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`. `current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key) current_state_ids: The state to store. Map of (type, state_key)
to event_id. to event_id.
Returns: Returns:

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
from synapse.rest.client.v1 import room from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event) event = self.successResultOf(event)
# Purge everything before this topological token # Purge everything before this topological token
purge = storage.purge_events.purge_history(self.room_id, event, True) purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump() self.pump()
self.assertEqual(self.successResultOf(purge), None) self.assertEqual(self.successResultOf(purge), None)
@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
) )
# Purge everything before this topological token # Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True) purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump() self.pump()
f = self.failureResultOf(purge) f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0]) self.assertIn("greater than forward", f.value.args[0])

View file

@ -97,9 +97,11 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def inject_room_event(self, **kwargs): def inject_room_event(self, **kwargs):
yield self.storage.persistence.persist_event( yield defer.ensureDeferred(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
) )
)
@defer.inlineCallbacks @defer.inlineCallbacks
def STALE_test_room_name(self): def STALE_test_room_name(self):

View file

@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event
@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.storage.state.get_state_groups_ids( state_group_map = yield defer.ensureDeferred(
self.room, [e2.event_id] self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0] state_map = list(state_group_map.values())[0]
@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"} self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
) )
state_group_map = yield self.storage.state.get_state_groups( state_group_map = yield defer.ensureDeferred(
self.room, [e2.event_id] self.storage.state.get_state_groups(self.room, [e2.event_id])
) )
self.assertEqual(len(state_group_map), 1) self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0] state_list = list(state_group_map.values())[0]
@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we get the full state as of the final event # check we get the full state as of the final event
state = yield self.storage.state.get_state_for_event(e5.event_id) state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)
self.assertIsNotNone(e4) self.assertIsNotNone(e4)
@ -164,23 +168,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check we can filter to the m.room.name event (with a '' state key) # check we can filter to the m.room.name event (with a '' state key)
state = yield self.storage.state.get_state_for_event( state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")]) e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
) )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key) # check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event( state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
) )
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state) self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key) # check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event( state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)]) e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
) )
)
self.assertStateMapEqual( self.assertStateMapEqual(
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state {(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
@ -188,13 +198,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the # check we can grab a specific room member without filtering out the
# other event types # other event types
state = yield self.storage.state.get_state_for_event( state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}}, types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True, include_others=True,
), ),
) )
)
self.assertStateMapEqual( self.assertStateMapEqual(
{ {
@ -206,12 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
) )
# check that we can grab everything except members # check that we can grab everything except members
state = yield self.storage.state.get_state_for_event( state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, e5.event_id,
state_filter=StateFilter( state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True types={EventTypes.Member: set()}, include_others=True
), ),
) )
)
self.assertStateMapEqual( self.assertStateMapEqual(
{(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state {(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
####################################################### #######################################################
room_id = self.room.to_string() room_id = self.room.to_string()
group_ids = yield self.storage.state.get_state_groups_ids( group_ids = yield defer.ensureDeferred(
room_id, [e5.event_id] self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
) )
group = list(group_ids.keys())[0] group = list(group_ids.keys())[0]

View file

@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore() self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage() self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM") yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks @defer.inlineCallbacks
def test_filtering(self): def test_filtering(self):
@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield defer.ensureDeferred( event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks
@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks
@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
yield self.storage.persistence.persist_event(event, context) yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event return event
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -638,14 +638,8 @@ class DeferredMockCallable(object):
) )
@defer.inlineCallbacks async def create_room(hs, room_id: str, creator_id: str):
def create_room(hs, room_id, creator_id):
"""Creates and persist a creation event for the given room """Creates and persist a creation event for the given room
Args:
hs
room_id (str)
creator_id (str)
""" """
persistence_store = hs.get_storage().persistence persistence_store = hs.get_storage().persistence
@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory() event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler() event_creation_handler = hs.get_event_creation_handler()
yield store.store_room( await store.store_room(
room_id=room_id, room_id=room_id,
room_creator_user_id=creator_id, room_creator_user_id=creator_id,
is_public=False, is_public=False,
@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id):
}, },
) )
event, context = yield defer.ensureDeferred( event, context = await event_creation_handler.create_new_client_event(builder)
event_creation_handler.create_new_client_event(builder)
)
yield persistence_store.persist_event(event, context) await persistence_store.persist_event(event, context)

View file

@ -206,6 +206,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \ synapse/storage/database.py \
synapse/storage/engines \ synapse/storage/engines \
synapse/storage/state.py \
synapse/storage/util \ synapse/storage/util \
synapse/streams \ synapse/streams \
synapse/util/caches/stream_change_cache.py \ synapse/util/caches/stream_change_cache.py \