Separate creating an event context from persisting it in the federation handler (#9800)

This refactoring allows adding logic that uses the event context
before persisting it.
This commit is contained in:
Patrick Cloke 2021-04-14 12:35:28 -04:00 committed by GitHub
parent e8816c6ace
commit 936e69825a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 67 deletions

1
changelog.d/9800.feature Normal file
View file

@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.

View file

@ -103,7 +103,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
"""Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes:
event: the received event
@ -807,7 +807,10 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)
context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
@ -1010,7 +1013,9 @@ class FederationHandler(BaseHandler):
)
if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@ -1023,10 +1028,12 @@ class FederationHandler(BaseHandler):
# non-outliers
assert not event.internal_metadata.is_outlier()
context = await self.state_handler.compute_event_context(event)
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True)
await self._auth_and_persist_event(dest, event, context, backfilled=True)
return events
@ -1360,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
await self._auth_and_persist_events(
destination,
room_id,
event_infos,
@ -1666,10 +1673,11 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
context = await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
context = await self._auth_and_persist_event(origin, event, context)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
"on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
@ -1878,10 +1886,11 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
await self._handle_new_event(origin, event)
context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
"on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id,
event.signatures,
)
@ -1989,16 +1998,47 @@ class FederationHandler(BaseHandler):
async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)
async def _handle_new_event(
async def _auth_and_persist_event(
self,
origin: str,
event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False,
) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
"""
Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
backfilled: True if the event was backfilled.
Returns:
The event context.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
)
try:
@ -2022,7 +2062,7 @@ class FederationHandler(BaseHandler):
return context
async def _handle_new_events(
async def _auth_and_persist_events(
self,
origin: str,
room_id: str,
@ -2040,9 +2080,13 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo):
event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = await self._prep_event(
res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self._check_event_auth(
origin,
event,
res,
state=ev_info.state,
auth_events=ev_info.auth_events,
backfilled=backfilled,
@ -2177,49 +2221,6 @@ class FederationHandler(BaseHandler):
room_id, [(event, new_event_context)]
)
async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
context = await self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
) -> None:
@ -2330,19 +2331,28 @@ class FederationHandler(BaseHandler):
return missing_events
async def do_auth(
async def _check_event_auth(
self,
origin: str,
event: EventBase,
context: EventContext,
auth_events: MutableStateMap[EventBase],
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
"""
Checks whether an event should be rejected (for failing auth checks).
Args:
origin:
event:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
@ -2352,12 +2362,34 @@ class FederationHandler(BaseHandler):
server.
Also NB that this function adds entries to it.
If this is not provided, it is calculated from the previous state IDs.
backfilled: True if the event was backfilled.
Returns:
updated context object
The updated context object.
"""
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
@ -2379,6 +2411,17 @@ class FederationHandler(BaseHandler):
logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _update_auth_events_and_context_for_auth(
@ -2388,7 +2431,7 @@ class FederationHandler(BaseHandler):
context: EventContext,
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.
"""Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if
@ -2468,9 +2511,14 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True
logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id
"_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context = await self.state_handler.compute_event_context(e)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
)
await self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e

View file

@ -75,9 +75,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
self.handler._check_event_auth = (
lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
)
self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
pdus