Don't pull out state in compute_event_context for unconflicted state (#13267)

This commit is contained in:
Erik Johnston 2022-07-14 14:57:02 +01:00 committed by GitHub
parent 599c403d99
commit 0ca4172b5d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 95 additions and 136 deletions

1
changelog.d/13267.misc Normal file
View file

@ -0,0 +1 @@
Don't pull out state in `compute_event_context` for unconflicted state.

View file

@ -1444,7 +1444,12 @@ class EventCreationHandler:
if state_entry.state_group in self._external_cache_joined_hosts_updates: if state_entry.state_group in self._external_cache_joined_hosts_updates:
return return
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) state = await state_entry.get_state(
self._storage_controllers.state, StateFilter.all()
)
joined_hosts = await self.store.get_joined_hosts(
event.room_id, state, state_entry
)
# Note that the expiry times must be larger than the expiry time in # Note that the expiry times must be larger than the expiry time in
# _external_cache_joined_hosts_updates. # _external_cache_joined_hosts_updates.

View file

@ -31,7 +31,6 @@ from typing import (
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
Union,
) )
import attr import attr
@ -47,6 +46,7 @@ from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServ
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@ -54,6 +54,7 @@ from synapse.util.metrics import Measure, measure_func
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.controllers import StateStorageController
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -83,17 +84,20 @@ def _gen_state_id() -> str:
class _StateCacheEntry: class _StateCacheEntry:
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] __slots__ = ["state", "state_group", "prev_group", "delta_ids"]
def __init__( def __init__(
self, self,
state: StateMap[str], state: Optional[StateMap[str]],
state_group: Optional[int], state_group: Optional[int],
prev_group: Optional[int] = None, prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None, delta_ids: Optional[StateMap[str]] = None,
): ):
if state is None and state_group is None:
raise Exception("Either state or state group must be not None")
# A map from (type, state_key) to event_id. # A map from (type, state_key) to event_id.
self.state = frozendict(state) self.state = frozendict(state) if state is not None else None
# the ID of a state group if one and only one is involved. # the ID of a state group if one and only one is involved.
# otherwise, None otherwise? # otherwise, None otherwise?
@ -102,20 +106,30 @@ class _StateCacheEntry:
self.prev_group = prev_group self.prev_group = prev_group
self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
# The `state_id` is a unique ID we generate that can be used as ID for async def get_state(
# this collection of state. Usually this would be the same as the self,
# state group, but on worker instances we can't generate a new state state_storage: "StateStorageController",
# group each time we resolve state, so we generate a separate one that state_filter: Optional["StateFilter"] = None,
# isn't persisted and is used solely for caches. ) -> StateMap[str]:
# `state_id` is either a state_group (and so an int) or a string. This """Get the state map for this entry, either from the in-memory state or
# ensures we don't accidentally persist a state_id as a stateg_group looking up the state group in the DB.
if state_group: """
self.state_id: Union[str, int] = state_group
else: if self.state is not None:
self.state_id = _gen_state_id() return self.state
assert self.state_group is not None
return await state_storage.get_state_ids_for_group(
self.state_group, state_filter
)
def __len__(self) -> int: def __len__(self) -> int:
return len(self.state) # The len should is used to estimate how large this cache entry is, for
# cache eviction purposes. This is why if `self.state` is None it's fine
# to return 1.
return len(self.state) if self.state else 1
class StateHandler: class StateHandler:
@ -153,7 +167,7 @@ class StateHandler:
""" """
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state return await ret.get_state(self._state_storage_controller, StateFilter.all())
async def get_current_users_in_room( async def get_current_users_in_room(
self, room_id: str, latest_event_ids: List[str] self, room_id: str, latest_event_ids: List[str]
@ -177,7 +191,8 @@ class StateHandler:
logger.debug("calling resolve_state_groups from get_current_users_in_room") logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry) state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_users_from_state(room_id, state, entry)
async def get_hosts_in_room_at_events( async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: Collection[str]
@ -192,7 +207,8 @@ class StateHandler:
The hosts in the room at the given events The hosts in the room at the given events
""" """
entry = await self.resolve_state_groups_for_events(room_id, event_ids) entry = await self.resolve_state_groups_for_events(room_id, event_ids)
return await self.store.get_joined_hosts(room_id, entry) state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
async def compute_event_context( async def compute_event_context(
self, self,
@ -227,10 +243,19 @@ class StateHandler:
# #
if state_ids_before_event: if state_ids_before_event:
# if we're given the state before the event, then we use that # if we're given the state before the event, then we use that
state_group_before_event = None
state_group_before_event_prev_group = None state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None deltas_to_state_group_before_event = None
entry = None
# .. though we need to get a state group for it.
state_group_before_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=state_ids_before_event,
)
)
else: else:
# otherwise, we'll need to resolve the state across the prev_events. # otherwise, we'll need to resolve the state across the prev_events.
@ -264,19 +289,14 @@ class StateHandler:
await_full_state=False, await_full_state=False,
) )
state_ids_before_event = entry.state
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids deltas_to_state_group_before_event = entry.delta_ids
# # We make sure that we have a state group assigned to the state.
# make sure that we have a state group at that point. If it's not a state event, if entry.state_group is None:
# that will be the state group for the new event. If it *is* a state event, state_ids_before_event = await entry.get_state(
# it might get rejected (in which case we'll need to persist it with the self._state_storage_controller, StateFilter.all()
# previous state group) )
#
if not state_group_before_event:
state_group_before_event = ( state_group_before_event = (
await self._state_storage_controller.store_state_group( await self._state_storage_controller.store_state_group(
event.event_id, event.event_id,
@ -286,14 +306,10 @@ class StateHandler:
current_state_ids=state_ids_before_event, current_state_ids=state_ids_before_event,
) )
) )
# Assign the new state group to the cached state entry.
#
# Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event entry.state_group = state_group_before_event
else:
state_group_before_event = entry.state_group
state_ids_before_event = None
# #
# now if it's not a state event, we're done # now if it's not a state event, we're done
@ -313,6 +329,10 @@ class StateHandler:
# #
# otherwise, we'll need to create a new state group for after the event # otherwise, we'll need to create a new state group for after the event
# #
if state_ids_before_event is None:
state_ids_before_event = await entry.get_state(
self._state_storage_controller, StateFilter.all()
)
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in state_ids_before_event: if key in state_ids_before_event:
@ -372,9 +392,6 @@ class StateHandler:
state_group_ids_set = set(state_group_ids) state_group_ids_set = set(state_group_ids)
if len(state_group_ids_set) == 1: if len(state_group_ids_set) == 1:
(state_group_id,) = state_group_ids_set (state_group_id,) = state_group_ids_set
state = await self._state_storage_controller.get_state_for_groups(
state_group_ids_set
)
( (
prev_group, prev_group,
delta_ids, delta_ids,
@ -382,7 +399,7 @@ class StateHandler:
state_group_id state_group_id
) )
return _StateCacheEntry( return _StateCacheEntry(
state=state[state_group_id], state=None,
state_group=state_group_id, state_group=state_group_id,
prev_group=prev_group, prev_group=prev_group,
delta_ids=delta_ids, delta_ids=delta_ids,

View file

@ -43,4 +43,6 @@ class StorageControllers:
self.persistence = None self.persistence = None
if stores.persist_events: if stores.persist_events:
self.persistence = EventsPersistenceStorageController(hs, stores) self.persistence = EventsPersistenceStorageController(
hs, stores, self.state
)

View file

@ -48,9 +48,11 @@ from synapse.events.snapshot import EventContext
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState from synapse.storage.databases.main.events import DeltaState
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import ( from synapse.types import (
PersistedEventPosition, PersistedEventPosition,
RoomStreamToken, RoomStreamToken,
@ -308,7 +310,12 @@ class EventsPersistenceStorageController:
current state and forward extremity changes. current state and forward extremity changes.
""" """
def __init__(self, hs: "HomeServer", stores: Databases): def __init__(
self,
hs: "HomeServer",
stores: Databases,
state_controller: StateStorageController,
):
# We ultimately want to split out the state store from the main store, # We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same # so we use separate variables here even though they point to the same
# store for now. # store for now.
@ -325,6 +332,7 @@ class EventsPersistenceStorageController:
self._process_event_persist_queue_task self._process_event_persist_queue_task
) )
self._state_resolution_handler = hs.get_state_resolution_handler() self._state_resolution_handler = hs.get_state_resolution_handler()
self._state_controller = state_controller
async def _process_event_persist_queue_task( async def _process_event_persist_queue_task(
self, self,
@ -504,7 +512,7 @@ class EventsPersistenceStorageController:
state_res_store=StateResolutionStore(self.main_store), state_res_store=StateResolutionStore(self.main_store),
) )
return res.state return await res.get_state(self._state_controller, StateFilter.all())
async def _persist_event_batch( async def _persist_event_batch(
self, _room_id: str, task: _PersistEventsTask self, _room_id: str, task: _PersistEventsTask

View file

@ -31,7 +31,6 @@ import attr
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import ( from synapse.metrics.background_process_metrics import (
run_as_background_process, run_as_background_process,
@ -780,26 +779,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return shared_room_ids or frozenset() return shared_room_ids or frozenset()
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = context.state_group
if not state_group:
# If state_group is None it means it has yet to be assigned a
# state group, i.e. we need to make sure that calls with a state_group
# of None don't hit previous cached calls with a None state_group.
# To do this we set the state_group to a new object as object() != object()
state_group = object()
current_state_ids = await context.get_current_state_ids()
assert current_state_ids is not None
assert state_group is not None
return await self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
async def get_joined_users_from_state( async def get_joined_users_from_state(
self, room_id: str, state_entry: "_StateCacheEntry" self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> Dict[str, ProfileInfo]: ) -> Dict[str, ProfileInfo]:
state_group: Union[object, int] = state_entry.state_group state_group: Union[object, int] = state_entry.state_group
if not state_group: if not state_group:
@ -812,18 +793,17 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None assert state_group is not None
with Measure(self._clock, "get_joined_users_from_state"): with Measure(self._clock, "get_joined_users_from_state"):
return await self._get_joined_users_from_context( return await self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry room_id, state_group, state, context=state_entry
) )
@cached(num_args=2, cache_context=True, iterable=True, max_entries=100000) @cached(num_args=2, iterable=True, max_entries=100000)
async def _get_joined_users_from_context( async def _get_joined_users_from_context(
self, self,
room_id: str, room_id: str,
state_group: Union[object, int], state_group: Union[object, int],
current_state_ids: StateMap[str], current_state_ids: StateMap[str],
cache_context: _CacheContext,
event: Optional[EventBase] = None, event: Optional[EventBase] = None,
context: Optional[Union[EventContext, "_StateCacheEntry"]] = None, context: Optional["_StateCacheEntry"] = None,
) -> Dict[str, ProfileInfo]: ) -> Dict[str, ProfileInfo]:
# We don't use `state_group`, it's there so that we can cache based # We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states # on it. However, it's important that it's never None, since two current_states
@ -1017,7 +997,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) )
async def get_joined_hosts( async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry" self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
) -> FrozenSet[str]: ) -> FrozenSet[str]:
state_group: Union[object, int] = state_entry.state_group state_group: Union[object, int] = state_entry.state_group
if not state_group: if not state_group:
@ -1030,7 +1010,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
assert state_group is not None assert state_group is not None
with Measure(self._clock, "get_joined_hosts"): with Measure(self._clock, "get_joined_hosts"):
return await self._get_joined_hosts( return await self._get_joined_hosts(
room_id, state_group, state_entry=state_entry room_id, state_group, state, state_entry=state_entry
) )
@cached(num_args=2, max_entries=10000, iterable=True) @cached(num_args=2, max_entries=10000, iterable=True)
@ -1038,6 +1018,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self, self,
room_id: str, room_id: str,
state_group: Union[object, int], state_group: Union[object, int],
state: StateMap[str],
state_entry: "_StateCacheEntry", state_entry: "_StateCacheEntry",
) -> FrozenSet[str]: ) -> FrozenSet[str]:
# We don't use `state_group`, it's there so that we can cache based on # We don't use `state_group`, it's there so that we can cache based on
@ -1093,7 +1074,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# The cache doesn't match the state group or prev state group, # The cache doesn't match the state group or prev state group,
# so we calculate the result from first principles. # so we calculate the result from first principles.
joined_users = await self.get_joined_users_from_state( joined_users = await self.get_joined_users_from_state(
room_id, state_entry room_id, state, state_entry
) )
cache.hosts_to_joined_users = {} cache.hosts_to_joined_users = {}

View file

@ -23,7 +23,6 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import TestHomeServer from tests.server import TestHomeServer
from tests.test_utils import event_injection
class RoomMemberStoreTestCase(unittest.HomeserverTestCase): class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
@ -110,60 +109,6 @@ class RoomMemberStoreTestCase(unittest.HomeserverTestCase):
# It now knows about Charlie's server. # It now knows about Charlie's server.
self.assertEqual(self.store._known_servers_count, 2) self.assertEqual(self.store._known_servers_count, 2)
def test_get_joined_users_from_context(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)
bob_event = self.get_success(
event_injection.inject_member_event(
self.hs, room, self.u_bob, Membership.JOIN
)
)
# first, create a regular event
event, context = self.get_success(
event_injection.create_event(
self.hs,
room_id=room,
sender=self.u_alice,
prev_event_ids=[bob_event.event_id],
type="m.test.1",
content={},
)
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
)
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
# Regression test for #7376: create a state event whose key matches bob's
# user_id, but which is *not* a membership event, and persist that; then check
# that `get_joined_users_from_context` returns the correct users for the next event.
non_member_event = self.get_success(
event_injection.inject_event(
self.hs,
room_id=room,
sender=self.u_bob,
prev_event_ids=[bob_event.event_id],
type="m.test.2",
state_key=self.u_bob,
content={},
)
)
event, context = self.get_success(
event_injection.create_event(
self.hs,
room_id=room,
sender=self.u_alice,
prev_event_ids=[non_member_event.event_id],
type="m.test.3",
content={},
)
)
users = self.get_success(
self.store.get_joined_users_from_context(event, context)
)
self.assertEqual(users.keys(), {self.u_alice, self.u_bob})
def test__null_byte_in_display_name_properly_handled(self) -> None: def test__null_byte_in_display_name_properly_handled(self) -> None:
room = self.helper.create_room_as(self.u_alice, tok=self.t_alice) room = self.helper.create_room_as(self.u_alice, tok=self.t_alice)