mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 21:43:50 +01:00
Convert storage layer to async/await. (#7963)
This commit is contained in:
parent
e866e3b896
commit
3345c166a4
10 changed files with 210 additions and 185 deletions
1
changelog.d/7963.misc
Normal file
1
changelog.d/7963.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Convert various parts of the codebase to async/await.
|
|
@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import FrozenEvent
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
|
||||||
self._event_persist_queue = _EventPeristenceQueue()
|
self._event_persist_queue = _EventPeristenceQueue()
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def persist_events(
|
||||||
def persist_events(
|
|
||||||
self,
|
self,
|
||||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
backfilled: bool = False,
|
backfilled: bool = False,
|
||||||
):
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Write events to the database
|
Write events to the database
|
||||||
Args:
|
Args:
|
||||||
|
@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
|
||||||
which might update the current state etc.
|
which might update the current state etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[int]: the stream ordering of the latest persisted event
|
the stream ordering of the latest persisted event
|
||||||
"""
|
"""
|
||||||
partitioned = {}
|
partitioned = {}
|
||||||
for event, ctx in events_and_contexts:
|
for event, ctx in events_and_contexts:
|
||||||
|
@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
|
||||||
for room_id in partitioned:
|
for room_id in partitioned:
|
||||||
self._maybe_start_persisting(room_id)
|
self._maybe_start_persisting(room_id)
|
||||||
|
|
||||||
yield make_deferred_yieldable(
|
await make_deferred_yieldable(
|
||||||
defer.gatherResults(deferreds, consumeErrors=True)
|
defer.gatherResults(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
return self.main_store.get_current_events_token()
|
||||||
|
|
||||||
return max_persisted_id
|
async def persist_event(
|
||||||
|
self, event: EventBase, context: EventContext, backfilled: bool = False
|
||||||
@defer.inlineCallbacks
|
) -> Tuple[int, int]:
|
||||||
def persist_event(
|
|
||||||
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
|
The stream ordering of `event`, and the stream ordering of the
|
||||||
and the stream ordering of the latest persisted event
|
latest persisted event
|
||||||
"""
|
"""
|
||||||
deferred = self._event_persist_queue.add_to_queue(
|
deferred = self._event_persist_queue.add_to_queue(
|
||||||
event.room_id, [(event, context)], backfilled=backfilled
|
event.room_id, [(event, context)], backfilled=backfilled
|
||||||
|
@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
|
||||||
|
|
||||||
self._maybe_start_persisting(event.room_id)
|
self._maybe_start_persisting(event.room_id)
|
||||||
|
|
||||||
yield make_deferred_yieldable(deferred)
|
await make_deferred_yieldable(deferred)
|
||||||
|
|
||||||
max_persisted_id = yield self.main_store.get_current_events_token()
|
max_persisted_id = self.main_store.get_current_events_token()
|
||||||
return (event.internal_metadata.stream_ordering, max_persisted_id)
|
return (event.internal_metadata.stream_ordering, max_persisted_id)
|
||||||
|
|
||||||
def _maybe_start_persisting(self, room_id: str):
|
def _maybe_start_persisting(self, room_id: str):
|
||||||
|
@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
|
||||||
|
|
||||||
async def _persist_events(
|
async def _persist_events(
|
||||||
self,
|
self,
|
||||||
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
|
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
backfilled: bool = False,
|
backfilled: bool = False,
|
||||||
):
|
):
|
||||||
"""Calculates the change to current state and forward extremities, and
|
"""Calculates the change to current state and forward extremities, and
|
||||||
|
@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
|
||||||
async def _calculate_new_extremities(
|
async def _calculate_new_extremities(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_contexts: List[Tuple[FrozenEvent, EventContext]],
|
event_contexts: List[Tuple[EventBase, EventContext]],
|
||||||
latest_event_ids: List[str],
|
latest_event_ids: List[str],
|
||||||
):
|
):
|
||||||
"""Calculates the new forward extremities for a room given events to
|
"""Calculates the new forward extremities for a room given events to
|
||||||
|
@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
|
||||||
async def _get_new_state_after_events(
|
async def _get_new_state_after_events(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
events_context: List[Tuple[FrozenEvent, EventContext]],
|
events_context: List[Tuple[EventBase, EventContext]],
|
||||||
old_latest_event_ids: Iterable[str],
|
old_latest_event_ids: Iterable[str],
|
||||||
new_latest_event_ids: Iterable[str],
|
new_latest_event_ids: Iterable[str],
|
||||||
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
|
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
|
||||||
|
@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
|
||||||
async def _is_server_still_joined(
|
async def _is_server_still_joined(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
|
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
|
||||||
delta: DeltaState,
|
delta: DeltaState,
|
||||||
current_state: Optional[StateMap[str]],
|
current_state: Optional[StateMap[str]],
|
||||||
potentially_left_users: Set[str],
|
potentially_left_users: Set[str],
|
||||||
|
|
|
@ -15,8 +15,7 @@
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
|
||||||
def __init__(self, hs, stores):
|
def __init__(self, hs, stores):
|
||||||
self.stores = stores
|
self.stores = stores
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def purge_room(self, room_id: str):
|
||||||
def purge_room(self, room_id: str):
|
|
||||||
"""Deletes all record of a room
|
"""Deletes all record of a room
|
||||||
"""
|
"""
|
||||||
|
|
||||||
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
|
state_groups_to_delete = await self.stores.main.purge_room(room_id)
|
||||||
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def purge_history(
|
||||||
def purge_history(self, room_id, token, delete_local_events):
|
self, room_id: str, token: str, delete_local_events: bool
|
||||||
|
) -> None:
|
||||||
"""Deletes room history before a certain point
|
"""Deletes room history before a certain point
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str):
|
room_id: The room ID
|
||||||
|
|
||||||
token (str): A topological token to delete events before
|
token: A topological token to delete events before
|
||||||
|
|
||||||
delete_local_events (bool):
|
delete_local_events:
|
||||||
if True, we will delete local events as well as remote ones
|
if True, we will delete local events as well as remote ones
|
||||||
(instead of just marking them as outliers and deleting their
|
(instead of just marking them as outliers and deleting their
|
||||||
state groups).
|
state groups).
|
||||||
"""
|
"""
|
||||||
state_groups = yield self.stores.main.purge_history(
|
state_groups = await self.stores.main.purge_history(
|
||||||
room_id, token, delete_local_events
|
room_id, token, delete_local_events
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("[purge] finding state groups that can be deleted")
|
logger.info("[purge] finding state groups that can be deleted")
|
||||||
|
|
||||||
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
|
sg_to_delete = await self._find_unreferenced_groups(state_groups)
|
||||||
|
|
||||||
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
|
||||||
def _find_unreferenced_groups(self, state_groups):
|
|
||||||
"""Used when purging history to figure out which state groups can be
|
"""Used when purging history to figure out which state groups can be
|
||||||
deleted.
|
deleted.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_groups (set[int]): Set of state groups referenced by events
|
state_groups: Set of state groups referenced by events
|
||||||
that are going to be deleted.
|
that are going to be deleted.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[set[int]] The set of state groups that can be deleted.
|
The set of state groups that can be deleted.
|
||||||
"""
|
"""
|
||||||
# Graph of state group -> previous group
|
# Graph of state group -> previous group
|
||||||
graph = {}
|
graph = {}
|
||||||
|
@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
|
||||||
current_search = set(itertools.islice(next_to_search, 100))
|
current_search = set(itertools.islice(next_to_search, 100))
|
||||||
next_to_search -= current_search
|
next_to_search -= current_search
|
||||||
|
|
||||||
referenced = yield self.stores.main.get_referenced_state_groups(
|
referenced = await self.stores.main.get_referenced_state_groups(
|
||||||
current_search
|
current_search
|
||||||
)
|
)
|
||||||
referenced_groups |= referenced
|
referenced_groups |= referenced
|
||||||
|
@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
|
||||||
# groups that are referenced.
|
# groups that are referenced.
|
||||||
current_search -= referenced
|
current_search -= referenced
|
||||||
|
|
||||||
edges = yield self.stores.state.get_previous_state_groups(current_search)
|
edges = await self.stores.state.get_previous_state_groups(current_search)
|
||||||
|
|
||||||
prevs = set(edges.values())
|
prevs = set(edges.values())
|
||||||
# We don't bother re-handling groups we've already seen
|
# We don't bother re-handling groups we've already seen
|
||||||
|
|
|
@ -14,13 +14,12 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Iterable, List, TypeVar
|
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
|
from synapse.events import EventBase
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -34,16 +33,16 @@ class StateFilter(object):
|
||||||
"""A filter used when querying for state.
|
"""A filter used when querying for state.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
types (dict[str, set[str]|None]): Map from type to set of state keys (or
|
types: Map from type to set of state keys (or None). This specifies
|
||||||
None). This specifies which state_keys for the given type to fetch
|
which state_keys for the given type to fetch from the DB. If None
|
||||||
from the DB. If None then all events with that type are fetched. If
|
then all events with that type are fetched. If the set is empty
|
||||||
the set is empty then no events with that type are fetched.
|
then no events with that type are fetched.
|
||||||
include_others (bool): Whether to fetch events with types that do not
|
include_others: Whether to fetch events with types that do not
|
||||||
appear in `types`.
|
appear in `types`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
types = attr.ib()
|
types = attr.ib(type=Dict[str, Optional[Set[str]]])
|
||||||
include_others = attr.ib(default=False)
|
include_others = attr.ib(default=False, type=bool)
|
||||||
|
|
||||||
def __attrs_post_init__(self):
|
def __attrs_post_init__(self):
|
||||||
# If `include_others` is set we canonicalise the filter by removing
|
# If `include_others` is set we canonicalise the filter by removing
|
||||||
|
@ -52,36 +51,35 @@ class StateFilter(object):
|
||||||
self.types = {k: v for k, v in self.types.items() if v is not None}
|
self.types = {k: v for k, v in self.types.items() if v is not None}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def all():
|
def all() -> "StateFilter":
|
||||||
"""Creates a filter that fetches everything.
|
"""Creates a filter that fetches everything.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StateFilter
|
The new state filter.
|
||||||
"""
|
"""
|
||||||
return StateFilter(types={}, include_others=True)
|
return StateFilter(types={}, include_others=True)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def none():
|
def none() -> "StateFilter":
|
||||||
"""Creates a filter that fetches nothing.
|
"""Creates a filter that fetches nothing.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StateFilter
|
The new state filter.
|
||||||
"""
|
"""
|
||||||
return StateFilter(types={}, include_others=False)
|
return StateFilter(types={}, include_others=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_types(types):
|
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
|
||||||
"""Creates a filter that only fetches the given types
|
"""Creates a filter that only fetches the given types
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
types (Iterable[tuple[str, str|None]]): A list of type and state
|
types: A list of type and state keys to fetch. A state_key of None
|
||||||
keys to fetch. A state_key of None fetches everything for
|
fetches everything for that type
|
||||||
that type
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StateFilter
|
The new state filter.
|
||||||
"""
|
"""
|
||||||
type_dict = {}
|
type_dict = {} # type: Dict[str, Optional[Set[str]]]
|
||||||
for typ, s in types:
|
for typ, s in types:
|
||||||
if typ in type_dict:
|
if typ in type_dict:
|
||||||
if type_dict[typ] is None:
|
if type_dict[typ] is None:
|
||||||
|
@ -91,24 +89,24 @@ class StateFilter(object):
|
||||||
type_dict[typ] = None
|
type_dict[typ] = None
|
||||||
continue
|
continue
|
||||||
|
|
||||||
type_dict.setdefault(typ, set()).add(s)
|
type_dict.setdefault(typ, set()).add(s) # type: ignore
|
||||||
|
|
||||||
return StateFilter(types=type_dict)
|
return StateFilter(types=type_dict)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_lazy_load_member_list(members):
|
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||||
"""Creates a filter that returns all non-member events, plus the member
|
"""Creates a filter that returns all non-member events, plus the member
|
||||||
events for the given users
|
events for the given users
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
members (iterable[str]): Set of user IDs
|
members: Set of user IDs
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StateFilter
|
The new state filter
|
||||||
"""
|
"""
|
||||||
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
|
||||||
|
|
||||||
def return_expanded(self):
|
def return_expanded(self) -> "StateFilter":
|
||||||
"""Creates a new StateFilter where type wild cards have been removed
|
"""Creates a new StateFilter where type wild cards have been removed
|
||||||
(except for memberships). The returned filter is a superset of the
|
(except for memberships). The returned filter is a superset of the
|
||||||
current one, i.e. anything that passes the current filter will pass
|
current one, i.e. anything that passes the current filter will pass
|
||||||
|
@ -130,7 +128,7 @@ class StateFilter(object):
|
||||||
return all non-member events
|
return all non-member events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
StateFilter
|
The new state filter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self.is_full():
|
if self.is_full():
|
||||||
|
@ -167,7 +165,7 @@ class StateFilter(object):
|
||||||
include_others=True,
|
include_others=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_sql_filter_clause(self):
|
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
|
||||||
"""Converts the filter to an SQL clause.
|
"""Converts the filter to an SQL clause.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
@ -179,13 +177,12 @@ class StateFilter(object):
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[str, list]: The SQL string (may be empty) and arguments. An
|
The SQL string (may be empty) and arguments. An empty SQL string is
|
||||||
empty SQL string is returned when the filter matches everything
|
returned when the filter matches everything (i.e. is "full").
|
||||||
(i.e. is "full").
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
where_args = []
|
where_args = [] # type: List[str]
|
||||||
|
|
||||||
if self.is_full():
|
if self.is_full():
|
||||||
return where_clause, where_args
|
return where_clause, where_args
|
||||||
|
@ -221,7 +218,7 @@ class StateFilter(object):
|
||||||
|
|
||||||
return where_clause, where_args
|
return where_clause, where_args
|
||||||
|
|
||||||
def max_entries_returned(self):
|
def max_entries_returned(self) -> Optional[int]:
|
||||||
"""Returns the maximum number of entries this filter will return if
|
"""Returns the maximum number of entries this filter will return if
|
||||||
known, otherwise returns None.
|
known, otherwise returns None.
|
||||||
|
|
||||||
|
@ -260,33 +257,33 @@ class StateFilter(object):
|
||||||
|
|
||||||
return filtered_state
|
return filtered_state
|
||||||
|
|
||||||
def is_full(self):
|
def is_full(self) -> bool:
|
||||||
"""Whether this filter fetches everything or not
|
"""Whether this filter fetches everything or not
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool
|
True if the filter fetches everything.
|
||||||
"""
|
"""
|
||||||
return self.include_others and not self.types
|
return self.include_others and not self.types
|
||||||
|
|
||||||
def has_wildcards(self):
|
def has_wildcards(self) -> bool:
|
||||||
"""Whether the filter includes wildcards or is attempting to fetch
|
"""Whether the filter includes wildcards or is attempting to fetch
|
||||||
specific state.
|
specific state.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool
|
True if the filter includes wildcards.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.include_others or any(
|
return self.include_others or any(
|
||||||
state_keys is None for state_keys in self.types.values()
|
state_keys is None for state_keys in self.types.values()
|
||||||
)
|
)
|
||||||
|
|
||||||
def concrete_types(self):
|
def concrete_types(self) -> List[Tuple[str, str]]:
|
||||||
"""Returns a list of concrete type/state_keys (i.e. not None) that
|
"""Returns a list of concrete type/state_keys (i.e. not None) that
|
||||||
will be fetched. This will be a complete list if `has_wildcards`
|
will be fetched. This will be a complete list if `has_wildcards`
|
||||||
returns False, but otherwise will be a subset (or even empty).
|
returns False, but otherwise will be a subset (or even empty).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[tuple[str,str]]
|
A list of type/state_keys tuples.
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
(t, s)
|
(t, s)
|
||||||
|
@ -295,7 +292,7 @@ class StateFilter(object):
|
||||||
for s in state_keys
|
for s in state_keys
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_member_split(self):
|
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
||||||
"""Return the filter split into two: one which assumes it's exclusively
|
"""Return the filter split into two: one which assumes it's exclusively
|
||||||
matching against member state, and one which assumes it's matching
|
matching against member state, and one which assumes it's matching
|
||||||
against non member state.
|
against non member state.
|
||||||
|
@ -307,7 +304,7 @@ class StateFilter(object):
|
||||||
state caches).
|
state caches).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[StateFilter, StateFilter]: The member and non member filters
|
The member and non member filters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if EventTypes.Member in self.types:
|
if EventTypes.Member in self.types:
|
||||||
|
@ -340,6 +337,9 @@ class StateGroupStorage(object):
|
||||||
"""Given a state group try to return a previous group and a delta between
|
"""Given a state group try to return a previous group and a delta between
|
||||||
the old and the new.
|
the old and the new.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_group: The state group used to retrieve state deltas.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
|
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
|
||||||
(prev_group, delta_ids)
|
(prev_group, delta_ids)
|
||||||
|
@ -347,55 +347,59 @@ class StateGroupStorage(object):
|
||||||
|
|
||||||
return self.stores.state.get_state_group_delta(state_group)
|
return self.stores.state.get_state_group_delta(state_group)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_groups_ids(
|
||||||
def get_state_groups_ids(self, _room_id, event_ids):
|
self, _room_id: str, event_ids: Iterable[str]
|
||||||
|
) -> Dict[int, StateMap[str]]:
|
||||||
"""Get the event IDs of all the state for the state groups for the given events
|
"""Get the event IDs of all the state for the state groups for the given events
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
_room_id (str): id of the room for these events
|
_room_id: id of the room for these events
|
||||||
event_ids (iterable[str]): ids of the events
|
event_ids: ids of the events
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[int, StateMap[str]]]:
|
|
||||||
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
dict of state_group_id -> (dict of (type, state_key) -> event id)
|
||||||
"""
|
"""
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = yield self.stores.state._get_state_for_groups(groups)
|
group_to_state = await self.stores.state._get_state_for_groups(groups)
|
||||||
|
|
||||||
return group_to_state
|
return group_to_state
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
|
||||||
def get_state_ids_for_group(self, state_group):
|
|
||||||
"""Get the event IDs of all the state in the given state group
|
"""Get the event IDs of all the state in the given state group
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state_group (int)
|
state_group: A state group for which we want to get the state IDs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
|
Resolves to a map of (type, state_key) -> event_id
|
||||||
"""
|
"""
|
||||||
group_to_state = yield self._get_state_for_groups((state_group,))
|
group_to_state = await self._get_state_for_groups((state_group,))
|
||||||
|
|
||||||
return group_to_state[state_group]
|
return group_to_state[state_group]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_groups(
|
||||||
def get_state_groups(self, room_id, event_ids):
|
self, room_id: str, event_ids: Iterable[str]
|
||||||
|
) -> Dict[int, List[EventBase]]:
|
||||||
""" Get the state groups for the given list of event_ids
|
""" Get the state groups for the given list of event_ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_id: ID of the room for these events.
|
||||||
|
event_ids: The event IDs to retrieve state for.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[int, list[EventBase]]]:
|
|
||||||
dict of state_group_id -> list of state events.
|
dict of state_group_id -> list of state events.
|
||||||
"""
|
"""
|
||||||
if not event_ids:
|
if not event_ids:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
|
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
|
||||||
|
|
||||||
state_event_map = yield self.stores.main.get_events(
|
state_event_map = await self.stores.main.get_events(
|
||||||
[
|
[
|
||||||
ev_id
|
ev_id
|
||||||
for group_ids in group_to_ids.values()
|
for group_ids in group_to_ids.values()
|
||||||
|
@ -423,31 +427,34 @@ class StateGroupStorage(object):
|
||||||
groups: list of state group IDs to query
|
groups: list of state group IDs to query
|
||||||
state_filter: The state filter used to fetch state
|
state_filter: The state filter used to fetch state
|
||||||
from the database.
|
from the database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_for_events(
|
||||||
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||||
|
):
|
||||||
"""Given a list of event_ids and type tuples, return a list of state
|
"""Given a list of event_ids and type tuples, return a list of state
|
||||||
dicts for each event.
|
dicts for each event.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids (list[string])
|
event_ids: The events to fetch the state of.
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
state_filter: The state filter used to fetch state.
|
||||||
from the database.
|
|
||||||
Returns:
|
Returns:
|
||||||
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
|
A dict of (event_id) -> (type, state_key) -> [state_events]
|
||||||
"""
|
"""
|
||||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
groups, state_filter
|
groups, state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
state_event_map = yield self.stores.main.get_events(
|
state_event_map = await self.stores.main.get_events(
|
||||||
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
|
||||||
get_prev_content=False,
|
get_prev_content=False,
|
||||||
)
|
)
|
||||||
|
@ -463,24 +470,24 @@ class StateGroupStorage(object):
|
||||||
|
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_ids_for_events(
|
||||||
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
|
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||||
of the state events (as opposed to the events themselves)
|
of the state events (as opposed to the events themselves)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_ids(list(str)): events whose state should be returned
|
event_ids: events whose state should be returned
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
state_filter: The state filter used to fetch state from the database.
|
||||||
from the database.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict from event_id -> (type, state_key) -> event_id
|
A dict from event_id -> (type, state_key) -> event_id
|
||||||
"""
|
"""
|
||||||
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
|
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
|
||||||
|
|
||||||
groups = set(event_to_groups.values())
|
groups = set(event_to_groups.values())
|
||||||
group_to_state = yield self.stores.state._get_state_for_groups(
|
group_to_state = await self.stores.state._get_state_for_groups(
|
||||||
groups, state_filter
|
groups, state_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -491,36 +498,36 @@ class StateGroupStorage(object):
|
||||||
|
|
||||||
return {event: event_to_state[event] for event in event_ids}
|
return {event: event_to_state[event] for event in event_ids}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_for_event(
|
||||||
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
|
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id(str): event whose state should be returned
|
event_id: event whose state should be returned
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
state_filter: The state filter used to fetch state from the database.
|
||||||
from the database.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict from (type, state_key) -> state_event
|
A dict from (type, state_key) -> state_event
|
||||||
"""
|
"""
|
||||||
state_map = yield self.get_state_for_events([event_id], state_filter)
|
state_map = await self.get_state_for_events([event_id], state_filter)
|
||||||
return state_map[event_id]
|
return state_map[event_id]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def get_state_ids_for_event(
|
||||||
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
|
self, event_id: str, state_filter: StateFilter = StateFilter.all()
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get the state dict corresponding to a particular event
|
Get the state dict corresponding to a particular event
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id(str): event whose state should be returned
|
event_id: event whose state should be returned
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
state_filter: The state filter used to fetch state from the database.
|
||||||
from the database.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A deferred dict from (type, state_key) -> state_event
|
A deferred dict from (type, state_key) -> state_event
|
||||||
"""
|
"""
|
||||||
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
|
state_map = await self.get_state_ids_for_events([event_id], state_filter)
|
||||||
return state_map[event_id]
|
return state_map[event_id]
|
||||||
|
|
||||||
def _get_state_for_groups(
|
def _get_state_for_groups(
|
||||||
|
@ -530,9 +537,8 @@ class StateGroupStorage(object):
|
||||||
filtering by type/state_key
|
filtering by type/state_key
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
groups (iterable[int]): list of state groups for which we want
|
groups: list of state groups for which we want to get the state.
|
||||||
to get the state.
|
state_filter: The state filter used to fetch state.
|
||||||
state_filter (StateFilter): The state filter used to fetch state
|
|
||||||
from the database.
|
from the database.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
|
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
|
||||||
|
@ -540,18 +546,23 @@ class StateGroupStorage(object):
|
||||||
return self.stores.state._get_state_for_groups(groups, state_filter)
|
return self.stores.state._get_state_for_groups(groups, state_filter)
|
||||||
|
|
||||||
def store_state_group(
|
def store_state_group(
|
||||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
self,
|
||||||
|
event_id: str,
|
||||||
|
room_id: str,
|
||||||
|
prev_group: Optional[int],
|
||||||
|
delta_ids: Optional[dict],
|
||||||
|
current_state_ids: dict,
|
||||||
):
|
):
|
||||||
"""Store a new set of state, returning a newly assigned state group.
|
"""Store a new set of state, returning a newly assigned state group.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
event_id (str): The event ID for which the state was calculated
|
event_id: The event ID for which the state was calculated.
|
||||||
room_id (str)
|
room_id: ID of the room for which the state was calculated.
|
||||||
prev_group (int|None): A previous state group for the room, optional.
|
prev_group: A previous state group for the room, optional.
|
||||||
delta_ids (dict|None): The delta between state at `prev_group` and
|
delta_ids: The delta between state at `prev_group` and
|
||||||
`current_state_ids`, if `prev_group` was given. Same format as
|
`current_state_ids`, if `prev_group` was given. Same format as
|
||||||
`current_state_ids`.
|
`current_state_ids`.
|
||||||
current_state_ids (dict): The state to store. Map of (type, state_key)
|
current_state_ids: The state to store. Map of (type, state_key)
|
||||||
to event_id.
|
to event_id.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -13,6 +13,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.rest.client.v1 import room
|
from synapse.rest.client.v1 import room
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
|
||||||
event = self.successResultOf(event)
|
event = self.successResultOf(event)
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
purge = storage.purge_events.purge_history(self.room_id, event, True)
|
purge = defer.ensureDeferred(
|
||||||
|
storage.purge_events.purge_history(self.room_id, event, True)
|
||||||
|
)
|
||||||
self.pump()
|
self.pump()
|
||||||
self.assertEqual(self.successResultOf(purge), None)
|
self.assertEqual(self.successResultOf(purge), None)
|
||||||
|
|
||||||
|
@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Purge everything before this topological token
|
# Purge everything before this topological token
|
||||||
purge = storage.purge_history(self.room_id, event, True)
|
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
|
||||||
self.pump()
|
self.pump()
|
||||||
f = self.failureResultOf(purge)
|
f = self.failureResultOf(purge)
|
||||||
self.assertIn("greater than forward", f.value.args[0])
|
self.assertIn("greater than forward", f.value.args[0])
|
||||||
|
|
|
@ -97,9 +97,11 @@ class RoomEventsStoreTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def inject_room_event(self, **kwargs):
|
def inject_room_event(self, **kwargs):
|
||||||
yield self.storage.persistence.persist_event(
|
yield defer.ensureDeferred(
|
||||||
|
self.storage.persistence.persist_event(
|
||||||
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def STALE_test_room_name(self):
|
def STALE_test_room_name(self):
|
||||||
|
|
|
@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield defer.ensureDeferred(
|
||||||
|
self.storage.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = yield self.storage.state.get_state_groups_ids(
|
state_group_map = yield defer.ensureDeferred(
|
||||||
self.room, [e2.event_id]
|
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
|
||||||
)
|
)
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_map = list(state_group_map.values())[0]
|
state_map = list(state_group_map.values())[0]
|
||||||
|
@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
|
||||||
)
|
)
|
||||||
|
|
||||||
state_group_map = yield self.storage.state.get_state_groups(
|
state_group_map = yield defer.ensureDeferred(
|
||||||
self.room, [e2.event_id]
|
self.storage.state.get_state_groups(self.room, [e2.event_id])
|
||||||
)
|
)
|
||||||
self.assertEqual(len(state_group_map), 1)
|
self.assertEqual(len(state_group_map), 1)
|
||||||
state_list = list(state_group_map.values())[0]
|
state_list = list(state_group_map.values())[0]
|
||||||
|
@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check we get the full state as of the final event
|
# check we get the full state as of the final event
|
||||||
state = yield self.storage.state.get_state_for_event(e5.event_id)
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(e5.event_id)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertIsNotNone(e4)
|
self.assertIsNotNone(e4)
|
||||||
|
|
||||||
|
@ -164,23 +168,29 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check we can filter to the m.room.name event (with a '' state key)
|
# check we can filter to the m.room.name event (with a '' state key)
|
||||||
state = yield self.storage.state.get_state_for_event(
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(
|
||||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
|
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
||||||
|
|
||||||
# check we can filter to the m.room.name event (with a wildcard None state key)
|
# check we can filter to the m.room.name event (with a wildcard None state key)
|
||||||
state = yield self.storage.state.get_state_for_event(
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(
|
||||||
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
|
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
|
||||||
|
|
||||||
# check we can grab the m.room.member events (with a wildcard None state key)
|
# check we can grab the m.room.member events (with a wildcard None state key)
|
||||||
state = yield self.storage.state.get_state_for_event(
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(
|
||||||
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
|
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual(
|
self.assertStateMapEqual(
|
||||||
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
|
{(e3.type, e3.state_key): e3, (e5.type, e5.state_key): e5}, state
|
||||||
|
@ -188,13 +198,15 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
|
|
||||||
# check we can grab a specific room member without filtering out the
|
# check we can grab a specific room member without filtering out the
|
||||||
# other event types
|
# other event types
|
||||||
state = yield self.storage.state.get_state_for_event(
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(
|
||||||
e5.event_id,
|
e5.event_id,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: {self.u_alice.to_string()}},
|
types={EventTypes.Member: {self.u_alice.to_string()}},
|
||||||
include_others=True,
|
include_others=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual(
|
self.assertStateMapEqual(
|
||||||
{
|
{
|
||||||
|
@ -206,12 +218,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# check that we can grab everything except members
|
# check that we can grab everything except members
|
||||||
state = yield self.storage.state.get_state_for_event(
|
state = yield defer.ensureDeferred(
|
||||||
|
self.storage.state.get_state_for_event(
|
||||||
e5.event_id,
|
e5.event_id,
|
||||||
state_filter=StateFilter(
|
state_filter=StateFilter(
|
||||||
types={EventTypes.Member: set()}, include_others=True
|
types={EventTypes.Member: set()}, include_others=True
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
self.assertStateMapEqual(
|
self.assertStateMapEqual(
|
||||||
{(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
|
{(e1.type, e1.state_key): e1, (e2.type, e2.state_key): e2}, state
|
||||||
|
@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
|
||||||
#######################################################
|
#######################################################
|
||||||
|
|
||||||
room_id = self.room.to_string()
|
room_id = self.room.to_string()
|
||||||
group_ids = yield self.storage.state.get_state_groups_ids(
|
group_ids = yield defer.ensureDeferred(
|
||||||
room_id, [e5.event_id]
|
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
|
||||||
)
|
)
|
||||||
group = list(group_ids.keys())[0]
|
group = list(group_ids.keys())[0]
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
self.store = self.hs.get_datastore()
|
self.store = self.hs.get_datastore()
|
||||||
self.storage = self.hs.get_storage()
|
self.storage = self.hs.get_storage()
|
||||||
|
|
||||||
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
|
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_filtering(self):
|
def test_filtering(self):
|
||||||
|
@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
event, context = yield defer.ensureDeferred(
|
event, context = yield defer.ensureDeferred(
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield defer.ensureDeferred(
|
||||||
|
self.storage.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield defer.ensureDeferred(
|
||||||
|
self.storage.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||||
self.event_creation_handler.create_new_client_event(builder)
|
self.event_creation_handler.create_new_client_event(builder)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield self.storage.persistence.persist_event(event, context)
|
yield defer.ensureDeferred(
|
||||||
|
self.storage.persistence.persist_event(event, context)
|
||||||
|
)
|
||||||
return event
|
return event
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -638,14 +638,8 @@ class DeferredMockCallable(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
async def create_room(hs, room_id: str, creator_id: str):
|
||||||
def create_room(hs, room_id, creator_id):
|
|
||||||
"""Creates and persist a creation event for the given room
|
"""Creates and persist a creation event for the given room
|
||||||
|
|
||||||
Args:
|
|
||||||
hs
|
|
||||||
room_id (str)
|
|
||||||
creator_id (str)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
persistence_store = hs.get_storage().persistence
|
persistence_store = hs.get_storage().persistence
|
||||||
|
@ -653,7 +647,7 @@ def create_room(hs, room_id, creator_id):
|
||||||
event_builder_factory = hs.get_event_builder_factory()
|
event_builder_factory = hs.get_event_builder_factory()
|
||||||
event_creation_handler = hs.get_event_creation_handler()
|
event_creation_handler = hs.get_event_creation_handler()
|
||||||
|
|
||||||
yield store.store_room(
|
await store.store_room(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
room_creator_user_id=creator_id,
|
room_creator_user_id=creator_id,
|
||||||
is_public=False,
|
is_public=False,
|
||||||
|
@ -671,8 +665,6 @@ def create_room(hs, room_id, creator_id):
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
event, context = yield defer.ensureDeferred(
|
event, context = await event_creation_handler.create_new_client_event(builder)
|
||||||
event_creation_handler.create_new_client_event(builder)
|
|
||||||
)
|
|
||||||
|
|
||||||
yield persistence_store.persist_event(event, context)
|
await persistence_store.persist_event(event, context)
|
||||||
|
|
1
tox.ini
1
tox.ini
|
@ -206,6 +206,7 @@ commands = mypy \
|
||||||
synapse/storage/data_stores/main/ui_auth.py \
|
synapse/storage/data_stores/main/ui_auth.py \
|
||||||
synapse/storage/database.py \
|
synapse/storage/database.py \
|
||||||
synapse/storage/engines \
|
synapse/storage/engines \
|
||||||
|
synapse/storage/state.py \
|
||||||
synapse/storage/util \
|
synapse/storage/util \
|
||||||
synapse/streams \
|
synapse/streams \
|
||||||
synapse/util/caches/stream_change_cache.py \
|
synapse/util/caches/stream_change_cache.py \
|
||||||
|
|
Loading…
Reference in a new issue