diff --git a/changelog.d/10746.misc b/changelog.d/10746.misc new file mode 100644 index 000000000..9a765435d --- /dev/null +++ b/changelog.d/10746.misc @@ -0,0 +1 @@ +Clean up some of the federation event authentication code for clarity. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index afeb2892d..69f8287b2 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -124,28 +124,28 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self.storage = hs.get_storage() - self.state_store = self.storage.state + self._store = hs.get_datastore() + self._storage = hs.get_storage() + self._state_store = self._storage.state - self.state_handler = hs.get_state_handler() - self.event_creation_handler = hs.get_event_creation_handler() + self._state_handler = hs.get_state_handler() + self._event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() self._message_handler = hs.get_message_handler() - self.action_generator = hs.get_action_generator() + self._action_generator = hs.get_action_generator() self._state_resolution_handler = hs.get_state_resolution_handler() # avoid a circular dependency by deferring execution here self._get_room_member_handler = hs.get_room_member_handler - self.federation_client = hs.get_federation_client() - self.third_party_event_rules = hs.get_third_party_event_rules() + self._federation_client = hs.get_federation_client() + self._third_party_event_rules = hs.get_third_party_event_rules() self._notifier = hs.get_notifier() - self.is_mine_id = hs.is_mine_id + self._is_mine_id = hs.is_mine_id self._server_name = hs.hostname self._instance_name = hs.get_instance_name() - self.config = hs.config + self._config = hs.config self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs) @@ -177,7 +177,7 @@ class FederationEventHandler: event_id = pdu.event_id # We reprocess pdus when we have seen them only as outliers - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) @@ -240,7 +240,7 @@ class FederationEventHandler: # - Fetching state if we have a hole in the graph if not pdu.internal_metadata.is_outlier(): prevs = set(pdu.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if missing_prevs: @@ -274,7 +274,7 @@ class FederationEventHandler: # Update the set of things we've seen after trying to # fetch the missing stuff - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -363,7 +363,7 @@ class FederationEventHandler: # the room, so we send it on their behalf. event.internal_metadata.send_on_behalf_of = origin - context = await self.state_handler.compute_event_context(event) + context = await self._state_handler.compute_event_context(event) context = await self._check_event_auth(origin, event, context) if context.rejected: raise SynapseError( @@ -377,7 +377,7 @@ class FederationEventHandler: # for knock events, we run the third-party event rules. It's not entirely clear # why we don't do this for other sorts of membership events. if event.membership == Membership.KNOCK: - event_allowed, _ = await self.third_party_event_rules.check_event_allowed( + event_allowed, _ = await self._third_party_event_rules.check_event_allowed( event, context ) if not event_allowed: @@ -406,7 +406,7 @@ class FederationEventHandler: prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None) prev_member_event = None if prev_member_event_id: - prev_member_event = await self.store.get_event(prev_member_event_id) + prev_member_event = await self._store.get_event(prev_member_event_id) # Check if the member should be allowed access via membership in a space. await self._event_auth_handler.check_restricted_join_rules( @@ -439,7 +439,7 @@ class FederationEventHandler: if dest == self._server_name: raise SynapseError(400, "Can't backfill from self.") - events = await self.federation_client.backfill( + events = await self._federation_client.backfill( dest, room_id, limit=limit, extremities=extremities ) @@ -471,12 +471,12 @@ class FederationEventHandler: room_id = pdu.room_id event_id = pdu.event_id - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) if not prevs - seen: return - latest_list = await self.store.get_latest_event_ids_in_room(room_id) + latest_list = await self._store.get_latest_event_ids_in_room(room_id) # We add the prev events that we have seen to the latest # list to ensure the remote server doesn't give them to us @@ -538,7 +538,7 @@ class FederationEventHandler: # All that said: Let's try increasing the timeout to 60s and see what happens. try: - missing_events = await self.federation_client.get_missing_events( + missing_events = await self._federation_client.get_missing_events( origin, room_id, earliest_events_ids=list(latest), @@ -611,7 +611,7 @@ class FederationEventHandler: event_id = event.event_id - existing = await self.store.get_event( + existing = await self._store.get_event( event_id, allow_none=True, allow_rejected=True ) if existing: @@ -676,7 +676,7 @@ class FederationEventHandler: event_id = event.event_id prevs = set(event.prev_event_ids()) - seen = await self.store.have_events_in_timeline(prevs) + seen = await self._store.have_events_in_timeline(prevs) missing_prevs = prevs - seen if not missing_prevs: @@ -693,7 +693,7 @@ class FederationEventHandler: event_map = {event_id: event} try: # Get the state of the events we know about - ours = await self.state_store.get_state_groups_ids(room_id, seen) + ours = await self._state_store.get_state_groups_ids(room_id, seen) # state_maps is a list of mappings from (type, state_key) to event_id state_maps: List[StateMap[str]] = list(ours.values()) @@ -722,13 +722,13 @@ class FederationEventHandler: for x in remote_state: event_map[x.event_id] = x - room_version = await self.store.get_room_version_id(room_id) + room_version = await self._store.get_room_version_id(room_id) state_map = await self._state_resolution_handler.resolve_events_with_store( room_id, room_version, state_maps, event_map, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore(self._store), ) # We need to give _process_received_pdu the actual state events @@ -736,7 +736,7 @@ class FederationEventHandler: # First though we need to fetch all the events that are in # state_map, so we can build up the state below. - evs = await self.store.get_events( + evs = await self._store.get_events( list(state_map.values()), get_prev_content=False, redact_behaviour=EventRedactBehaviour.AS_IS, @@ -776,7 +776,7 @@ class FederationEventHandler: ( state_event_ids, auth_event_ids, - ) = await self.federation_client.get_room_state_ids( + ) = await self._federation_client.get_room_state_ids( destination, room_id, event_id=event_id ) @@ -790,7 +790,7 @@ class FederationEventHandler: desired_events = set(state_event_ids) desired_events.add(event_id) logger.debug("Fetching %i events from cache/store", len(desired_events)) - fetched_events = await self.store.get_events( + fetched_events = await self._store.get_events( desired_events, allow_rejected=True ) @@ -811,7 +811,7 @@ class FederationEventHandler: missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events.difference_update( - await self.store.have_seen_events(room_id, missing_auth_events) + await self._store.have_seen_events(room_id, missing_auth_events) ) logger.debug("We are also missing %i auth events", len(missing_auth_events)) @@ -824,7 +824,7 @@ class FederationEventHandler: # we need to make sure we re-load from the database to get the rejected # state correct. fetched_events.update( - await self.store.get_events(missing_desired_events, allow_rejected=True) + await self._store.get_events(missing_desired_events, allow_rejected=True) ) # check for events which were in the wrong room. @@ -903,7 +903,7 @@ class FederationEventHandler: logger.debug("Processing event: %s", event) try: - context = await self.state_handler.compute_event_context( + context = await self._state_handler.compute_event_context( event, old_state=state ) await self._auth_and_persist_event( @@ -921,7 +921,7 @@ class FederationEventHandler: device_id = event.content.get("device_id") sender_key = event.content.get("sender_key") - cached_devices = await self.store.get_cached_devices_for_user(event.sender) + cached_devices = await self._store.get_cached_devices_for_user(event.sender) resync = False # Whether we should resync device lists. @@ -997,10 +997,10 @@ class FederationEventHandler: """ try: - await self.store.mark_remote_user_device_cache_as_stale(sender) + await self._store.mark_remote_user_device_cache_as_stale(sender) # Immediately attempt a resync in the background - if self.config.worker_app: + if self._config.worker_app: await self._user_device_resync(user_id=sender) else: await self._device_list_updater.user_device_resync(sender) @@ -1026,12 +1026,12 @@ class FederationEventHandler: # Skip processing a marker event if the room version doesn't # support it or the event is not from the room creator. - room_version = await self.store.get_room_version(marker_event.room_id) - create_event = await self.store.get_create_event_for_room(marker_event.room_id) + room_version = await self._store.get_room_version(marker_event.room_id) + create_event = await self._store.get_create_event_for_room(marker_event.room_id) room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR) if ( not room_version.msc2716_historical - or not self.config.experimental.msc2716_enabled + or not self._config.experimental.msc2716_enabled or marker_event.sender != room_creator ): return @@ -1056,7 +1056,7 @@ class FederationEventHandler: [insertion_event_id], ) - insertion_event = await self.store.get_event( + insertion_event = await self._store.get_event( insertion_event_id, allow_none=True ) if insertion_event is None: @@ -1074,7 +1074,7 @@ class FederationEventHandler: marker_event, ) - await self.store.insert_insertion_extremity( + await self._store.insert_insertion_extremity( insertion_event_id, marker_event.room_id ) @@ -1096,14 +1096,14 @@ class FederationEventHandler: Logs a warning if we can't find the given event. """ - room_version = await self.store.get_room_version(room_id) + room_version = await self._store.get_room_version(room_id) event_map: Dict[str, EventBase] = {} async def get_event(event_id: str): with nested_logging_context(event_id): try: - event = await self.federation_client.get_pdu( + event = await self._federation_client.get_pdu( [destination], event_id, room_version, @@ -1139,7 +1139,7 @@ class FederationEventHandler: for aid in event.auth_event_ids() if aid not in event_map ] - persisted_events = await self.store.get_events( + persisted_events = await self._store.get_events( auth_events, allow_rejected=True, ) @@ -1183,7 +1183,7 @@ class FederationEventHandler: async def prep(ev_info: _NewEventInfo): event = ev_info.event with nested_logging_context(suffix=event.event_id): - res = await self.state_handler.compute_event_context(event) + res = await self._state_handler.compute_event_context(event) res = await self._check_event_auth( origin, event, @@ -1286,7 +1286,7 @@ class FederationEventHandler: Returns: The updated context object. """ - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] if claimed_auth_event_map: @@ -1299,7 +1299,7 @@ class FederationEventHandler: auth_events_ids = self._event_auth_handler.compute_auth_events( event, prev_state_ids, for_verification=True ) - auth_events_x = await self.store.get_events(auth_events_ids) + auth_events_x = await self._store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()} try: @@ -1334,7 +1334,7 @@ class FederationEventHandler: # If we are going to send this event over federation we precaclculate # the joined hosts. if event.internal_metadata.get_send_on_behalf_of(): - await self.event_creation_handler.cache_joined_hosts_for_event( + await self._event_creation_handler.cache_joined_hosts_for_event( event, context ) @@ -1348,7 +1348,7 @@ class FederationEventHandler: if guest_access == GuestAccess.CAN_JOIN: return - current_state_map = await self.state_handler.get_current_state(event.room_id) + current_state_map = await self._state_handler.get_current_state(event.room_id) current_state = list(current_state_map.values()) await self._get_room_member_handler().kick_guest_users(current_state) @@ -1374,7 +1374,7 @@ class FederationEventHandler: if backfilled or event.internal_metadata.is_outlier(): return - extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id) + extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id) extrem_ids = set(extrem_ids_list) prev_event_ids = set(event.prev_event_ids()) @@ -1383,7 +1383,7 @@ class FederationEventHandler: # state at the event, so no point rechecking auth for soft fail. return - room_version = await self.store.get_room_version_id(event.room_id) + room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] # Calculate the "current state". @@ -1400,19 +1400,19 @@ class FederationEventHandler: # given state at the event. This should correctly handle cases # like bans, especially with state res v2. - state_sets_d = await self.state_store.get_state_groups( + state_sets_d = await self._state_store.get_state_groups( event.room_id, extrem_ids ) state_sets: List[Iterable[EventBase]] = list(state_sets_d.values()) state_sets.append(state) - current_states = await self.state_handler.resolve_events( + current_states = await self._state_handler.resolve_events( room_version, state_sets, event ) current_state_ids: StateMap[str] = { k: e.event_id for k, e in current_states.items() } else: - current_state_ids = await self.state_handler.get_current_state_ids( + current_state_ids = await self._state_handler.get_current_state_ids( event.room_id, latest_event_ids=extrem_ids ) @@ -1428,7 +1428,7 @@ class FederationEventHandler: e for k, e in current_state_ids.items() if k in auth_types ] - auth_events_map = await self.store.get_events(current_state_ids_list) + auth_events_map = await self._store.get_events(current_state_ids_list) current_auth_events = { (e.type, e.state_key): e for e in auth_events_map.values() } @@ -1499,7 +1499,9 @@ class FederationEventHandler: # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - have_events = await self.store.have_seen_events(event.room_id, missing_auth) + have_events = await self._store.have_seen_events( + event.room_id, missing_auth + ) logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) @@ -1508,7 +1510,7 @@ class FederationEventHandler: logger.info("auth_events contains unknown events: %s", missing_auth) try: try: - remote_auth_chain = await self.federation_client.get_event_auth( + remote_auth_chain = await self._federation_client.get_event_auth( origin, event.room_id, event.event_id ) except RequestSendFailed as e1: @@ -1517,7 +1519,7 @@ class FederationEventHandler: logger.info("Failed to get event auth from remote: %s", e1) return context, auth_events - seen_remotes = await self.store.have_seen_events( + seen_remotes = await self._store.have_seen_events( event.room_id, [e.event_id for e in remote_auth_chain] ) @@ -1543,7 +1545,7 @@ class FederationEventHandler: e.event_id, ) missing_auth_event_context = ( - await self.state_handler.compute_event_context(e) + await self._state_handler.compute_event_context(e) ) await self._auth_and_persist_event( origin, @@ -1584,7 +1586,7 @@ class FederationEventHandler: # XXX: currently this checks for redactions but I'm not convinced that is # necessary? - different_events = await self.store.get_events_as_list(different_auth) + different_events = await self._store.get_events_as_list(different_auth) for d in different_events: if d.room_id != event.room_id: @@ -1610,8 +1612,8 @@ class FederationEventHandler: remote_auth_events.update({(d.type, d.state_key): d for d in different_events}) remote_state = remote_auth_events.values() - room_version = await self.store.get_room_version_id(event.room_id) - new_state = await self.state_handler.resolve_events( + room_version = await self._store.get_room_version_id(event.room_id) + new_state = await self._state_handler.resolve_events( room_version, (local_state, remote_state), event ) @@ -1669,7 +1671,7 @@ class FederationEventHandler: # create a new state group as a delta from the existing one. prev_group = context.state_group - state_group = await self.state_store.store_state_group( + state_group = await self._state_store.store_state_group( event.event_id, event.room_id, prev_group=prev_group, @@ -1701,9 +1703,9 @@ class FederationEventHandler: not event.internal_metadata.is_outlier() and not backfilled and not context.rejected - and (await self.store.get_min_depth(event.room_id)) <= event.depth + and (await self._store.get_min_depth(event.room_id)) <= event.depth ): - await self.action_generator.handle_push_actions_for_event( + await self._action_generator.handle_push_actions_for_event( event, context ) @@ -1712,7 +1714,7 @@ class FederationEventHandler: ) except Exception: run_in_background( - self.store.remove_push_actions_from_staging, event.event_id + self._store.remove_push_actions_from_staging, event.event_id ) raise @@ -1737,27 +1739,27 @@ class FederationEventHandler: The stream ID after which all events have been persisted. """ if not event_and_contexts: - return self.store.get_current_events_token() + return self._store.get_current_events_token() - instance = self.config.worker.events_shard_config.get_instance(room_id) + instance = self._config.worker.events_shard_config.get_instance(room_id) if instance != self._instance_name: # Limit the number of events sent over replication. We choose 200 # here as that is what we default to in `max_request_body_size(..)` for batch in batch_iter(event_and_contexts, 200): result = await self._send_events( instance_name=instance, - store=self.store, + store=self._store, room_id=room_id, event_and_contexts=batch, backfilled=backfilled, ) return result["max_stream_id"] else: - assert self.storage.persistence + assert self._storage.persistence # Note that this returns the events that were persisted, which may not be # the same as were passed in if some were deduplicated due to transaction IDs. - events, max_stream_token = await self.storage.persistence.persist_events( + events, max_stream_token = await self._storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -1791,7 +1793,7 @@ class FederationEventHandler: # users if event.internal_metadata.is_outlier(): if event.membership != Membership.INVITE: - if not self.is_mine_id(target_user_id): + if not self._is_mine_id(target_user_id): return target_user = UserID.from_string(target_user_id) @@ -1840,4 +1842,4 @@ class FederationEventHandler: raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events") async def get_min_depth_for_context(self, context: str) -> int: - return await self.store.get_min_depth(context) + return await self._store.get_min_depth(context)