forked from MirrorHub/synapse
Don't pull out state in compute_event_context
for unconflicted state (#13267)
This commit is contained in:
parent
599c403d99
commit
0ca4172b5d
7 changed files with 95 additions and 136 deletions
1
changelog.d/13267.misc
Normal file
1
changelog.d/13267.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Don't pull out state in `compute_event_context` for unconflicted state.
|
|
@ -1444,7 +1444,12 @@ class EventCreationHandler:
|
||||||
if state_entry.state_group in self._external_cache_joined_hosts_updates:
|
if state_entry.state_group in self._external_cache_joined_hosts_updates:
|
||||||
return
|
return
|
||||||
|
|
||||||
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
|
state = await state_entry.get_state(
|
||||||
|
self._storage_controllers.state, StateFilter.all()
|
||||||
|
)
|
||||||
|
joined_hosts = await self.store.get_joined_hosts(
|
||||||
|
event.room_id, state, state_entry
|
||||||
|
)
|
||||||
|
|
||||||
# Note that the expiry times must be larger than the expiry time in
|
# Note that the expiry times must be larger than the expiry time in
|
||||||
# _external_cache_joined_hosts_updates.
|
# _external_cache_joined_hosts_updates.
|
||||||
|
|
|
@ -31,7 +31,6 @@ from typing import (
|
||||||
Sequence,
|
Sequence,
|
||||||
Set,
|
Set,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -47,6 +46,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
|
||||||
from synapse.state import v1, v2
|
from synapse.state import v1, v2
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import StateMap
|
from synapse.types import StateMap
|
||||||
from synapse.util.async_helpers import Linearizer
|
from synapse.util.async_helpers import Linearizer
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
|
@ -54,6 +54,7 @@ from synapse.util.metrics import Measure, measure_func
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.controllers import StateStorageController
|
||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -83,17 +84,20 @@ def _gen_state_id() -> str:
|
||||||
|
|
||||||
|
|
||||||
class _StateCacheEntry:
|
class _StateCacheEntry:
|
||||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
__slots__ = ["state", "state_group", "prev_group", "delta_ids"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state: StateMap[str],
|
state: Optional[StateMap[str]],
|
||||||
state_group: Optional[int],
|
state_group: Optional[int],
|
||||||
prev_group: Optional[int] = None,
|
prev_group: Optional[int] = None,
|
||||||
delta_ids: Optional[StateMap[str]] = None,
|
delta_ids: Optional[StateMap[str]] = None,
|
||||||
):
|
):
|
||||||
|
if state is None and state_group is None:
|
||||||
|
raise Exception("Either state or state group must be not None")
|
||||||
|
|
||||||
# A map from (type, state_key) to event_id.
|
# A map from (type, state_key) to event_id.
|
||||||
self.state = frozendict(state)
|
self.state = frozendict(state) if state is not None else None
|
||||||
|
|
||||||
# the ID of a state group if one and only one is involved.
|
# the ID of a state group if one and only one is involved.
|
||||||
# otherwise, None otherwise?
|
# otherwise, None otherwise?
|
||||||
|
@ -102,20 +106,30 @@ class _StateCacheEntry:
|
||||||
self.prev_group = prev_group
|
self.prev_group = prev_group
|
||||||
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
|
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
|
||||||
|
|
||||||
# The `state_id` is a unique ID we generate that can be used as ID for
|
async def get_state(
|
||||||
# this collection of state. Usually this would be the same as the
|
self,
|
||||||
# state group, but on worker instances we can't generate a new state
|
state_storage: "StateStorageController",
|
||||||
# group each time we resolve state, so we generate a separate one that
|
state_filter: Optional["StateFilter"] = None,
|
||||||
# isn't persisted and is used solely for caches.
|
) -> StateMap[str]:
|
||||||
# `state_id` is either a state_group (and so an int) or a string. This
|
"""Get the state map for this entry, either from the in-memory state or
|
||||||
# ensures we don't accidentally persist a state_id as a stateg_group
|
looking up the state group in the DB.
|
||||||
if state_group:
|
"""
|
||||||
self.state_id: Union[str, int] = state_group
|
|
||||||
else:
|
if self.state is not None:
|
||||||
self.state_id = _gen_state_id()
|
return self.state
|
||||||
|
|
||||||
|
assert self.state_group is not None
|
||||||
|
|
||||||
|
return await state_storage.get_state_ids_for_group(
|
||||||
|
self.state_group, state_filter
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self.state)
|
# The len should is used to estimate how large this cache entry is, for
|
||||||
|
# cache eviction purposes. This is why if `self.state` is None it's fine
|
||||||
|
# to return 1.
|
||||||
|
|
||||||
|
return len(self.state) if self.state else 1
|
||||||
|
|
||||||
|
|
||||||
class StateHandler:
|
class StateHandler:
|
||||||
|
@ -153,7 +167,7 @@ class StateHandler:
|
||||||
"""
|
"""
|
||||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||||
return ret.state
|
return await ret.get_state(self._state_storage_controller, StateFilter.all())
|
||||||
|
|
||||||
async def get_current_users_in_room(
|
async def get_current_users_in_room(
|
||||||
self, room_id: str, latest_event_ids: List[str]
|
self, room_id: str, latest_event_ids: List[str]
|
||||||
|
@ -177,7 +191,8 @@ class StateHandler:
|
||||||
|
|
||||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
|
||||||
|
return await self.store.get_joined_users_from_state(room_id, state, entry)
|
||||||
|
|
||||||
async def get_hosts_in_room_at_events(
|
async def get_hosts_in_room_at_events(
|
||||||
self, room_id: str, event_ids: Collection[str]
|
self, room_id: str, event_ids: Collection[str]
|
||||||
|
@ -192,7 +207,8 @@ class StateHandler:
|
||||||
The hosts in the room at the given events
|
The hosts in the room at the given events
|
||||||
"""
|
"""
|
||||||
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||||
return await self.store.get_joined_hosts(room_id, entry)
|
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
|
||||||
|
return await self.store.get_joined_hosts(room_id, state, entry)
|
||||||
|
|
||||||
async def compute_event_context(
|
async def compute_event_context(
|
||||||
self,
|
self,
|
||||||
|
@ -227,10 +243,19 @@ class StateHandler:
|
||||||
#
|
#
|
||||||
if state_ids_before_event:
|
if state_ids_before_event:
|
||||||
# if we're given the state before the event, then we use that
|
# if we're given the state before the event, then we use that
|
||||||
state_group_before_event = None
|
|
||||||
state_group_before_event_prev_group = None
|
state_group_before_event_prev_group = None
|
||||||
deltas_to_state_group_before_event = None
|
deltas_to_state_group_before_event = None
|
||||||
entry = None
|
|
||||||
|
# .. though we need to get a state group for it.
|
||||||
|
state_group_before_event = (
|
||||||
|
await self._state_storage_controller.store_state_group(
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
prev_group=None,
|
||||||
|
delta_ids=None,
|
||||||
|
current_state_ids=state_ids_before_event,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# otherwise, we'll need to resolve the state across the prev_events.
|
# otherwise, we'll need to resolve the state across the prev_events.
|
||||||
|
@ -264,19 +289,14 @@ class StateHandler:
|
||||||
await_full_state=False,
|
await_full_state=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
state_ids_before_event = entry.state
|
|
||||||
state_group_before_event = entry.state_group
|
|
||||||
state_group_before_event_prev_group = entry.prev_group
|
state_group_before_event_prev_group = entry.prev_group
|
||||||
deltas_to_state_group_before_event = entry.delta_ids
|
deltas_to_state_group_before_event = entry.delta_ids
|
||||||
|
|
||||||
#
|
# We make sure that we have a state group assigned to the state.
|
||||||
# make sure that we have a state group at that point. If it's not a state event,
|
if entry.state_group is None:
|
||||||
# that will be the state group for the new event. If it *is* a state event,
|
state_ids_before_event = await entry.get_state(
|
||||||
# it might get rejected (in which case we'll need to persist it with the
|
self._state_storage_controller, StateFilter.all()
|
||||||
# previous state group)
|
)
|
||||||
#
|
|
||||||
|
|
||||||
if not state_group_before_event:
|
|
||||||
state_group_before_event = (
|
state_group_before_event = (
|
||||||
await self._state_storage_controller.store_state_group(
|
await self._state_storage_controller.store_state_group(
|
||||||
event.event_id,
|
event.event_id,
|
||||||
|
@ -286,14 +306,10 @@ class StateHandler:
|
||||||
current_state_ids=state_ids_before_event,
|
current_state_ids=state_ids_before_event,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign the new state group to the cached state entry.
|
|
||||||
#
|
|
||||||
# Note that this can race in that we could generate multiple state
|
|
||||||
# groups for the same state entry, but that is just inefficient
|
|
||||||
# rather than dangerous.
|
|
||||||
if entry and entry.state_group is None:
|
|
||||||
entry.state_group = state_group_before_event
|
entry.state_group = state_group_before_event
|
||||||
|
else:
|
||||||
|
state_group_before_event = entry.state_group
|
||||||
|
state_ids_before_event = None
|
||||||
|
|
||||||
#
|
#
|
||||||
# now if it's not a state event, we're done
|
# now if it's not a state event, we're done
|
||||||
|
@ -313,6 +329,10 @@ class StateHandler:
|
||||||
#
|
#
|
||||||
# otherwise, we'll need to create a new state group for after the event
|
# otherwise, we'll need to create a new state group for after the event
|
||||||
#
|
#
|
||||||
|
if state_ids_before_event is None:
|
||||||
|
state_ids_before_event = await entry.get_state(
|
||||||
|
self._state_storage_controller, StateFilter.all()
|
||||||
|
)
|
||||||
|
|
||||||
key = (event.type, event.state_key)
|
key = (event.type, event.state_key)
|
||||||
if key in state_ids_before_event:
|
if key in state_ids_before_event:
|
||||||
|
@ -372,9 +392,6 @@ class StateHandler:
|
||||||
state_group_ids_set = set(state_group_ids)
|
state_group_ids_set = set(state_group_ids)
|
||||||
if len(state_group_ids_set) == 1:
|
if len(state_group_ids_set) == 1:
|
||||||
(state_group_id,) = state_group_ids_set
|
(state_group_id,) = state_group_ids_set
|
||||||
state = await self._state_storage_controller.get_state_for_groups(
|
|
||||||
state_group_ids_set
|
|
||||||
)
|
|
||||||
(
|
(
|
||||||
prev_group,
|
prev_group,
|
||||||
delta_ids,
|
delta_ids,
|
||||||
|
@ -382,7 +399,7 @@ class StateHandler:
|
||||||
state_group_id
|
state_group_id
|
||||||
)
|
)
|
||||||
return _StateCacheEntry(
|
return _StateCacheEntry(
|
||||||
state=state[state_group_id],
|
state=None,
|
||||||
state_group=state_group_id,
|
state_group=state_group_id,
|
||||||
prev_group=prev_group,
|
prev_group=prev_group,
|
||||||
delta_ids=delta_ids,
|
delta_ids=delta_ids,
|
||||||
|
|
|
@ -43,4 +43,6 @@ class StorageControllers:
|
||||||
|
|
||||||
self.persistence = None
|
self.persistence = None
|
||||||
if stores.persist_events:
|
if stores.persist_events:
|
||||||
self.persistence = EventsPersistenceStorageController(hs, stores)
|
self.persistence = EventsPersistenceStorageController(
|
||||||
|
hs, stores, self.state
|
||||||
|
)
|
||||||
|
|
|
@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
|
||||||
from synapse.logging import opentracing
|
from synapse.logging import opentracing
|
||||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||||
|
from synapse.storage.controllers.state import StateStorageController
|
||||||
from synapse.storage.databases import Databases
|
from synapse.storage.databases import Databases
|
||||||
from synapse.storage.databases.main.events import DeltaState
|
from synapse.storage.databases.main.events import DeltaState
|
||||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||||
|
from synapse.storage.state import StateFilter
|
||||||
from synapse.types import (
|
from synapse.types import (
|
||||||
PersistedEventPosition,
|
PersistedEventPosition,
|
||||||
RoomStreamToken,
|
RoomStreamToken,
|
||||||
|
@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
|
||||||
current state and forward extremity changes.
|
current state and forward extremity changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs: "HomeServer", stores: Databases):
|
def __init__(
|
||||||
|
self,
|
||||||
|
hs: "HomeServer",
|
||||||
|
stores: Databases,
|
||||||
|
state_controller: StateStorageController,
|
||||||
|
):
|
||||||
# We ultimately want to split out the state store from the main store,
|
# We ultimately want to split out the state store from the main store,
|
||||||
# so we use separate variables here even though they point to the same
|
# so we use separate variables here even though they point to the same
|
||||||
# store for now.
|
# store for now.
|
||||||
|
@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
|
||||||
self._process_event_persist_queue_task
|
self._process_event_persist_queue_task
|
||||||
)
|
)
|
||||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||||
|
self._state_controller = state_controller
|
||||||
|
|
||||||
async def _process_event_persist_queue_task(
|
async def _process_event_persist_queue_task(
|
||||||
self,
|
self,
|
||||||
|
@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
|
||||||
state_res_store=StateResolutionStore(self.main_store),
|
state_res_store=StateResolutionStore(self.main_store),
|
||||||
)
|
)
|
||||||
|
|
||||||
return res.state
|
return await res.get_state(self._state_controller, StateFilter.all())
|
||||||
|
|
||||||
async def _persist_event_batch(
|
async def _persist_event_batch(
|
||||||
self, _room_id: str, task: _PersistEventsTask
|
self, _room_id: str, task: _PersistEventsTask
|
||||||
|
|
|
@ -31,7 +31,6 @@ import attr
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import (
|
from synapse.metrics.background_process_metrics import (
|
||||||
run_as_background_process,
|
run_as_background_process,
|
||||||
|
@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
|
|
||||||
return shared_room_ids or frozenset()
|
return shared_room_ids or frozenset()
|
||||||
|
|
||||||
async def get_joined_users_from_context(
|
|
||||||
self, event: EventBase, context: EventContext
|
|
||||||
) -> Dict[str, ProfileInfo]:
|
|
||||||
state_group: Union[object, int] = context.state_group
|
|
||||||
if not state_group:
|
|
||||||
# If state_group is None it means it has yet to be assigned a
|
|
||||||
# state group, i.e. we need to make sure that calls with a state_group
|
|
||||||
# of None don't hit previous cached calls with a None state_group.
|
|
||||||
# To do this we set the state_group to a new object as object() != object()
|
|
||||||
state_group = object()
|
|
||||||
|
|
||||||
current_state_ids = await context.get_current_state_ids()
|
|
||||||
assert current_state_ids is not None
|
|
||||||
assert state_group is not None
|
|
||||||
return await self._get_joined_users_from_context(
|
|
||||||
event.room_id, state_group, current_state_ids, event=event, context=context
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_joined_users_from_state(
|
async def get_joined_users_from_state(
|
||||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
|
||||||
) -> Dict[str, ProfileInfo]:
|
) -> Dict[str, ProfileInfo]:
|
||||||
state_group: Union[object, int] = state_entry.state_group
|
state_group: Union[object, int] = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
with Measure(self._clock, "get_joined_users_from_state"):
|
with Measure(self._clock, "get_joined_users_from_state"):
|
||||||
return await self._get_joined_users_from_context(
|
return await self._get_joined_users_from_context(
|
||||||
room_id, state_group, state_entry.state, context=state_entry
|
room_id, state_group, state, context=state_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000)
|
@cached(num_args=2, iterable=True, max_entries=100000)
|
||||||
async def _get_joined_users_from_context(
|
async def _get_joined_users_from_context(
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
state_group: Union[object, int],
|
state_group: Union[object, int],
|
||||||
current_state_ids: StateMap[str],
|
current_state_ids: StateMap[str],
|
||||||
cache_context: _CacheContext,
|
|
||||||
event: Optional[EventBase] = None,
|
event: Optional[EventBase] = None,
|
||||||
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None,
|
context: Optional["_StateCacheEntry"] = None,
|
||||||
) -> Dict[str, ProfileInfo]:
|
) -> Dict[str, ProfileInfo]:
|
||||||
# We don't use `state_group`, it's there so that we can cache based
|
# We don't use `state_group`, it's there so that we can cache based
|
||||||
# on it. However, it's important that it's never None, since two current_states
|
# on it. However, it's important that it's never None, since two current_states
|
||||||
|
@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_joined_hosts(
|
async def get_joined_hosts(
|
||||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
state_group: Union[object, int] = state_entry.state_group
|
state_group: Union[object, int] = state_entry.state_group
|
||||||
if not state_group:
|
if not state_group:
|
||||||
|
@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
assert state_group is not None
|
assert state_group is not None
|
||||||
with Measure(self._clock, "get_joined_hosts"):
|
with Measure(self._clock, "get_joined_hosts"):
|
||||||
return await self._get_joined_hosts(
|
return await self._get_joined_hosts(
|
||||||
room_id, state_group, state_entry=state_entry
|
room_id, state_group, state, state_entry=state_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(num_args=2, max_entries=10000, iterable=True)
|
@cached(num_args=2, max_entries=10000, iterable=True)
|
||||||
|
@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
self,
|
self,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
state_group: Union[object, int],
|
state_group: Union[object, int],
|
||||||
|
state: StateMap[str],
|
||||||
state_entry: "_StateCacheEntry",
|
state_entry: "_StateCacheEntry",
|
||||||
) -> FrozenSet[str]:
|
) -> FrozenSet[str]:
|
||||||
# We don't use `state_group`, it's there so that we can cache based on
|
# We don't use `state_group`, it's there so that we can cache based on
|
||||||
|
@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# The cache doesn't match the state group or prev state group,
|
# The cache doesn't match the state group or prev state group,
|
||||||
# so we calculate the result from first principles.
|
# so we calculate the result from first principles.
|
||||||
joined_users = await self.get_joined_users_from_state(
|
joined_users = await self.get_joined_users_from_state(
|
||||||
room_id, state_entry
|
room_id, state, state_entry
|
||||||
)
|
)
|
||||||
|
|
||||||
cache.hosts_to_joined_users = {}
|
cache.hosts_to_joined_users = {}
|
||||||
|
|
|
@ -23,7 +23,6 @@ from synapse.util import Clock
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import TestHomeServer
|
from tests.server import TestHomeServer
|
||||||
from tests.test_utils import event_injection
|
|
||||||
|
|
||||||
|
|
||||||
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
|
||||||
# It now knows about Charlie's server.
|
# It now knows about Charlie's server.
|
||||||
self.assertEqual(self.store._known_servers_count, 2)
|
self.assertEqual(self.store._known_servers_count, 2)
|
||||||
|
|
||||||
def test_get_joined_users_from_context(self) -> None:
|
|
||||||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
|
||||||
bob_event = self.get_success(
|
|
||||||
event_injection.inject_member_event(
|
|
||||||
self.hs, room, self.u_bob, Membership.JOIN
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# first, create a regular event
|
|
||||||
event, context = self.get_success(
|
|
||||||
event_injection.create_event(
|
|
||||||
self.hs,
|
|
||||||
room_id=room,
|
|
||||||
sender=self.u_alice,
|
|
||||||
prev_event_ids=[bob_event.event_id],
|
|
||||||
type="m.test.1",
|
|
||||||
content={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
users = self.get_success(
|
|
||||||
self.store.get_joined_users_from_context(event, context)
|
|
||||||
)
|
|
||||||
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
|
|
||||||
|
|
||||||
# Regression test for #7376: create a state event whose key matches bob's
|
|
||||||
# user_id, but which is *not* a membership event, and persist that; then check
|
|
||||||
# that `get_joined_users_from_context` returns the correct users for the next event.
|
|
||||||
non_member_event = self.get_success(
|
|
||||||
event_injection.inject_event(
|
|
||||||
self.hs,
|
|
||||||
room_id=room,
|
|
||||||
sender=self.u_bob,
|
|
||||||
prev_event_ids=[bob_event.event_id],
|
|
||||||
type="m.test.2",
|
|
||||||
state_key=self.u_bob,
|
|
||||||
content={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
event, context = self.get_success(
|
|
||||||
event_injection.create_event(
|
|
||||||
self.hs,
|
|
||||||
room_id=room,
|
|
||||||
sender=self.u_alice,
|
|
||||||
prev_event_ids=[non_member_event.event_id],
|
|
||||||
type="m.test.3",
|
|
||||||
content={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
users = self.get_success(
|
|
||||||
self.store.get_joined_users_from_context(event, context)
|
|
||||||
)
|
|
||||||
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
|
|
||||||
|
|
||||||
def test__null_byte_in_display_name_properly_handled(self) -> None:
|
def test__null_byte_in_display_name_properly_handled(self) -> None:
|
||||||
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue