Persist auth/state events at backwards extremities when we fetch them (#6526)

The main point here is to make sure that the state returned by _get_state_in_room has been authed before we try to use it as state in the room.
This commit is contained in:
Richard van der Hoff 2019-12-16 12:26:28 +00:00 committed by Richard van der Hoff
parent 83895316d4
commit ff773ff724
3 changed files with 83 additions and 165 deletions

1
changelog.d/6526.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.

View file

@ -64,8 +64,7 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
@ -240,7 +239,6 @@ class FederationHandler(BaseHandler):
return None return None
state = None state = None
auth_chain = []
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
@ -346,7 +344,6 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and # Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
@ -370,24 +367,14 @@ class FederationHandler(BaseHandler):
p, p,
) )
room_version = yield self.store.get_room_version(room_id)
with nested_logging_context(p): with nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped # auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client. # by the get_pdu_cache in federation_client.
( (remote_state, _,) = yield self._get_state_for_room(
remote_state,
got_auth_chain,
) = yield self._get_state_for_room(
origin, room_id, p, include_event_in_state=True origin, room_id, p, include_event_in_state=True
) )
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
auth_chains.update(got_auth_chain)
remote_state_map = { remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state (x.type, x.state_key): x.event_id for x in remote_state
} }
@ -396,6 +383,7 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
room_version = yield self.store.get_room_version(room_id)
state_map = yield resolve_events_with_store( state_map = yield resolve_events_with_store(
room_id, room_id,
room_version, room_version,
@ -417,7 +405,6 @@ class FederationHandler(BaseHandler):
event_map.update(evs) event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)] state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception: except Exception:
logger.warning( logger.warning(
"[%s %s] Error attempting to resolve state at missing " "[%s %s] Error attempting to resolve state at missing "
@ -433,9 +420,7 @@ class FederationHandler(BaseHandler):
affected=event_id, affected=event_id,
) )
yield self._process_received_pdu( yield self._process_received_pdu(origin, pdu, state=state)
origin, pdu, state=state, auth_chain=auth_chain
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
@ -638,6 +623,8 @@ class FederationHandler(BaseHandler):
room_id (str) room_id (str)
event_ids (Iterable[str]) event_ids (Iterable[str])
Persists any events we don't already have as outliers.
If we fail to fetch any of the events, a warning will be logged, and the event If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to will be omitted from the result. Likewise, any events which turn out not to
be in the given room. be in the given room.
@ -657,27 +644,15 @@ class FederationHandler(BaseHandler):
room_id, room_id,
) )
room_version = yield self.store.get_room_version(room_id) yield self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
) )
for success, result in res: # we need to make sure we re-load from the database to get the rejected
if success and result: # state correct.
fetched_events[result.event_id] = result fetched_events.update(
(yield self.store.get_events(missing_events, allow_rejected=True))
)
# check for events which were in the wrong room. # check for events which were in the wrong room.
# #
@ -707,50 +682,24 @@ class FederationHandler(BaseHandler):
return fetched_events return fetched_events
@defer.inlineCallbacks @defer.inlineCallbacks
def _process_received_pdu(self, origin, event, state, auth_chain): def _process_received_pdu(self, origin, event, state):
""" Called when we have a new pdu. We need to do auth checks and put it """ Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
Args:
origin: server sending the event
event: event to be persisted
state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
""" """
room_id = event.room_id room_id = event.room_id
event_id = event.event_id event_id = event.event_id
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id,
event_id,
[e.event.event_id for e in event_infos],
)
yield self._handle_new_events(origin, event_infos)
try: try:
context = yield self._handle_new_event(origin, event, state=state) context = yield self._handle_new_event(origin, event, state=state)
except AuthError as e: except AuthError as e:
@ -806,8 +755,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id)
events = yield self.federation_client.backfill( events = yield self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities dest, room_id, limit=limit, extremities=extremities
) )
@ -836,6 +783,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events) event_ids = set(e.event_id for e in events)
# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@ -864,95 +814,11 @@ class FederationHandler(BaseHandler):
auth_events.update( auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
) )
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch,
)
results = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.federation_client.get_pdu,
[dest],
event_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results
if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
# We now have a chunk of events plus associated state and auth chain to
# persist. We do the persistence in two steps:
# 1. Auth events and state get persisted as outliers, plus the
# backward extremities get persisted (as non-outliers).
# 2. The rest of the events in the chunk get persisted one by one, as
# each one depends on the previous event for its state.
#
# The important thing is that events in the chunk get persisted as
# non-outliers, including when those events are also in the state or
# auth chain. Caution must therefore be taken to ensure that they are
# not accidentally marked as outliers.
# Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = [] ev_infos = []
for a in auth_events.values():
# We only want to persist auth events as outliers that we haven't
# seen and aren't about to persist as part of the backfilled chunk.
if a.event_id in seen_events or a.event_id in event_map:
continue
a.internal_metadata.outlier = True # Step 1: persist the events in the chunk we fetched state for (i.e.
ev_infos.append( # the backwards extremities), with custom auth events and state
_NewEventInfo(
event=a,
auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in a.auth_event_ids()
if a_id in auth_events
},
)
)
# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
for e_id in events_to_state: for e_id in events_to_state:
# For paranoia we ensure that these events are marked as # For paranoia we ensure that these events are marked as
# non-outliers # non-outliers
@ -1194,6 +1060,57 @@ class FederationHandler(BaseHandler):
return False return False
@defer.inlineCallbacks
def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.
Logs a warning if we can't find the given event.
"""
room_version = yield self.store.get_room_version(room_id)
event_infos = []
async def get_event(event_id: str):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True,
)
if event is None:
logger.warning(
"Server %s didn't return event %s", destination, event_id,
)
return
# recursively fetch the auth events for this event
auth_events = await self._get_events_from_store_or_dest(
destination, room_id, event.auth_event_ids()
)
auth = {}
for auth_event_id in event.auth_event_ids():
ae = auth_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae
event_infos.append(_NewEventInfo(event, None, auth))
except Exception as e:
logger.warning(
"Error fetching missing state/auth event %s: %s %s",
event_id,
type(e),
e,
)
yield concurrently_execute(get_event, events, 5)
yield self._handle_new_events(
destination, event_infos,
)
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev):
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event

View file

@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
Args: Args:
func (func): Function to execute, should return a deferred or coroutine. func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func args (Iterable): List of arguments to pass to func, each invocation of func
gets a signle argument. gets a single argument.
limit (int): Maximum number of conccurent executions. limit (int): Maximum number of conccurent executions.
Returns: Returns: