0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-10 12:02:43 +01:00

Convert state resolution to async/await (#7942)

This commit is contained in:
Patrick Cloke 2020-07-24 10:59:51 -04:00 committed by GitHub
parent e739b20588
commit b975fa2e99
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 198 additions and 184 deletions

1
changelog.d/7942.misc Normal file
View file

@ -0,0 +1 @@
Convert state resolution to async/await.

View file

@ -127,8 +127,10 @@ class Auth(object):
if current_state: if current_state:
member = current_state.get((EventTypes.Member, user_id), None) member = current_state.get((EventTypes.Member, user_id), None)
else: else:
member = yield self.state.get_current_state( member = yield defer.ensureDeferred(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
) )
membership = member.membership if member else None membership = member.membership if member else None
@ -665,8 +667,10 @@ class Auth(object):
) )
return member_event.membership, member_event.event_id return member_event.membership, member_event.event_id
except AuthError: except AuthError:
visibility = yield self.state.get_current_state( visibility = yield defer.ensureDeferred(
room_id, EventTypes.RoomHistoryVisibility, "" self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
) )
if ( if (
visibility visibility

View file

@ -106,8 +106,8 @@ class EventBuilder(object):
Deferred[FrozenEvent] Deferred[FrozenEvent]
""" """
state_ids = yield self._state.get_current_state_ids( state_ids = yield defer.ensureDeferred(
self.room_id, prev_event_ids self._state.get_current_state_ids(self.room_id, prev_event_ids)
) )
auth_ids = yield self._auth.compute_auth_events(self, state_ids) auth_ids = yield self._auth.compute_auth_events(self, state_ids)

View file

@ -330,7 +330,9 @@ class FederationSender(object):
room_id = receipt.room_id room_id = receipt.room_id
# Work out which remote servers should be poked and poke them. # Work out which remote servers should be poked and poke them.
domains = yield self.state.get_current_hosts_in_room(room_id) domains = yield defer.ensureDeferred(
self.state.get_current_hosts_in_room(room_id)
)
domains = [ domains = [
d d
for d in domains for d in domains

View file

@ -928,8 +928,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the # TODO: Check that this is actually a new server joining the
# room. # room.
user_ids = await self.state.get_current_users_in_room(room_id) users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids)) user_ids = list(filter(self.is_mine_id, users))
states_d = await self.current_state_for_users(user_ids) states_d = await self.current_state_for_users(user_ids)

View file

@ -304,7 +304,9 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits() push_rules_delta_state_cache_metric.inc_hits()
else: else:
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses() push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids)) push_rules_state_size_counter.inc(len(current_state_ids))

View file

@ -16,14 +16,12 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set from typing import Awaitable, Dict, Iterable, List, Optional, Set
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from prometheus_client import Histogram from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase from synapse.events import EventBase
@ -31,6 +29,7 @@ from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function from synapse.logging.utils import log_function
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util import Clock from synapse.util import Clock
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
@ -108,8 +107,7 @@ class StateHandler(object):
self.hs = hs self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks async def get_current_state(
def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None self, room_id, event_type=None, state_key="", latest_event_ids=None
): ):
""" Retrieves the current state for the room. This is done by """ Retrieves the current state for the room. This is done by
@ -126,20 +124,20 @@ class StateHandler(object):
map from (type, state_key) to event map from (type, state_key) to event
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
if event_type: if event_type:
event_id = state.get((event_type, state_key)) event_id = state.get((event_type, state_key))
event = None event = None
if event_id: if event_id:
event = yield self.store.get_event(event_id, allow_none=True) event = await self.store.get_event(event_id, allow_none=True)
return event return event
state_map = yield self.store.get_events( state_map = await self.store.get_events(
list(state.values()), get_prev_content=False list(state.values()), get_prev_content=False
) )
state = { state = {
@ -148,8 +146,7 @@ class StateHandler(object):
return state return state
@defer.inlineCallbacks async def get_current_state_ids(self, room_id, latest_event_ids=None):
def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room """Get the current state, or the state at a set of events, for a room
Args: Args:
@ -164,41 +161,38 @@ class StateHandler(object):
(event_type, state_key) -> event_id (event_type, state_key) -> event_id
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state state = ret.state
return state return state
@defer.inlineCallbacks async def get_current_users_in_room(
def get_current_users_in_room(self, room_id, latest_event_ids=None): self, room_id: str, latest_event_ids: Optional[List[str]] = None
) -> Dict[str, ProfileInfo]:
""" """
Get the users who are currently in a room. Get the users who are currently in a room.
Args: Args:
room_id (str): The ID of the room. room_id: The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
event IDs. Will be computed if None.
Returns: Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their Dictionary of user IDs to their profileinfo.
profileinfo.
""" """
if not latest_event_ids: if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room") logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids) entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry) joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users return joined_users
@defer.inlineCallbacks async def get_current_hosts_in_room(self, room_id):
def get_current_hosts_in_room(self, room_id): event_ids = await self.store.get_latest_event_ids_in_room(room_id)
event_ids = yield self.store.get_latest_event_ids_in_room(room_id) return await self.get_hosts_in_room_at_events(room_id, event_ids)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
@defer.inlineCallbacks async def get_hosts_in_room_at_events(self, room_id, event_ids):
def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids """Get the hosts that were in a room at the given event ids
Args: Args:
@ -208,12 +202,11 @@ class StateHandler(object):
Returns: Returns:
Deferred[list[str]]: the hosts in the room at the given events Deferred[list[str]]: the hosts in the room at the given events
""" """
entry = yield self.resolve_state_groups_for_events(room_id, event_ids) entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry) joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts return joined_hosts
@defer.inlineCallbacks async def compute_event_context(
def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
): ):
"""Build an EventContext structure for the event. """Build an EventContext structure for the event.
@ -278,7 +271,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events( entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids() event.room_id, event.prev_event_ids()
) )
@ -295,7 +288,7 @@ class StateHandler(object):
# #
if not state_group_before_event: if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group( state_group_before_event = await self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=state_group_before_event_prev_group, prev_group=state_group_before_event_prev_group,
@ -335,7 +328,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id} delta_ids = {key: event.event_id}
state_group_after_event = yield self.state_store.store_state_group( state_group_after_event = await self.state_store.store_state_group(
event.event_id, event.event_id,
event.room_id, event.room_id,
prev_group=state_group_before_event, prev_group=state_group_before_event,
@ -353,8 +346,7 @@ class StateHandler(object):
) )
@measure_func() @measure_func()
@defer.inlineCallbacks async def resolve_state_groups_for_events(self, room_id, event_ids):
def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each """ Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@ -373,7 +365,7 @@ class StateHandler(object):
# map from state group id to the state in that state group (where # map from state group id to the state in that state group (where
# 'state' is a map from state key to event id) # 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]] # dict[int, dict[(str, str), str]]
state_groups_ids = yield self.state_store.get_state_groups_ids( state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids room_id, event_ids
) )
@ -382,7 +374,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1: elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop() name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.state_store.get_state_group_delta(name) prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry( return _StateCacheEntry(
state=state_list, state=state_list,
@ -391,9 +383,9 @@ class StateHandler(object):
delta_ids=delta_ids, delta_ids=delta_ids,
) )
room_version = yield self.store.get_room_version_id(room_id) room_version = await self.store.get_room_version_id(room_id)
result = yield self._state_resolution_handler.resolve_state_groups( result = await self._state_resolution_handler.resolve_state_groups(
room_id, room_id,
room_version, room_version,
state_groups_ids, state_groups_ids,
@ -402,8 +394,7 @@ class StateHandler(object):
) )
return result return result
@defer.inlineCallbacks async def resolve_events(self, room_version, state_sets, event):
def resolve_events(self, room_version, state_sets, event):
logger.info( logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets) "Resolving state for %s with %d groups", event.room_id, len(state_sets)
) )
@ -414,7 +405,7 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st} state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store( new_state = await resolve_events_with_store(
self.clock, self.clock,
event.room_id, event.room_id,
room_version, room_version,
@ -451,9 +442,8 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True, reset_expiry_on_get=True,
) )
@defer.inlineCallbacks
@log_function @log_function
def resolve_state_groups( async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store self, room_id, room_version, state_groups_ids, event_map, state_res_store
): ):
"""Resolves conflicts between a set of state groups """Resolves conflicts between a set of state groups
@ -479,13 +469,13 @@ class StateResolutionHandler(object):
state_res_store (StateResolutionStore) state_res_store (StateResolutionStore)
Returns: Returns:
Deferred[_StateCacheEntry]: resolved state _StateCacheEntry: resolved state
""" """
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys()) logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys()) group_names = frozenset(state_groups_ids.keys())
with (yield self.resolve_linearizer.queue(group_names)): with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None: if self._state_cache is not None:
cache = self._state_cache.get(group_names, None) cache = self._state_cache.get(group_names, None)
if cache: if cache:
@ -517,7 +507,7 @@ class StateResolutionHandler(object):
if conflicted_state: if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id) logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"): with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store( new_state = await resolve_events_with_store(
self.clock, self.clock,
room_id, room_id,
room_version, room_version,
@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]], state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore", state_res_store: "StateResolutionStore",
): ) -> Awaitable[StateMap[str]]:
""" """
Args: Args:
room_id: the room we are working in room_id: the room we are working in
@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from state_res_store: a place to fetch events from
Returns: Returns:
Deferred[dict[(str, str), str]]: a map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
""" """
v = KNOWN_ROOM_VERSIONS[room_version] v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1: if v.state_res == StateResolutionVersions.V1:

View file

@ -15,9 +15,7 @@
import hashlib import hashlib
import logging import logging
from typing import Callable, Dict, List, Optional from typing import Awaitable, Callable, Dict, List, Optional
from twisted.internet import defer
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -32,12 +30,11 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "") POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks async def resolve_events_with_store(
def resolve_events_with_store(
room_id: str, room_id: str,
state_sets: List[StateMap[str]], state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]], event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable, state_map_factory: Callable[[List[str]], Awaitable],
): ):
""" """
Args: Args:
@ -56,7 +53,7 @@ def resolve_events_with_store(
state_map_factory: will be called state_map_factory: will be called
with a list of event_ids that are needed, and should return with with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. an Awaitable that resolves to a dict of event_id to event.
Returns: Returns:
Deferred[dict[(str, str), str]]: Deferred[dict[(str, str), str]]:
@ -80,7 +77,7 @@ def resolve_events_with_store(
# dict[str, FrozenEvent]: a map from state event id to event. Only includes # dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map) # the state events which are in conflict (and those in event_map)
state_map = yield state_map_factory(needed_events) state_map = await state_map_factory(needed_events)
if event_map is not None: if event_map is not None:
state_map.update(event_map) state_map.update(event_map)
@ -110,7 +107,7 @@ def resolve_events_with_store(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
) )
state_map_new = yield state_map_factory(new_needed_events) state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values(): for event in state_map_new.values():
if event.room_id != room_id: if event.room_id != room_id:
raise Exception( raise Exception(

View file

@ -18,8 +18,6 @@ import itertools
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from twisted.internet import defer
import synapse.state import synapse.state
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
@ -32,14 +30,13 @@ from synapse.util import Clock
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# We want to yield to the reactor occasionally during state res when dealing # We want to await to the reactor occasionally during state res when dealing
# with large data sets, so that we don't exhaust the reactor. This is done by # with large data sets, so that we don't exhaust the reactor. This is done by
# yielding to reactor during loops every N iterations. # awaiting to reactor during loops every N iterations.
_YIELD_AFTER_ITERATIONS = 100 _AWAIT_AFTER_ITERATIONS = 100
@defer.inlineCallbacks async def resolve_events_with_store(
def resolve_events_with_store(
clock: Clock, clock: Clock,
room_id: str, room_id: str,
room_version: str, room_version: str,
@ -87,7 +84,7 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets' # Also fetch all auth events that appear in only some of the state sets'
# auth chains. # auth chains.
auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set( full_conflicted_set = set(
itertools.chain( itertools.chain(
@ -95,7 +92,7 @@ def resolve_events_with_store(
) )
) )
events = yield state_res_store.get_events( events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map], [eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True, allow_rejected=True,
) )
@ -118,14 +115,14 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
) )
sorted_power_events = yield _reverse_topological_power_sort( sorted_power_events = await _reverse_topological_power_sort(
clock, room_id, power_events, event_map, state_res_store, full_conflicted_set clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
) )
logger.debug("sorted %d power events", len(sorted_power_events)) logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one # Now sequentially auth each one
resolved_state = yield _iterative_auth_checks( resolved_state = await _iterative_auth_checks(
clock, clock,
room_id, room_id,
room_version, room_version,
@ -148,13 +145,13 @@ def resolve_events_with_store(
logger.debug("sorting %d remaining events", len(leftover_events)) logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None) pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort( leftover_events = await _mainline_sort(
clock, room_id, leftover_events, pl, event_map, state_res_store clock, room_id, leftover_events, pl, event_map, state_res_store
) )
logger.debug("resolving remaining events") logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks( resolved_state = await _iterative_auth_checks(
clock, clock,
room_id, room_id,
room_version, room_version,
@ -174,8 +171,7 @@ def resolve_events_with_store(
return resolved_state return resolved_state
@defer.inlineCallbacks async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to """Return the power level of the sender of the given event according to
their auth events. their auth events.
@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
Returns: Returns:
Deferred[int] Deferred[int]
""" """
event = yield _get_event(room_id, event_id, event_map, state_res_store) event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None pl = None
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
aev = yield _get_event( aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None: if pl is None:
# Couldn't find power level. Check if they're the creator of the room # Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
aev = yield _get_event( aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level) return int(level)
@defer.inlineCallbacks async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
def _get_auth_chain_difference(state_sets, event_map, state_res_store):
"""Compare the auth chains of each state set and return the set of events """Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains. that only appear in some but not all of the auth chains.
@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Deferred[set[str]]: Set of event IDs Deferred[set[str]]: Set of event IDs
""" """
difference = yield state_res_store.get_auth_chain_difference( difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets] [set(state_set.values()) for state_set in state_sets]
) )
@ -292,8 +287,7 @@ def _is_power_event(event):
return False return False
@defer.inlineCallbacks async def _add_event_and_auth_chain_to_graph(
def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
): ):
"""Helper function for _reverse_topological_power_sort that add the event """Helper function for _reverse_topological_power_sort that add the event
@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop() eid = state.pop()
graph.setdefault(eid, set()) graph.setdefault(eid, set())
event = yield _get_event(room_id, eid, event_map, state_res_store) event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
if aid in auth_diff: if aid in auth_diff:
if aid not in graph: if aid not in graph:
@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid) graph.setdefault(eid, set()).add(aid)
@defer.inlineCallbacks async def _reverse_topological_power_sort(
def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff clock, room_id, event_ids, event_map, state_res_store, auth_diff
): ):
"""Returns a list of the event_ids sorted by reverse topological ordering, """Returns a list of the event_ids sorted by reverse topological ordering,
@ -344,26 +337,26 @@ def _reverse_topological_power_sort(
graph = {} graph = {}
for idx, event_id in enumerate(event_ids, start=1): for idx, event_id in enumerate(event_ids, start=1):
yield _add_event_and_auth_chain_to_graph( await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff graph, room_id, event_id, event_map, state_res_store, auth_diff
) )
# We yield occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long. # ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0: if idx % _AWAIT_AFTER_ITERATIONS == 0:
yield clock.sleep(0) await clock.sleep(0)
event_to_pl = {} event_to_pl = {}
for idx, event_id in enumerate(graph, start=1): for idx, event_id in enumerate(graph, start=1):
pl = yield _get_power_level_for_sender( pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store room_id, event_id, event_map, state_res_store
) )
event_to_pl[event_id] = pl event_to_pl[event_id] = pl
# We yield occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long. # ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0: if idx % _AWAIT_AFTER_ITERATIONS == 0:
yield clock.sleep(0) await clock.sleep(0)
def _get_power_order(event_id): def _get_power_order(event_id):
ev = event_map[event_id] ev = event_map[event_id]
@ -378,8 +371,7 @@ def _reverse_topological_power_sort(
return sorted_events return sorted_events
@defer.inlineCallbacks async def _iterative_auth_checks(
def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
): ):
"""Sequentially apply auth checks to each event in given list, updating the """Sequentially apply auth checks to each event in given list, updating the
@ -405,7 +397,7 @@ def _iterative_auth_checks(
auth_events = {} auth_events = {}
for aid in event.auth_event_ids(): for aid in event.auth_event_ids():
ev = yield _get_event( ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
@ -420,7 +412,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event): for key in event_auth.auth_types_for_event(event):
if key in resolved_state: if key in resolved_state:
ev_id = resolved_state[key] ev_id = resolved_state[key]
ev = yield _get_event(room_id, ev_id, event_map, state_res_store) ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None: if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id] auth_events[key] = event_map[ev_id]
@ -438,16 +430,15 @@ def _iterative_auth_checks(
except AuthError: except AuthError:
pass pass
# We yield occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long. # ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0: if idx % _AWAIT_AFTER_ITERATIONS == 0:
yield clock.sleep(0) await clock.sleep(0)
return resolved_state return resolved_state
@defer.inlineCallbacks async def _mainline_sort(
def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
): ):
"""Returns a sorted list of event_ids sorted by mainline ordering based on """Returns a sorted list of event_ids sorted by mainline ordering based on
@ -474,21 +465,21 @@ def _mainline_sort(
idx = 0 idx = 0
while pl: while pl:
mainline.append(pl) mainline.append(pl)
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids() auth_events = pl_ev.auth_event_ids()
pl = None pl = None
for aid in auth_events: for aid in auth_events:
ev = yield _get_event( ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid pl = aid
break break
# We yield occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long. # ensure that we don't block the reactor loop for too long.
if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
yield clock.sleep(0) await clock.sleep(0)
idx += 1 idx += 1
@ -498,23 +489,24 @@ def _mainline_sort(
order_map = {} order_map = {}
for idx, ev_id in enumerate(event_ids, start=1): for idx, ev_id in enumerate(event_ids, start=1):
depth = yield _get_mainline_depth_for_event( depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store event_map[ev_id], mainline_map, event_map, state_res_store
) )
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
# We yield occasionally when we're working with large data sets to # We await occasionally when we're working with large data sets to
# ensure that we don't block the reactor loop for too long. # ensure that we don't block the reactor loop for too long.
if idx % _YIELD_AFTER_ITERATIONS == 0: if idx % _AWAIT_AFTER_ITERATIONS == 0:
yield clock.sleep(0) await clock.sleep(0)
event_ids.sort(key=lambda ev_id: order_map[ev_id]) event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids return event_ids
@defer.inlineCallbacks async def _get_mainline_depth_for_event(
def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): event, mainline_map, event_map, state_res_store
):
"""Get the mainline depths for the given event based on the mainline map """Get the mainline depths for the given event based on the mainline map
Args: Args:
@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None event = None
for aid in auth_events: for aid in auth_events:
aev = yield _get_event( aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True room_id, aid, event_map, state_res_store, allow_none=True
) )
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
return 0 return 0
@defer.inlineCallbacks async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking """Helper function to look up event in event_map, falling back to looking
it up in the store it up in the store
@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
Deferred[Optional[FrozenEvent]] Deferred[Optional[FrozenEvent]]
""" """
if event_id not in event_map: if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True) events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events) event_map.update(events)
event = event_map.get(event_id) event = event_map.get(event_id)

View file

@ -259,7 +259,7 @@ class PushRulesWorkerStore(
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._bulk_get_push_rules_for_room( result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event event.room_id, state_group, current_state_ids, event=event
) )

View file

@ -497,7 +497,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object() # To do this we set the state_group to a new object as object() != object()
state_group = object() state_group = object()
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
result = yield self._get_joined_users_from_context( result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context event.room_id, state_group, current_state_ids, event=event, context=context
) )

View file

@ -198,7 +198,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
room_id room_id
) )
users_with_profile = yield state.get_current_users_in_room(room_id) users_with_profile = yield defer.ensureDeferred(
state.get_current_users_in_room(room_id)
)
user_ids = set(users_with_profile) user_ids = set(users_with_profile)
# Update each user in the user directory. # Update each user in the user directory.

View file

@ -29,7 +29,6 @@ from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores from synapse.storage.data_stores import DataStores
from synapse.storage.data_stores.main.events import DeltaState from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap from synapse.types import StateMap
@ -648,6 +647,10 @@ class EventsPersistenceStorage(object):
room_version = await self.main_store.get_room_version_id(room_id) room_version = await self.main_store.get_room_version_id(room_id)
logger.debug("calling resolve_state_groups from preserve_events") logger.debug("calling resolve_state_groups from preserve_events")
# Avoid a circular import.
from synapse.state import StateResolutionStore
res = await self._state_resolution_handler.resolve_state_groups( res = await self._state_resolution_handler.resolve_state_groups(
room_id, room_id,
room_version, room_version,

View file

@ -26,21 +26,24 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt from synapse.types import JsonDict, ReadReceipt
from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase): class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver( return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]), state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
@override_config({"send_federation": True}) @override_config({"send_federation": True})
def test_send_receipts(self): def test_send_receipts(self):
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self): def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but """Send two receipts in quick succession; the second should be flushed, but
only after 20ms""" only after 20ms"""
mock_state_handler = self.hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
mock_send_transaction = ( mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction self.hs.get_federation_transport_client().send_transaction
) )
@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver( return self.setup_test_homeserver(
state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]), federation_transport_client=Mock(spec=["send_transaction"]),
) )
@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c return c
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
# stub out get_current_hosts_in_room
mock_state_handler = hs.get_state_handler()
mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
# stub out get_users_who_share_room_with_user so that it claims that # stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room # `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id): def get_users_who_share_room_with_user(user_id):

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import itertools import itertools
from typing import List
import attr import attr
@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map), state_res_store=TestStateResolutionStore(event_map),
) )
state_before = self.successResultOf(state_d) state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before) state_after = dict(state_before)
if fake_event.state_key is not None: if fake_event.state_key is not None:
@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map), state_res_store=TestStateResolutionStore(self.event_map),
) )
state = self.successResultOf(state_d) state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state) self.assert_dict(self.expected_combined_state, state)
@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
""" """
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
)
def _get_auth_chain(self, event_ids): def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected """Gets the full auth chain for a set of events (including rejected
events). events).
@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events presence of rejected events
Args: Args:
event_ids (list): The event IDs of the events to fetch the auth event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events. chain for. Must be state events.
Returns: Returns:
Deferred[list[str]]: List of event IDs of the auth chain. List of event IDs of the auth chain.
""" """
# Simple DFS for auth chain # Simple DFS for auth chain
@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:]) common = set(chains[0]).intersection(*chains[1:])
return set(chains[0]).union(*chains[1:]) - common return defer.succeed(set(chains[0]).union(*chains[1:]) - common)

View file

@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1 etype=EventTypes.Name, name=name, content={"name": name}, depth=1
) )
state = yield self.store.get_current_state(room_id=self.room.to_string()) state = yield defer.ensureDeferred(
self.store.get_current_state(room_id=self.room.to_string())
)
self.assertEquals(1, len(state)) self.assertEquals(1, len(state))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(
@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1 etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
) )
state = yield self.store.get_current_state(room_id=self.room.to_string()) state = yield defer.ensureDeferred(
self.store.get_current_state(room_id=self.room.to_string())
)
self.assertEquals(1, len(state)) self.assertEquals(1, len(state))
self.assertObjectHasAttributes( self.assertObjectHasAttributes(

View file

@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids) self._group_to_state[state_group] = dict(current_state_ids)
return state_group return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs): def get_events(self, event_ids, **kwargs):
return { return defer.succeed(
e_id: self._event_id_to_event[e_id] {
for e_id in event_ids e_id: self._event_id_to_event[e_id]
if e_id in self._event_id_to_event for e_id in event_ids
} if e_id in self._event_id_to_event
}
)
def get_state_group_delta(self, name): def get_state_group_delta(self, name):
return None, None return defer.succeed((None, None))
def register_events(self, events): def register_events(self, events):
for e in events: for e in events:
@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id): def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict): class DictObj(dict):
@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext] context_store = {} # type: dict[str, EventContext]
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} context_store = {}
for event in graph.walk(): for event in graph.walk():
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(
self.state.compute_event_context(event)
)
self.store.register_event_context(event, context) self.store.register_event_context(event, context)
context_store[event.event_id] = context context_store[event.event_id] = context
@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context(event, old_state=old_state) context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual( self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values() (e.event_id for e in old_state), current_state_ids.values()
) )
@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
context = yield self.state.compute_event_context(event, old_state=old_state) context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values()) self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual( self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values() (e.event_id for e in old_state + [event]), current_state_ids.values()
) )
@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = self.store.store_state_group( group_name = yield self.store.store_state_group(
prev_event_id, prev_event_id,
event.room_id, event.room_id,
None, None,
@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(self.state.compute_event_context(event))
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual( self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values()) {e.event_id for e in old_state}, set(current_state_ids.values())
@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""), create_event(type="test2", state_key=""),
] ]
group_name = self.store.store_state_group( group_name = yield self.store.store_state_group(
prev_event_id, prev_event_id,
event.room_id, event.room_id,
None, None,
@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
) )
self.store.register_event_id_state_group(prev_event_id, group_name) self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event) context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids() prev_state_ids = yield context.get_prev_state_ids()
@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6) self.assertEqual(len(current_state_ids), 6)
@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6) self.assertEqual(len(current_state_ids), 6)
@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")]) self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2 event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
) )
current_state_ids = yield context.get_current_state_ids() current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")]) self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
@defer.inlineCallbacks
def _get_context( def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2 self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
): ):
sg1 = self.store.store_state_group( sg1 = yield self.store.store_state_group(
prev_event_id_1, prev_event_id_1,
event.room_id, event.room_id,
None, None,
@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
) )
self.store.register_event_id_state_group(prev_event_id_1, sg1) self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = self.store.store_state_group( sg2 = yield self.store.store_state_group(
prev_event_id_2, prev_event_id_2,
event.room_id, event.room_id,
None, None,
@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
) )
self.store.register_event_id_state_group(prev_event_id_2, sg2) self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event) result = yield defer.ensureDeferred(self.state.compute_event_context(event))
return result

View file

@ -17,7 +17,7 @@
""" """
Utilities for running the unit tests Utilities for running the unit tests
""" """
from typing import Awaitable, TypeVar from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV") TV = TypeVar("TV")
@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed. # if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed") raise Exception("awaitable has not yet completed")
async def make_awaitable(result: Any):
"""Create an awaitable that just returns a result."""
return result