Persist auth/state events at backwards extremities when we fetch them (#6526)

The main point here is to make sure that the state returned by _get_state_in_room has been authed before we try to use it as state in the room.
This commit is contained in:
Richard van der Hoff 2019-12-16 12:26:28 +00:00 committed by GitHub
parent 9d173b312c
commit bc7de87650
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 83 additions and 169 deletions

1
changelog.d/6526.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs.

View file

@ -65,8 +65,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRes
from synapse.state import StateResolutionStore, resolve_events_with_store from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id from synapse.types import UserID, get_domain_from_id
from synapse.util import batch_iter, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server from synapse.visibility import filter_events_for_server
@ -238,7 +237,6 @@ class FederationHandler(BaseHandler):
return None return None
state = None state = None
auth_chain = []
# Get missing pdus if necessary. # Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier(): if not pdu.internal_metadata.is_outlier():
@ -348,7 +346,6 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and # Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event. # resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu} event_map = {event_id: pdu}
try: try:
# Get the state of the events we know about # Get the state of the events we know about
@ -369,24 +366,14 @@ class FederationHandler(BaseHandler):
"Requesting state at missing prev_event %s", event_id, "Requesting state at missing prev_event %s", event_id,
) )
room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p): with nested_logging_context(p):
# note that if any of the missing prevs share missing state or # note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped # auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client. # by the get_pdu_cache in federation_client.
( (remote_state, _,) = await self._get_state_for_room(
remote_state,
got_auth_chain,
) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True origin, room_id, p, include_event_in_state=True
) )
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
auth_chains.update(got_auth_chain)
remote_state_map = { remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state (x.type, x.state_key): x.event_id for x in remote_state
} }
@ -395,6 +382,7 @@ class FederationHandler(BaseHandler):
for x in remote_state: for x in remote_state:
event_map[x.event_id] = x event_map[x.event_id] = x
room_version = await self.store.get_room_version(room_id)
state_map = await resolve_events_with_store( state_map = await resolve_events_with_store(
room_id, room_id,
room_version, room_version,
@ -416,7 +404,6 @@ class FederationHandler(BaseHandler):
event_map.update(evs) event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)] state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception: except Exception:
logger.warning( logger.warning(
"[%s %s] Error attempting to resolve state at missing " "[%s %s] Error attempting to resolve state at missing "
@ -432,9 +419,7 @@ class FederationHandler(BaseHandler):
affected=event_id, affected=event_id,
) )
await self._process_received_pdu( await self._process_received_pdu(origin, pdu, state=state)
origin, pdu, state=state, auth_chain=auth_chain
)
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth): async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
""" """
@ -633,10 +618,7 @@ class FederationHandler(BaseHandler):
) -> Dict[str, EventBase]: ) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them. """Fetch events from a remote destination, checking if we already have them.
Args: Persists any events we don't already have as outliers.
destination
room_id
event_ids
If we fail to fetch any of the events, a warning will be logged, and the event If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to will be omitted from the result. Likewise, any events which turn out not to
@ -656,27 +638,15 @@ class FederationHandler(BaseHandler):
room_id, room_id,
) )
room_version = await self.store.get_room_version(room_id) await self._get_events_and_persist(
destination=destination, room_id=room_id, events=missing_events
)
# XXX 20 requests at once? really? # we need to make sure we re-load from the database to get the rejected
for batch in batch_iter(missing_events, 20): # state correct.
deferreds = [ fetched_events.update(
run_in_background( (await self.store.get_events(missing_events, allow_rejected=True))
self.federation_client.get_pdu, )
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
fetched_events[result.event_id] = result
# check for events which were in the wrong room. # check for events which were in the wrong room.
# #
@ -705,50 +675,26 @@ class FederationHandler(BaseHandler):
return fetched_events return fetched_events
async def _process_received_pdu(self, origin, event, state, auth_chain): async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
):
""" Called when we have a new pdu. We need to do auth checks and put it """ Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler. through the StateHandler.
Args:
origin: server sending the event
event: event to be persisted
state: Normally None, but if we are handling a gap in the graph
(ie, we are missing one or more prev_events), the resolved state at the
event
""" """
room_id = event.room_id room_id = event.room_id
event_id = event.event_id event_id = event.event_id
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event) logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
event_ids = set()
if state:
event_ids |= {e.event_id for e in state}
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id,
event_id,
[e.event.event_id for e in event_infos],
)
await self._handle_new_events(origin, event_infos)
try: try:
context = await self._handle_new_event(origin, event, state=state) context = await self._handle_new_event(origin, event, state=state)
except AuthError as e: except AuthError as e:
@ -803,8 +749,6 @@ class FederationHandler(BaseHandler):
if dest == self.server_name: if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.") raise SynapseError(400, "Can't backfill from self.")
room_version = await self.store.get_room_version(room_id)
events = await self.federation_client.backfill( events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities dest, room_id, limit=limit, extremities=extremities
) )
@ -833,6 +777,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events) event_ids = set(e.event_id for e in events)
# build a list of events whose prev_events weren't in the batch.
# (XXX: this will include events whose prev_events we already have; that doesn't
# sound right?)
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids] edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
logger.info("backfill: Got %d events with %d edges", len(events), len(edges)) logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@ -861,95 +808,11 @@ class FederationHandler(BaseHandler):
auth_events.update( auth_events.update(
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map} {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
) )
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
# Try and fetch any missing auth events from both DB and remote servers.
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
missing_auth - failed_to_fetch,
)
results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.federation_client.get_pdu,
[dest],
event_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
for event_id in missing_auth - failed_to_fetch
],
consumeErrors=True,
)
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
for event in results
if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
failed_to_fetch = missing_auth - set(auth_events)
seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
# We now have a chunk of events plus associated state and auth chain to
# persist. We do the persistence in two steps:
# 1. Auth events and state get persisted as outliers, plus the
# backward extremities get persisted (as non-outliers).
# 2. The rest of the events in the chunk get persisted one by one, as
# each one depends on the previous event for its state.
#
# The important thing is that events in the chunk get persisted as
# non-outliers, including when those events are also in the state or
# auth chain. Caution must therefore be taken to ensure that they are
# not accidentally marked as outliers.
# Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = [] ev_infos = []
for a in auth_events.values():
# We only want to persist auth events as outliers that we haven't
# seen and aren't about to persist as part of the backfilled chunk.
if a.event_id in seen_events or a.event_id in event_map:
continue
a.internal_metadata.outlier = True # Step 1: persist the events in the chunk we fetched state for (i.e.
ev_infos.append( # the backwards extremities), with custom auth events and state
_NewEventInfo(
event=a,
auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
): auth_events[a_id]
for a_id in a.auth_event_ids()
if a_id in auth_events
},
)
)
# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
for e_id in events_to_state: for e_id in events_to_state:
# For paranoia we ensure that these events are marked as # For paranoia we ensure that these events are marked as
# non-outliers # non-outliers
@ -1191,6 +1054,56 @@ class FederationHandler(BaseHandler):
return False return False
async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
):
"""Fetch the given events from a server, and persist them as outliers.
Logs a warning if we can't find the given event.
"""
room_version = await self.store.get_room_version(room_id)
event_infos = []
async def get_event(event_id: str):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True,
)
if event is None:
logger.warning(
"Server %s didn't return event %s", destination, event_id,
)
return
# recursively fetch the auth events for this event
auth_events = await self._get_events_from_store_or_dest(
destination, room_id, event.auth_event_ids()
)
auth = {}
for auth_event_id in event.auth_event_ids():
ae = auth_events.get(auth_event_id)
if ae:
auth[(ae.type, ae.state_key)] = ae
event_infos.append(_NewEventInfo(event, None, auth))
except Exception as e:
logger.warning(
"Error fetching missing state/auth event %s: %s %s",
event_id,
type(e),
e,
)
await concurrently_execute(get_event, events, 5)
await self._handle_new_events(
destination, event_infos,
)
def _sanity_check_event(self, ev): def _sanity_check_event(self, ev):
""" """
Do some early sanity checks of a received event Do some early sanity checks of a received event

View file

@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
Args: Args:
func (func): Function to execute, should return a deferred or coroutine. func (func): Function to execute, should return a deferred or coroutine.
args (list): List of arguments to pass to func, each invocation of func args (Iterable): List of arguments to pass to func, each invocation of func
gets a signle argument. gets a single argument.
limit (int): Maximum number of conccurent executions. limit (int): Maximum number of conccurent executions.
Returns: Returns: