Clean up federation event auth code (#10539)

* drop old-room hack

pretty sure we don't need this any more.

* Remove incorrect comment about modifying `context`

It doesn't look like the supplied context is ever modified.

* Stop `_auth_and_persist_event` modifying its parameters

This is only called in three places. Two of them don't pass `auth_events`, and
the third doesn't use the dict after passing it in, so this should be non-functional.

* Stop `_check_event_auth` modifying its parameters

`_check_event_auth` is only called in three places. `on_send_membership_event`
doesn't pass an `auth_events`, and `prep` and `_auth_and_persist_event` do not
use the map after passing it in.

* Stop `_update_auth_events_and_context_for_auth` modifying its parameters

Return the updated auth event dict, rather than modifying the parameter.

This is only called from `_check_event_auth`.

* Improve documentation on `_auth_and_persist_event`

Rename `auth_events` parameter to better reflect what it contains.

* Improve documentation on `_NewEventInfo`

* Improve documentation on `_check_event_auth`

rename `auth_events` parameter to better describe what it contains

* changelog
This commit is contained in:
Richard van der Hoff 2021-08-06 13:54:23 +01:00 committed by GitHub
parent f4ade972ad
commit 1bebc0b78c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 56 deletions

1
changelog.d/10539.misc Normal file
View file

@ -0,0 +1 @@
Clean up some of the federation event authentication code for clarity.

View file

@ -109,21 +109,33 @@ soft_failed_event_counter = Counter(
) )
@attr.s(slots=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _NewEventInfo: class _NewEventInfo:
"""Holds information about a received event, ready for passing to _auth_and_persist_events """Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes: Attributes:
event: the received event event: the received event
state: the state at that event state: the state at that event, according to /state_ids from a remote
homeserver. Only populated for backfilled events which are going to be a
new backwards extremity.
claimed_auth_event_map: a map of (type, state_key) => event for the event's
claimed auth_events.
This can include events which have not yet been persisted, in the case that
we are backfilling a batch of events.
Note: May be incomplete: if we were unable to find all of the claimed auth
events. Also, treat the contents with caution: the events might also have
been rejected, might not yet have been authorized themselves, or they might
be in the wrong room.
auth_events: the auth_event map for that event
""" """
event = attr.ib(type=EventBase) event: EventBase
state = attr.ib(type=Optional[Sequence[EventBase]], default=None) state: Optional[Sequence[EventBase]]
auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None) claimed_auth_event_map: StateMap[EventBase]
class FederationHandler(BaseHandler): class FederationHandler(BaseHandler):
@ -1086,7 +1098,7 @@ class FederationHandler(BaseHandler):
_NewEventInfo( _NewEventInfo(
event=ev, event=ev,
state=events_to_state[e_id], state=events_to_state[e_id],
auth_events={ claimed_auth_event_map={
( (
auth_events[a_id].type, auth_events[a_id].type,
auth_events[a_id].state_key, auth_events[a_id].state_key,
@ -2315,7 +2327,7 @@ class FederationHandler(BaseHandler):
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
state: Optional[Iterable[EventBase]] = None, state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None, claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False, backfilled: bool = False,
) -> None: ) -> None:
""" """
@ -2327,17 +2339,18 @@ class FederationHandler(BaseHandler):
context: context:
The event context. The event context.
NB that this function potentially modifies it.
state: state:
The state events used to check the event for soft-fail. If this is The state events used to check the event for soft-fail. If this is
not provided the current state events will be used. not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room claimed_auth_event_map:
at the event's position in the DAG, though occasionally (eg if the A map of (type, state_key) => event for the event's claimed auth_events.
event is an outlier), may be the auth events claimed by the remote Possibly incomplete, and possibly including events that are not yet
server. persisted, or authed, or in the right room.
Only populated where we may not already have persisted these events -
for example, when populating outliers.
backfilled: True if the event was backfilled. backfilled: True if the event was backfilled.
""" """
context = await self._check_event_auth( context = await self._check_event_auth(
@ -2345,7 +2358,7 @@ class FederationHandler(BaseHandler):
event, event,
context, context,
state=state, state=state,
auth_events=auth_events, claimed_auth_event_map=claimed_auth_event_map,
backfilled=backfilled, backfilled=backfilled,
) )
@ -2409,7 +2422,7 @@ class FederationHandler(BaseHandler):
event, event,
res, res,
state=ev_info.state, state=ev_info.state,
auth_events=ev_info.auth_events, claimed_auth_event_map=ev_info.claimed_auth_event_map,
backfilled=backfilled, backfilled=backfilled,
) )
return res return res
@ -2675,7 +2688,7 @@ class FederationHandler(BaseHandler):
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
state: Optional[Iterable[EventBase]] = None, state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None, claimed_auth_event_map: Optional[StateMap[EventBase]] = None,
backfilled: bool = False, backfilled: bool = False,
) -> EventContext: ) -> EventContext:
""" """
@ -2687,21 +2700,19 @@ class FederationHandler(BaseHandler):
context: context:
The event context. The event context.
NB that this function potentially modifies it.
state: state:
The state events used to check the event for soft-fail. If this is The state events used to check the event for soft-fail. If this is
not provided the current state events will be used. not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room claimed_auth_event_map:
at the event's position in the DAG, though occasionally (eg if the A map of (type, state_key) => event for the event's claimed auth_events.
event is an outlier), may be the auth events claimed by the remote Possibly incomplete, and possibly including events that are not yet
server. persisted, or authed, or in the right room.
Also NB that this function adds entries to it. Only populated where we may not already have persisted these events -
for example, when populating outliers, or the state for a backwards
extremity.
If this is not provided, it is calculated from the previous state IDs.
backfilled: True if the event was backfilled. backfilled: True if the event was backfilled.
Returns: Returns:
@ -2710,7 +2721,12 @@ class FederationHandler(BaseHandler):
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if not auth_events: if claimed_auth_event_map:
# if we have a copy of the auth events from the event, use that as the
# basis for auth.
auth_events = claimed_auth_event_map
else:
# otherwise, we calculate what the auth events *should* be, and use that
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self._event_auth_handler.compute_auth_events( auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
@ -2718,18 +2734,11 @@ class FederationHandler(BaseHandler):
auth_events_x = await self.store.get_events(auth_events_ids) auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try: try:
context = await self._update_auth_events_and_context_for_auth( (
context,
auth_events_for_auth,
) = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events origin, event, context, auth_events
) )
except Exception: except Exception:
@ -2742,9 +2751,10 @@ class FederationHandler(BaseHandler):
"Ignoring failure and continuing processing of event.", "Ignoring failure and continuing processing of event.",
event.event_id, event.event_id,
) )
auth_events_for_auth = auth_events
try: try:
event_auth.check(room_version_obj, event, auth_events=auth_events) event_auth.check(room_version_obj, event, auth_events=auth_events_for_auth)
except AuthError as e: except AuthError as e:
logger.warning("Failed auth resolution for %r because %s", event, e) logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
@ -2769,8 +2779,8 @@ class FederationHandler(BaseHandler):
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
auth_events: MutableStateMap[EventBase], input_auth_events: StateMap[EventBase],
) -> EventContext: ) -> Tuple[EventContext, StateMap[EventBase]]:
"""Helper for _check_event_auth. See there for docs. """Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it Checks whether a given event has the expected auth events. If it
@ -2787,7 +2797,7 @@ class FederationHandler(BaseHandler):
event: event:
context: context:
auth_events: input_auth_events:
Map from (event_type, state_key) to event Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room Normally, our calculated auth_events based on the state of the room
@ -2795,11 +2805,12 @@ class FederationHandler(BaseHandler):
event is an outlier), may be the auth events claimed by the remote event is an outlier), may be the auth events claimed by the remote
server. server.
Also NB that this function adds entries to it.
Returns: Returns:
updated context updated context, updated auth event map
""" """
# take a copy of input_auth_events before we modify it.
auth_events: MutableStateMap[EventBase] = dict(input_auth_events)
event_auth_events = set(event.auth_event_ids()) event_auth_events = set(event.auth_event_ids())
# missing_auth is the set of the event's auth_events which we don't yet have # missing_auth is the set of the event's auth_events which we don't yet have
@ -2828,7 +2839,7 @@ class FederationHandler(BaseHandler):
# The other side isn't around or doesn't implement the # The other side isn't around or doesn't implement the
# endpoint, so lets just bail out. # endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e1) logger.info("Failed to get event auth from remote: %s", e1)
return context return context, auth_events
seen_remotes = await self.store.have_seen_events( seen_remotes = await self.store.have_seen_events(
event.room_id, [e.event_id for e in remote_auth_chain] event.room_id, [e.event_id for e in remote_auth_chain]
@ -2859,7 +2870,10 @@ class FederationHandler(BaseHandler):
await self.state_handler.compute_event_context(e) await self.state_handler.compute_event_context(e)
) )
await self._auth_and_persist_event( await self._auth_and_persist_event(
origin, e, missing_auth_event_context, auth_events=auth origin,
e,
missing_auth_event_context,
claimed_auth_event_map=auth,
) )
if e.event_id in event_auth_events: if e.event_id in event_auth_events:
@ -2877,14 +2891,14 @@ class FederationHandler(BaseHandler):
# obviously be empty # obviously be empty
# (b) alternatively, why don't we do it earlier? # (b) alternatively, why don't we do it earlier?
logger.info("Skipping auth_event fetch for outlier") logger.info("Skipping auth_event fetch for outlier")
return context return context, auth_events
different_auth = event_auth_events.difference( different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values() e.event_id for e in auth_events.values()
) )
if not different_auth: if not different_auth:
return context return context, auth_events
logger.info( logger.info(
"auth_events refers to events which are not in our calculated auth " "auth_events refers to events which are not in our calculated auth "
@ -2910,7 +2924,7 @@ class FederationHandler(BaseHandler):
# XXX: should we reject the event in this case? It feels like we should, # XXX: should we reject the event in this case? It feels like we should,
# but then shouldn't we also do so if we've failed to fetch any of the # but then shouldn't we also do so if we've failed to fetch any of the
# auth events? # auth events?
return context return context, auth_events
# now we state-resolve between our own idea of the auth events, and the remote's # now we state-resolve between our own idea of the auth events, and the remote's
# idea of them. # idea of them.
@ -2940,7 +2954,7 @@ class FederationHandler(BaseHandler):
event, context, auth_events event, context, auth_events
) )
return context return context, auth_events
async def _update_context_for_auth_events( async def _update_context_for_auth_events(
self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase] self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]

View file

@ -75,11 +75,9 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
self.handler = self.homeserver.get_federation_handler() self.handler = self.homeserver.get_federation_handler()
self.handler._check_event_auth = ( self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed(
lambda origin, event, context, state, auth_events, backfilled: succeed(
context context
) )
)
self.client = self.homeserver.get_federation_client() self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus pdus