Refactor _resolve_state_at_missing_prevs to return an EventContext (#13404)

Previously, `_resolve_state_at_missing_prevs` returned the resolved
state before an event and a partial state flag. These were unwieldy to
carry around would only ever be used to build an event context. Build
the event context directly instead.

Signed-off-by: Sean Quah <seanq@matrix.org>
This commit is contained in:
Sean Quah 2022-08-01 13:53:56 +01:00 committed by GitHub
parent 05aeeb3a80
commit 224d792dd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 68 additions and 86 deletions

1
changelog.d/13404.misc Normal file
View file

@ -0,0 +1 @@
Refactor `_resolve_state_at_missing_prevs` to compute an `EventContext` instead.

View file

@ -23,7 +23,6 @@ from typing import (
Dict, Dict,
Iterable, Iterable,
List, List,
Optional,
Sequence, Sequence,
Set, Set,
Tuple, Tuple,
@ -278,9 +277,8 @@ class FederationEventHandler:
) )
try: try:
await self._process_received_pdu( context = await self._state_handler.compute_event_context(pdu)
origin, pdu, state_ids=None, partial_state=None await self._process_received_pdu(origin, pdu, context)
)
except PartialStateConflictError: except PartialStateConflictError:
# The room was un-partial stated while we were processing the PDU. # The room was un-partial stated while we were processing the PDU.
# Try once more, with full state this time. # Try once more, with full state this time.
@ -288,9 +286,8 @@ class FederationEventHandler:
"Room %s was un-partial stated while processing the PDU, trying again.", "Room %s was un-partial stated while processing the PDU, trying again.",
room_id, room_id,
) )
await self._process_received_pdu( context = await self._state_handler.compute_event_context(pdu)
origin, pdu, state_ids=None, partial_state=None await self._process_received_pdu(origin, pdu, context)
)
async def on_send_membership_event( async def on_send_membership_event(
self, origin: str, event: EventBase self, origin: str, event: EventBase
@ -320,6 +317,7 @@ class FederationEventHandler:
The event and context of the event after inserting it into the room graph. The event and context of the event after inserting it into the room graph.
Raises: Raises:
RuntimeError if any prev_events are missing
SynapseError if the event is not accepted into the room SynapseError if the event is not accepted into the room
PartialStateConflictError if the room was un-partial stated in between PartialStateConflictError if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should computing the state at the event and persisting it. The caller should
@ -380,7 +378,7 @@ class FederationEventHandler:
# need to. # need to.
await self._event_creation_handler.cache_joined_hosts_for_event(event, context) await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
await self._check_for_soft_fail(event, None, origin=origin) await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context) await self._run_push_actions_and_persist_event(event, context)
return event, context return event, context
@ -538,36 +536,10 @@ class FederationEventHandler:
# #
# This is the same operation as we do when we receive a regular event # This is the same operation as we do when we receive a regular event
# over federation. # over federation.
state_ids, partial_state = await self._resolve_state_at_missing_prevs( context = await self._compute_event_context_with_maybe_missing_prevs(
destination, event destination, event
) )
if context.partial_state:
# There are three possible cases for (state_ids, partial_state):
# * `state_ids` and `partial_state` are both `None` if we had all the
# prev_events. The prev_events may or may not have partial state and
# we won't know until we compute the event context.
# * `state_ids` is not `None` and `partial_state` is `False` if we were
# missing some prev_events (but we have full state for any we did
# have). We calculated the full state after the prev_events.
# * `state_ids` is not `None` and `partial_state` is `True` if we were
# missing some, but not all, prev_events. At least one of the
# prev_events we did have had partial state, so we calculated a partial
# state after the prev_events.
context = None
if state_ids is not None and partial_state:
# the state after the prev events is still partial. We can't de-partial
# state the event, so don't bother building the event context.
pass
else:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
if context is None or context.partial_state:
# this can happen if some or all of the event's prev_events still have # this can happen if some or all of the event's prev_events still have
# partial state. We were careful to only pick events from the db without # partial state. We were careful to only pick events from the db without
# partial-state prev events, so that implies that a prev event has # partial-state prev events, so that implies that a prev event has
@ -840,26 +812,25 @@ class FederationEventHandler:
try: try:
try: try:
state_ids, partial_state = await self._resolve_state_at_missing_prevs( context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event origin, event
) )
await self._process_received_pdu( await self._process_received_pdu(
origin, origin,
event, event,
state_ids=state_ids, context,
partial_state=partial_state,
backfilled=backfilled, backfilled=backfilled,
) )
except PartialStateConflictError: except PartialStateConflictError:
# The room was un-partial stated while we were processing the event. # The room was un-partial stated while we were processing the event.
# Try once more, with full state this time. # Try once more, with full state this time.
state_ids, partial_state = await self._resolve_state_at_missing_prevs( context = await self._compute_event_context_with_maybe_missing_prevs(
origin, event origin, event
) )
# We ought to have full state now, barring some unlikely race where we left and # We ought to have full state now, barring some unlikely race where we left and
# rejoned the room in the background. # rejoned the room in the background.
if state_ids is not None and partial_state: if context.partial_state:
raise AssertionError( raise AssertionError(
f"Event {event.event_id} still has a partial resolved state " f"Event {event.event_id} still has a partial resolved state "
f"after room {event.room_id} was un-partial stated" f"after room {event.room_id} was un-partial stated"
@ -868,8 +839,7 @@ class FederationEventHandler:
await self._process_received_pdu( await self._process_received_pdu(
origin, origin,
event, event,
state_ids=state_ids, context,
partial_state=partial_state,
backfilled=backfilled, backfilled=backfilled,
) )
except FederationError as e: except FederationError as e:
@ -878,15 +848,18 @@ class FederationEventHandler:
else: else:
raise raise
async def _resolve_state_at_missing_prevs( async def _compute_event_context_with_maybe_missing_prevs(
self, dest: str, event: EventBase self, dest: str, event: EventBase
) -> Tuple[Optional[StateMap[str]], Optional[bool]]: ) -> EventContext:
"""Calculate the state at an event with missing prev_events. """Build an EventContext structure for a non-outlier event whose prev_events may
be missing.
This is used when we have pulled a batch of events from a remote server, and This is used when we have pulled a batch of events from a remote server, and may
still don't have all the prev_events. not have all the prev_events.
If we already have all the prev_events for `event`, this method does nothing. To build an EventContext, we need to calculate the state before the event. If we
already have all the prev_events for `event`, we can simply use the state after
the prev_events to calculate the state before `event`.
Otherwise, the missing prevs become new backwards extremities, and we fall back Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`, to asking the remote server for the state after each missing `prev_event`,
@ -907,10 +880,7 @@ class FederationEventHandler:
event: an event to check for missing prevs. event: an event to check for missing prevs.
Returns: Returns:
if we already had all the prev events, `None, None`. Otherwise, returns a The event context.
tuple containing:
* the event ids of the state at `event`.
* a boolean indicating whether the state may be partial.
Raises: Raises:
FederationError if we fail to get the state from the remote server after any FederationError if we fail to get the state from the remote server after any
@ -924,7 +894,7 @@ class FederationEventHandler:
missing_prevs = prevs - seen missing_prevs = prevs - seen
if not missing_prevs: if not missing_prevs:
return None, None return await self._state_handler.compute_event_context(event)
logger.info( logger.info(
"Event %s is missing prev_events %s: calculating state for a " "Event %s is missing prev_events %s: calculating state for a "
@ -990,7 +960,9 @@ class FederationEventHandler:
"We can't get valid state history.", "We can't get valid state history.",
affected=event_id, affected=event_id,
) )
return state_map, partial_state return await self._state_handler.compute_event_context(
event, state_ids_before_event=state_map, partial_state=partial_state
)
async def _get_state_ids_after_missing_prev_event( async def _get_state_ids_after_missing_prev_event(
self, self,
@ -1159,8 +1131,7 @@ class FederationEventHandler:
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
state_ids: Optional[StateMap[str]], context: EventContext,
partial_state: Optional[bool],
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
"""Called when we have a new non-outlier event. """Called when we have a new non-outlier event.
@ -1182,32 +1153,18 @@ class FederationEventHandler:
event: event to be persisted event: event to be persisted
state_ids: Normally None, but if we are handling a gap in the graph context: The `EventContext` to persist the event with.
(ie, we are missing one or more prev_events), the resolved state at the
event
partial_state:
`True` if `state_ids` is partial and omits non-critical membership
events.
`False` if `state_ids` is the full state.
`None` if `state_ids` is not provided. In this case, the flag will be
calculated based on `event`'s prev events.
backfilled: True if this is part of a historical batch of events (inhibits backfilled: True if this is part of a historical batch of events (inhibits
notification to clients, and validation of device keys.) notification to clients, and validation of device keys.)
PartialStateConflictError: if the room was un-partial stated in between PartialStateConflictError: if the room was un-partial stated in between
computing the state at the event and persisting it. The caller should retry computing the state at the event and persisting it. The caller should
exactly once in this case. recompute `context` and retry exactly once when this happens.
""" """
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
assert not event.internal_metadata.outlier assert not event.internal_metadata.outlier
context = await self._state_handler.compute_event_context(
event,
state_ids_before_event=state_ids,
partial_state=partial_state,
)
try: try:
await self._check_event_auth(origin, event, context) await self._check_event_auth(origin, event, context)
except AuthError as e: except AuthError as e:
@ -1219,7 +1176,7 @@ class FederationEventHandler:
# For new (non-backfilled and non-outlier) events we check if the event # For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we # passes auth based on the current state. If it doesn't then we
# "soft-fail" the event. # "soft-fail" the event.
await self._check_for_soft_fail(event, state_ids, origin=origin) await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context, backfilled) await self._run_push_actions_and_persist_event(event, context, backfilled)
@ -1782,7 +1739,7 @@ class FederationEventHandler:
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, self,
event: EventBase, event: EventBase,
state_ids: Optional[StateMap[str]], context: EventContext,
origin: str, origin: str,
) -> None: ) -> None:
"""Checks if we should soft fail the event; if so, marks the event as """Checks if we should soft fail the event; if so, marks the event as
@ -1793,7 +1750,7 @@ class FederationEventHandler:
Args: Args:
event event
state_ids: The state at the event if we don't have all the event's prev events context: The `EventContext` which we are about to persist the event with.
origin: The host the event originates from. origin: The host the event originates from.
""" """
if await self._store.is_partial_state_room(event.room_id): if await self._store.is_partial_state_room(event.room_id):
@ -1819,11 +1776,15 @@ class FederationEventHandler:
auth_types = auth_types_for_event(room_version_obj, event) auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state". # Calculate the "current state".
if state_ids is not None: seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
# If we're explicitly given the state then we won't have all the has_missing_prevs = bool(prev_event_ids - seen_event_ids)
# prev events, and so we have a gap in the graph. In this case if has_missing_prevs:
# we want to be a little careful as we might have been down for # We don't have all the prev_events of this event, which means we have a
# a while and have an incorrect view of the current state, # gap in the graph, and the new event is going to become a new backwards
# extremity.
#
# In this case we want to be a little careful as we might have been
# down for a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to # however we still want to do checks as gaps are easy to
# maliciously manufacture. # maliciously manufacture.
# #
@ -1836,6 +1797,7 @@ class FederationEventHandler:
event.room_id, extrem_ids event.room_id, extrem_ids
) )
state_sets: List[StateMap[str]] = list(state_sets_d.values()) state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_ids = await context.get_prev_state_ids()
state_sets.append(state_ids) state_sets.append(state_ids)
current_state_ids = ( current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store( await self._state_resolution_handler.resolve_events_with_store(

View file

@ -278,6 +278,10 @@ class StateHandler:
flag will be calculated based on `event`'s prev events. flag will be calculated based on `event`'s prev events.
Returns: Returns:
The event context. The event context.
Raises:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
""" """
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()
@ -432,6 +436,10 @@ class StateHandler:
Returns: Returns:
The resolved state The resolved state
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)

View file

@ -338,6 +338,10 @@ class StateStorageController:
event_ids: events to get state groups for event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete await_full_state: if true, will block if we do not yet have complete
state at these events. state at these events.
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie. they are outliers or unknown)
""" """
if await_full_state: if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids) await self._partial_state_events_tracker.await_full_state(event_ids)

View file

@ -280,14 +280,21 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# we poke this directly into _process_received_pdu, to avoid the # we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event. # federation handler wanting to backfill the fake event.
state_handler = self.hs.get_state_handler()
context = self.get_success(
state_handler.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
)
)
self.get_success( self.get_success(
federation_event_handler._process_received_pdu( federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, self.OTHER_SERVER_NAME,
event, event,
state_ids={ context,
(e.type, e.state_key): e.event_id for e in current_state
},
partial_state=False,
) )
) )