Extract _resolve_state_at_missing_prevs (#10624)

This is a follow-up to #10615: it takes the code that constructs the state at a backwards extremity, and extracts it to a separate method.
This commit is contained in:
Richard van der Hoff 2021-08-19 17:31:40 +01:00 committed by GitHub
parent 000aa89be6
commit 50af1efe4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 105 deletions

1
changelog.d/10624.misc Normal file
View file

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View file

@ -285,12 +285,12 @@ class FederationHandler(BaseHandler):
# - Fetching any missing prev events to fill in gaps in the graph # - Fetching any missing prev events to fill in gaps in the graph
# - Fetching state if we have a hole in the graph # - Fetching state if we have a hole in the graph
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
prevs = set(pdu.prev_event_ids()) if sent_to_us_directly:
seen = await self.store.have_events_in_timeline(prevs) prevs = set(pdu.prev_event_ids())
missing_prevs = prevs - seen seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen
if missing_prevs: if missing_prevs:
if sent_to_us_directly:
# We only backfill backwards to the min depth. # We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id) min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("min_depth: %d", min_depth) logger.debug("min_depth: %d", min_depth)
@ -351,106 +351,8 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id, affected=pdu.event_id,
) )
else: else:
# We don't have all of the prev_events for this event. state = await self._resolve_state_at_missing_prevs(origin, pdu)
#
# In this case, we need to fall back to asking another server in the
# federation for the state at this event. That's ok provided we then
# resolve the state against other bits of the DAG before using it (which
# will ensure that you can't just take over a room by sending an event,
# withholding its prev_events, and declaring yourself to be an admin in
# the subsequent state request).
#
# Since we're pulling this event as a missing prev_event, then clearly
# this event is not going to become the only forward-extremity and we are
# guaranteed to resolve its state against our existing forward
# extremities, so that should be fine.
#
# XXX this really feels like it could/should be merged with the above,
# but there is an interaction with min_depth that I'm not really
# following.
logger.info(
"Event %s is missing prev_events %s: calculating state for a "
"backwards extremity",
event_id,
shortstr(missing_prevs),
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = await self.state_store.get_state_groups_ids(
room_id, seen
)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
# Ask the remote server for the states we don't
# know about
for p in missing_prevs:
logger.info(
"Requesting state after missing prev_event %s", p
)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
remote_state = (
await self._get_state_after_missing_prev_event(
origin, room_id, p
)
)
remote_state_map = {
(x.type, x.state_key): x.event_id
for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
state_res_store=StateResolutionStore(self.store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing "
"prev_events",
exc_info=True,
)
raise FederationError(
"ERROR",
403,
"We can't get valid state history.",
affected=event_id,
)
# A second round of checks for all events. Check that the event passes auth # A second round of checks for all events. Check that the event passes auth
# based on `auth_events`, this allows us to assert that the event would # based on `auth_events`, this allows us to assert that the event would
@ -1493,6 +1395,123 @@ class FederationHandler(BaseHandler):
event_infos, event_infos,
) )
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
still don't have all the prev_events.
If we already have all the prev_events for `event`, this method does nothing.
Otherwise, the missing prevs become new backwards extremities, and we fall back
to asking the remote server for the state after each missing `prev_event`,
and resolving across them.
That's ok provided we then resolve the state against other bits of the DAG
before using it - in other words, that the received event `event` is not going
to become the only forwards_extremity in the room (which will ensure that you
can't just take over a room by sending an event, withholding its prev_events,
and declaring yourself to be an admin in the subsequent state request).
In other words: we should only call this method if `event` has been *pulled*
as part of a batch of missing prev events, or similar.
Params:
dest: the remote server to ask for state at the missing prevs. Typically,
this will be the server we got `event` from.
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
prevs = set(event.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
missing_prevs = prevs - seen
if not missing_prevs:
return None
logger.info(
"Event %s is missing prev_events %s: calculating state for a "
"backwards extremity",
event_id,
shortstr(missing_prevs),
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
event_map = {event_id: event}
try:
# Get the state of the events we know about
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps: List[StateMap[str]] = list(ours.values())
# we don't need this any more, let's delete it.
del ours
# Ask the remote server for the states we don't
# know about
for p in missing_prevs:
logger.info("Requesting state after missing prev_event %s", p)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
remote_state = await self._get_state_after_missing_prev_event(
dest, room_id, p
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
state_res_store=StateResolutionStore(self.store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
exc_info=True,
)
raise FederationError(
"ERROR",
403,
"We can't get valid state history.",
affected=event_id,
)
return state
def _sanity_check_event(self, ev: EventBase) -> None: def _sanity_check_event(self, ev: EventBase) -> None:
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event